35 lines
985 B
Python
35 lines
985 B
Python
'''
|
|
📁 data_analysis_engine/models/xgboost_model.py
|
|
|
|
XGBoost 모델 클래스: 학습, 예측, 저장, 로드, 전처리 지원
|
|
'''
|
|
|
|
import os
|
|
import xgboost as xgb
|
|
import pandas as pd
|
|
from data_analysis_engine.dataset_builder import build_dataset
|
|
|
|
class XGBoostModel:
|
|
def __init__(self):
|
|
self.model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
|
|
|
|
def fit(self, X: pd.DataFrame, y: pd.Series):
|
|
self.model.fit(X, y)
|
|
|
|
def predict_proba(self, X: pd.DataFrame) -> float:
|
|
return float(self.model.predict_proba(X)[-1][1])
|
|
|
|
def save_model(self, path: str):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
self.model.save_model(path)
|
|
|
|
def load_model(self, path: str):
|
|
self.model.load_model(path)
|
|
|
|
def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""
|
|
예측 시 사용할 피처만 추출하는 전처리 함수
|
|
"""
|
|
X, _ = build_dataset(df)
|
|
return X
|