From 96b0bf608dd4a0b0ec4af1558424bb198a9b0783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=B0=95=EC=83=81=ED=98=B8=20Sangho=20Park?= Date: Thu, 28 Aug 2025 10:46:40 +0900 Subject: [PATCH] =?UTF-8?q?feat:=208=EB=B9=84=ED=8A=B8=20=EC=96=91?= =?UTF-8?q?=EC=9E=90=ED=99=94=20=EC=98=B5=EC=85=98=20=EC=B6=94=EA=B0=80=20?= =?UTF-8?q?=EB=B0=8F=20CPU=20=EB=A9=94=EB=AA=A8=EB=A6=AC=20=EC=B5=9C?= =?UTF-8?q?=EC=A0=81=ED=99=94=20=EA=B5=AC=ED=98=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI_Web_Scraper/model_downloader.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/AI_Web_Scraper/model_downloader.py b/AI_Web_Scraper/model_downloader.py index 96ded6c..9eddeba 100644 --- a/AI_Web_Scraper/model_downloader.py +++ b/AI_Web_Scraper/model_downloader.py @@ -1,6 +1,6 @@ import os import json -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from huggingface_hub import HfApi def download_model(config_path='./config.json'): @@ -12,6 +12,8 @@ def download_model(config_path='./config.json'): model_name = config['model_name'] local_path = config['model_local_path'] + model_settings = config.get('model_settings', {}) + use_quantization = model_settings.get('use_quantization', False) if not os.path.exists(local_path): os.makedirs(local_path) @@ -19,10 +21,21 @@ def download_model(config_path='./config.json'): print(f"모델 {model_name}을 {local_path}에 다운로드 중...") try: + # 양자화 설정 적용 + if use_quantization: + print("8bit 양자화 적용") + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=True + ) + else: + quantization_config = None + model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir=local_path, - device_map="auto", # GPU 자동 할당 + quantization_config=quantization_config, + device_map="cpu", # 다운로드 시 CPU에 로드하여 메모리 절약 torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained(