72 lines
2.4 KiB
Plaintext
72 lines
2.4 KiB
Plaintext
## 📄 train_model.py - Concept 설명 문서
|
|
|
|
### ✅ 목적
|
|
CDS(Complete Data Set)를 기반으로 한 학습용 데이터셋을 생성하고, XGBoost 모델을 학습시켜 상위 분석 파이프라인에서 활용할 수 있도록 모델을 저장하는 스크립트입니다.
|
|
|
|
---
|
|
|
|
### 📂 입력 데이터
|
|
- `cds_dir`: `_ohlcv.csv` 형식의 종목별 CDS가 저장된 디렉토리
|
|
- 예: `AAPL_ohlcv.csv`, `MSFT_ohlcv.csv`
|
|
- 파일 구조는 OHLCV(Time Series) 형태
|
|
|
|
---
|
|
|
|
### ⚙️ 주요 기능
|
|
1. **데이터 적재 및 통합**
|
|
- 종목별 CDS 파일을 모두 읽어 피처(X), 타깃(y)으로 변환
|
|
- `build_dataset()` 호출 → 기술 지표 피처 등 포함 가능
|
|
|
|
2. **클래스 분포 확인 및 불균형 보정**
|
|
- `y_total`의 클래스 비율(상승/하락) 출력
|
|
- `scale_pos_weight` 자동 조정 → 불균형에 강한 학습 구조 지원
|
|
|
|
3. **XGBoost 모델 학습 + 하이퍼파라미터 튜닝**
|
|
- `GridSearchCV`로 최적 파라미터 탐색
|
|
- 튜닝 대상: `max_depth`, `learning_rate`, `n_estimators`
|
|
|
|
4. **모델 평가 지표 출력**
|
|
- 정확도 (`accuracy_score`)
|
|
- AUC (`roc_auc_score`)
|
|
- F1 점수 (`f1_score`)
|
|
- LogLoss (`log_loss`)
|
|
- Precision@TopN (`precision@TopN`, 예: P@50)
|
|
|
|
5. **모델 저장 (버전 관리 포함)**
|
|
- 저장 경로: `data_analysis_engine/models/model_YYYY-MM-DD.json`
|
|
- 날짜 기반 버전 관리 자동 수행
|
|
|
|
6. **학습 로그 자동 기록**
|
|
- `train_log.csv`에 날짜, 성능 지표, 파라미터, 샘플 수 기록
|
|
|
|
---
|
|
|
|
### 🧪 사용 방법
|
|
```bash
|
|
python -m data_analysis_engine.train_model
|
|
```
|
|
또는 내부에서 import 후 `train_model("data")` 호출
|
|
|
|
---
|
|
|
|
### 🧠 향후 확장 가능성
|
|
- Optuna 기반 자동 하이퍼파라미터 탐색
|
|
- K-fold 교차 검증 평가 구조 도입
|
|
- 예측 기반 ROI 피처 학습 (Target 다양화)
|
|
- 외부 평가 세트 적용 및 모델 비교 리포트 생성
|
|
|
|
---
|
|
|
|
### ⚠️ 주의사항
|
|
- 모든 CDS 파일은 비어 있지 않아야 하며, `_ohlcv.csv` 확장자를 따라야 함
|
|
- 피처 수가 변하면 모델 구조도 반드시 재학습 필요
|
|
|
|
---
|
|
|
|
### 📌 관련 파일
|
|
- `dataset_builder.py`: X, y 전처리 생성 및 기술 지표 포함
|
|
- `xgboost_model.py`: 모델 클래스 정의 및 저장/불러오기
|
|
- `X_total.csv`, `y_total.csv`: 통합 학습 데이터
|
|
- `model_YYYY-MM-DD.json`: 학습된 XGBoost 모델
|
|
- `train_log.csv`: 학습 결과 누적 로그
|