""" RAG API 生产环境服务 使用独立的生产配置文件 rag_config_production.yaml 与测试环境配置完全分离 """ import os import sys import json import uvicorn from pathlib import Path # 设置环境变量,让prompt_loader使用生产配置 os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml' from datetime import datetime from typing import Dict, Any, List, Optional from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel, Field from dotenv import load_dotenv # 项目根目录 project_root = Path(__file__).parent sys.path.append(str(project_root)) # 加载环境变量 env_path = project_root / '.env' if env_path.exists(): load_dotenv(env_path) print(f"[OK] 已加载环境变量文件: {env_path}") else: print(f"[WARNING] 未找到.env文件,使用系统环境变量") from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection # ============= 生产环境配置 ============= # 使用项目模块的默认参数,不硬编码 OUTPUT_DIR = project_root / "api_outputs" # ============= 请求/响应模型 ============= class RetrieveRequest(BaseModel): """检索请求模型""" query: str = Field(..., description="查询问题") mode: str = Field(default="0", description="调试模式: 0=自动") save_output: bool = Field(default=True, description="是否保存输出到文件") class RetrieveResponse(BaseModel): """检索响应模型""" success: bool = Field(..., description="是否成功") query: str = Field(..., description="原始查询") answer: str = Field(..., description="生成的答案") supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落") supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件") metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据") output_file: Optional[str] = Field(None, description="输出文件路径") error: Optional[str] = Field(None, description="错误信息") class HealthResponse(BaseModel): """健康检查响应""" status: str = "healthy" environment: str = "production" langsmith_connected: bool = False timestamp: str = "" # ============= 全局状态管理 ============= class AppState: """应用状态管理""" retriever = None langsmith_connected = False request_count = 0 last_request_time = None # ============= 生命周期管理 ============= def initialize_retriever(): """初始化检索器(使用项目默认配置)""" try: # 检查LangSmith连接 AppState.langsmith_connected = check_langsmith_connection() # 创建检索器,使用项目默认参数 # keyword 是必需参数,使用项目默认值 "test" AppState.retriever = create_langsmith_retriever( keyword="test" # 项目默认关键词 ) print("[OK] 生产环境检索器初始化成功(使用项目默认参数)") print(" 环境: PRODUCTION") except Exception as e: print(f"[ERROR] 检索器初始化失败: {e}") AppState.retriever = None @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" # 启动时执行 print("="*60) print("[STARTING] 启动RAG API服务(生产环境)") print("="*60) # 创建输出目录 OUTPUT_DIR.mkdir(exist_ok=True) print(f"[OK] 输出目录: {OUTPUT_DIR}") # 初始化检索器 initialize_retriever() if AppState.retriever: print("[OK] 服务就绪(使用项目默认参数)") else: print("[WARNING] 服务启动但检索器未初始化") print("="*60) yield # 关闭时执行 print("\n[?] 服务关闭") # ============= FastAPI应用 ============= app = FastAPI( title="RAG API Service (Production)", description="RAG检索服务(生产环境)", version="1.0.0", lifespan=lifespan ) # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境可以改为具体域名 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============= API端点 ============= @app.get("/", response_model=Dict[str, str]) async def root(): """根路径""" return { "service": "RAG API Production Service", "status": "running", "environment": "production", "version": "1.0.0" } @app.get("/health", response_model=HealthResponse) async def health_check(): """健康检查""" return HealthResponse( status="healthy" if AppState.retriever else "degraded", environment="production", langsmith_connected=AppState.langsmith_connected, timestamp=datetime.now().isoformat() ) @app.post("/query", response_model=RetrieveResponse) async def query(request: RetrieveRequest): """RAG查询接口(生产环境)""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") AppState.request_count += 1 AppState.last_request_time = datetime.now() print(f"\n{'='*60}") print(f"[IN] 收到查询请求 #{AppState.request_count}") print(f" 时间: {timestamp}") print(f" 查询: {request.query}") print(f" 环境: PRODUCTION") print(f"{'='*60}") if not AppState.retriever: raise HTTPException(status_code=503, detail="检索器未初始化") try: # 执行检索,使用项目的默认参数 result = AppState.retriever.retrieve( request.query, request.mode ) # 保存输出 output_file = None if request.save_output: output_file = OUTPUT_DIR / f"output_{timestamp}.json" output_data = { "timestamp": timestamp, "environment": "production", "query": request.query, "answer": result.get("answer", ""), "supporting_facts": result.get("supporting_facts", []), "supporting_events": result.get("supporting_events", []), "metadata": { "request_count": AppState.request_count, "mode": request.mode } } with open(output_file, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f"[SAVE] 输出已保存: {output_file}") response = RetrieveResponse( success=True, query=request.query, answer=result.get("answer", ""), supporting_facts=result.get("supporting_facts", []), supporting_events=result.get("supporting_events", []), metadata={ "environment": "production", "timestamp": timestamp }, output_file=str(output_file) if output_file else None ) print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)") return response except Exception as e: print(f"[ERROR] 查询失败: {e}") import traceback traceback.print_exc() return RetrieveResponse( success=False, query=request.query, answer="", error=str(e), metadata={"environment": "production", "timestamp": timestamp} ) @app.post("/retrieve", response_model=RetrieveResponse) async def retrieve(request: RetrieveRequest): """检索接口(兼容旧版API)""" return await query(request) @app.get("/config", response_model=Dict[str, Any]) async def get_config(): """获取当前生产配置信息""" return { "environment": "production", "note": "生产环境使用项目默认配置", "description": "所有参数使用项目模块的内置默认值", "stats": { "request_count": AppState.request_count, "last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None } } @app.get("/outputs", response_model=List[str]) async def list_outputs(): """列出所有输出文件""" if not OUTPUT_DIR.exists(): return [] files = [f.name for f in OUTPUT_DIR.glob("output_*.json")] return sorted(files, reverse=True)[:100] # 只返回最近100个 @app.get("/outputs/{filename}") async def get_output(filename: str): """获取指定输出文件""" file_path = OUTPUT_DIR / filename if not file_path.exists(): raise HTTPException(status_code=404, detail="文件不存在") return FileResponse(file_path) # ============= 主函数 ============= def main(): """启动生产环境服务""" import argparse parser = argparse.ArgumentParser(description="RAG API生产环境服务") parser.add_argument("--host", default="0.0.0.0", help="服务地址") parser.add_argument("--port", type=int, default=8000, help="服务端口") args = parser.parse_args() # 生产环境强制关闭自动重载 uvicorn.run( "rag_api_server_production:app", host=args.host, port=args.port, reload=False, # 生产环境禁用自动重载 workers=1, # 可以根据需要调整worker数量 log_level="info" ) if __name__ == "__main__": main()