49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
'''
|
|
data_analysis_engine/predict.py
|
|
|
|
분석 엔진 전체 흐름을 실행하는 진입점 모듈입니다.
|
|
|
|
- CLI 또는 스크립트 기반 실행 가능
|
|
- 입력: CDS CSV 파일들이 저장된 폴더 경로, 모델 파일 경로, 상위 추출 수
|
|
- 처리: analyzer.analyze_stocks() 호출
|
|
- 출력: 예측 결과를 콘솔에 출력하고 CSV로 저장
|
|
'''
|
|
|
|
import os
|
|
import glob
|
|
import argparse
|
|
import pandas as pd
|
|
from data_analysis_engine.analyzer import analyze_stocks
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="SightRay 분석 엔진 실행")
|
|
parser.add_argument('--cds_dir', type=str, required=False, default=None, help='CDS CSV 파일들이 있는 디렉토리 경로')
|
|
parser.add_argument('--model_path', type=str, required=False, default=None, help='저장된 XGBoost 모델 파일 경로')
|
|
parser.add_argument('--top_n', type=int, default=5, help='상위 예측 종목 수 (기본 5개)')
|
|
parser.add_argument('--output_path', type=str, default='prediction_result.csv', help='결과 저장 파일명')
|
|
|
|
args = parser.parse_args()
|
|
|
|
cds_dir = args.cds_dir or input("CDS 디렉토리 경로를 입력하세요 (예: data/CS/2024-12-31): ").strip()
|
|
model_path = args.model_path or input("모델 파일 경로를 입력하세요 (예: data_analysis_engine/models/model_2024-04-17.json): ").strip()
|
|
|
|
# CDS 경로 리스트 생성 (*.csv)
|
|
cds_files = glob.glob(os.path.join(cds_dir, '*.csv'))
|
|
if not cds_files:
|
|
print("[오류] CDS 파일이 존재하지 않습니다.")
|
|
return
|
|
|
|
# 분석 실행
|
|
result_df = analyze_stocks(cds_files, model_path, args.top_n)
|
|
|
|
# 결과 출력 및 저장
|
|
print("\n[예측 결과 요약]")
|
|
print(result_df)
|
|
|
|
result_df.rename(columns={'probability': 'predicted_score'}, inplace=True)
|
|
result_df.to_csv(args.output_path, index=False)
|
|
print(f"\n결과가 다음 위치에 저장되었습니다: {args.output_path}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|