feat: 셀레니움 선택적 사용 및 폴백 메커니즘 추가
This commit is contained in:
@@ -95,6 +95,12 @@ class AIAgent:
|
||||
|
||||
model_settings = self.config.get('model_settings', {})
|
||||
use_quantization = bool(model_settings.get('use_quantization', False))
|
||||
# 양자화 비트/오프로딩 옵션
|
||||
try:
|
||||
quant_bits = int(model_settings.get('quantization_bits', 8))
|
||||
except Exception:
|
||||
quant_bits = 8
|
||||
cpu_offload = bool(model_settings.get('cpu_offload', False))
|
||||
torch_dtype_cfg = str(model_settings.get('torch_dtype', 'auto')).lower()
|
||||
|
||||
# dtype 파싱
|
||||
@@ -114,20 +120,7 @@ class AIAgent:
|
||||
if not model_source:
|
||||
raise RuntimeError("모델 경로/이름이 설정되지 않았습니다.")
|
||||
|
||||
# quantization 설정 (가능한 경우에만)
|
||||
quant_args = {}
|
||||
if use_quantization:
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
quant_args["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True
|
||||
)
|
||||
print("8bit 양자화 적용")
|
||||
except Exception as _:
|
||||
# transformers/bitsandbytes 호환 문제 시 양자화 비활성화
|
||||
print("bitsandbytes/transformers 호환 문제로 양자화를 비활성화합니다.")
|
||||
quant_args = {}
|
||||
# (이전) quant_args 경로 제거: load_kwargs에서 직접 처리
|
||||
|
||||
# 메모리 제한/오프로딩 설정
|
||||
mm_cfg = model_settings.get('max_memory', {}) if isinstance(model_settings.get('max_memory', {}), dict) else {}
|
||||
@@ -167,11 +160,28 @@ class AIAgent:
|
||||
if max_memory:
|
||||
load_kwargs["max_memory"] = max_memory
|
||||
|
||||
# use_quantization=True면 8bit 우선 시도 (항상 레거시 플래그 사용)
|
||||
# use_quantization=True면 4bit 우선, 아니면 8bit 레거시 플래그 사용
|
||||
if use_quantization:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
load_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
|
||||
print("8bit 양자화 적용 (레거시 플래그)")
|
||||
if quant_bits == 4:
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=__import__('torch').bfloat16
|
||||
)
|
||||
print("4bit 양자화 적용 (bnb nf4)")
|
||||
except Exception as _:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
if cpu_offload:
|
||||
load_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
|
||||
print("4bit 미지원 → 8bit(레거시)로 폴백")
|
||||
else:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
if cpu_offload:
|
||||
load_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
|
||||
print("8bit 양자화 적용 (레거시 플래그)")
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_source,
|
||||
@@ -206,11 +216,11 @@ class AIAgent:
|
||||
except Exception as e_noq:
|
||||
print(f"비양자화 재시도 실패: {e_noq}")
|
||||
|
||||
# 2b. 8-bit 양자화로 재시도 (가능 시)
|
||||
# 2b. 양자화로 재시도 (4bit 우선, 아니면 8bit)
|
||||
loaded = False
|
||||
try:
|
||||
print("8bit 양자화로 재시도합니다...")
|
||||
print("양자화로 재시도합니다...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True)
|
||||
# config 재생성 및 quantization_config 제거
|
||||
cfg = AutoConfig.from_pretrained(model_source, trust_remote_code=True)
|
||||
if hasattr(cfg, 'quantization_config'):
|
||||
try:
|
||||
@@ -224,20 +234,31 @@ class AIAgent:
|
||||
offload_state_dict=True,
|
||||
trust_remote_code=True,
|
||||
config=cfg,
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
if dtype is not None:
|
||||
retry_kwargs["torch_dtype"] = dtype
|
||||
if max_memory:
|
||||
retry_kwargs["max_memory"] = max_memory
|
||||
if quant_bits == 4:
|
||||
from transformers import BitsAndBytesConfig
|
||||
retry_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=__import__('torch').bfloat16
|
||||
)
|
||||
else:
|
||||
retry_kwargs["load_in_8bit"] = True
|
||||
if cpu_offload:
|
||||
retry_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_source, **retry_kwargs)
|
||||
except Exception as e_int8:
|
||||
print(f"8bit 재시도 실패: {e_int8}")
|
||||
loaded = True
|
||||
except Exception as e_q:
|
||||
print(f"양자화 재시도 실패: {e_q}")
|
||||
|
||||
if not tried_int8:
|
||||
print("CPU로 폴백합니다.")
|
||||
if not loaded:
|
||||
print("CPU로 폴백합니다.")
|
||||
try:
|
||||
import torch, gc
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user