This repository has been archived on 2025-06-07. You can view files and clone it, but cannot push or open issues or pull requests.
2025-05-06 21:23:04 +09:00

49 lines
1.9 KiB
Python

'''
data_analysis_engine/predict.py
분석 엔진 전체 흐름을 실행하는 진입점 모듈입니다.
- CLI 또는 스크립트 기반 실행 가능
- 입력: CDS CSV 파일들이 저장된 폴더 경로, 모델 파일 경로, 상위 추출 수
- 처리: analyzer.analyze_stocks() 호출
- 출력: 예측 결과를 콘솔에 출력하고 CSV로 저장
'''
import os
import glob
import argparse
import pandas as pd
from data_analysis_engine.analyzer import analyze_stocks
def main():
parser = argparse.ArgumentParser(description="SightRay 분석 엔진 실행")
parser.add_argument('--cds_dir', type=str, required=False, default=None, help='CDS CSV 파일들이 있는 디렉토리 경로')
parser.add_argument('--model_path', type=str, required=False, default=None, help='저장된 XGBoost 모델 파일 경로')
parser.add_argument('--top_n', type=int, default=5, help='상위 예측 종목 수 (기본 5개)')
parser.add_argument('--output_path', type=str, default='prediction_result.csv', help='결과 저장 파일명')
args = parser.parse_args()
cds_dir = args.cds_dir or input("CDS 디렉토리 경로를 입력하세요 (예: data/CS/2024-12-31): ").strip()
model_path = args.model_path or input("모델 파일 경로를 입력하세요 (예: data_analysis_engine/models/model_2024-04-17.json): ").strip()
# CDS 경로 리스트 생성 (*.csv)
cds_files = glob.glob(os.path.join(cds_dir, '*.csv'))
if not cds_files:
print("[오류] CDS 파일이 존재하지 않습니다.")
return
# 분석 실행
result_df = analyze_stocks(cds_files, model_path, args.top_n)
# 결과 출력 및 저장
print("\n[예측 결과 요약]")
print(result_df)
result_df.rename(columns={'probability': 'predicted_score'}, inplace=True)
result_df.to_csv(args.output_path, index=False)
print(f"\n결과가 다음 위치에 저장되었습니다: {args.output_path}")
if __name__ == '__main__':
main()