This repository has been archived on 2025-06-07. You can view files and clone it, but cannot push or open issues or pull requests.
2025-05-06 21:23:04 +09:00

35 lines
985 B
Python

'''
📁 data_analysis_engine/models/xgboost_model.py
XGBoost 모델 클래스: 학습, 예측, 저장, 로드, 전처리 지원
'''
import os
import xgboost as xgb
import pandas as pd
from data_analysis_engine.dataset_builder import build_dataset
class XGBoostModel:
def __init__(self):
self.model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
def fit(self, X: pd.DataFrame, y: pd.Series):
self.model.fit(X, y)
def predict_proba(self, X: pd.DataFrame) -> float:
return float(self.model.predict_proba(X)[-1][1])
def save_model(self, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
self.model.save_model(path)
def load_model(self, path: str):
self.model.load_model(path)
def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
"""
예측 시 사용할 피처만 추출하는 전처리 함수
"""
X, _ = build_dataset(df)
return X