Files
AI_Devlop/AI_Web_Scraper/model_downloader.py

58 lines
1.8 KiB
Python

import os
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import HfApi
def download_model(config_path='./config.json'):
"""
Hugging Face에서 모델을 다운로드합니다.
"""
with open(config_path, 'r') as f:
config = json.load(f)
model_name = config['model_name']
local_path = config['model_local_path']
model_settings = config.get('model_settings', {})
use_quantization = model_settings.get('use_quantization', False)
if not os.path.exists(local_path):
os.makedirs(local_path)
print(f"모델 {model_name}{local_path}에 다운로드 중...")
try:
# 양자화 설정 적용
if use_quantization:
print("8bit 양자화 적용")
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
else:
quantization_config = None
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=local_path,
quantization_config=quantization_config,
device_map="auto" if quantization_config else "cpu", # 양자화 시 auto, 아니면 cpu
torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=local_path
)
# 모델과 토크나이저 저장
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
print(f"모델 다운로드 완료: {local_path}")
return model, tokenizer
except Exception as e:
print(f"모델 다운로드 실패: {e}")
return None, None
if __name__ == "__main__":
download_model()