309 lines
9.4 KiB
Python
309 lines
9.4 KiB
Python
|
|
"""
|
|||
|
|
RAG API 生产环境服务
|
|||
|
|
使用独立的生产配置文件 rag_config_production.yaml
|
|||
|
|
与测试环境配置完全分离
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import uvicorn
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 设置环境变量,让prompt_loader使用生产配置
|
|||
|
|
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Dict, Any, List, Optional
|
|||
|
|
from contextlib import asynccontextmanager
|
|||
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# 项目根目录
|
|||
|
|
project_root = Path(__file__).parent
|
|||
|
|
sys.path.append(str(project_root))
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
env_path = project_root / '.env'
|
|||
|
|
if env_path.exists():
|
|||
|
|
load_dotenv(env_path)
|
|||
|
|
print(f"[OK] 已加载环境变量文件: {env_path}")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] 未找到.env文件,使用系统环境变量")
|
|||
|
|
|
|||
|
|
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 生产环境配置 =============
|
|||
|
|
# 使用项目模块的默认参数,不硬编码
|
|||
|
|
OUTPUT_DIR = project_root / "api_outputs"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 请求/响应模型 =============
|
|||
|
|
class RetrieveRequest(BaseModel):
|
|||
|
|
"""检索请求模型"""
|
|||
|
|
query: str = Field(..., description="查询问题")
|
|||
|
|
mode: str = Field(default="0", description="调试模式: 0=自动")
|
|||
|
|
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RetrieveResponse(BaseModel):
|
|||
|
|
"""检索响应模型"""
|
|||
|
|
success: bool = Field(..., description="是否成功")
|
|||
|
|
query: str = Field(..., description="原始查询")
|
|||
|
|
answer: str = Field(..., description="生成的答案")
|
|||
|
|
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
|||
|
|
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
|||
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
|||
|
|
output_file: Optional[str] = Field(None, description="输出文件路径")
|
|||
|
|
error: Optional[str] = Field(None, description="错误信息")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class HealthResponse(BaseModel):
|
|||
|
|
"""健康检查响应"""
|
|||
|
|
status: str = "healthy"
|
|||
|
|
environment: str = "production"
|
|||
|
|
langsmith_connected: bool = False
|
|||
|
|
timestamp: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 全局状态管理 =============
|
|||
|
|
class AppState:
|
|||
|
|
"""应用状态管理"""
|
|||
|
|
retriever = None
|
|||
|
|
langsmith_connected = False
|
|||
|
|
request_count = 0
|
|||
|
|
last_request_time = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 生命周期管理 =============
|
|||
|
|
def initialize_retriever():
|
|||
|
|
"""初始化检索器(使用项目默认配置)"""
|
|||
|
|
try:
|
|||
|
|
# 检查LangSmith连接
|
|||
|
|
AppState.langsmith_connected = check_langsmith_connection()
|
|||
|
|
|
|||
|
|
# 创建检索器,使用项目默认参数
|
|||
|
|
# keyword 是必需参数,使用项目默认值 "test"
|
|||
|
|
AppState.retriever = create_langsmith_retriever(
|
|||
|
|
keyword="test" # 项目默认关键词
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
|
|||
|
|
print(" 环境: PRODUCTION")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 检索器初始化失败: {e}")
|
|||
|
|
AppState.retriever = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@asynccontextmanager
|
|||
|
|
async def lifespan(app: FastAPI):
|
|||
|
|
"""应用生命周期管理"""
|
|||
|
|
# 启动时执行
|
|||
|
|
print("="*60)
|
|||
|
|
print("[STARTING] 启动RAG API服务(生产环境)")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|||
|
|
print(f"[OK] 输出目录: {OUTPUT_DIR}")
|
|||
|
|
|
|||
|
|
# 初始化检索器
|
|||
|
|
initialize_retriever()
|
|||
|
|
|
|||
|
|
if AppState.retriever:
|
|||
|
|
print("[OK] 服务就绪(使用项目默认参数)")
|
|||
|
|
else:
|
|||
|
|
print("[WARNING] 服务启动但检索器未初始化")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
yield
|
|||
|
|
|
|||
|
|
# 关闭时执行
|
|||
|
|
print("\n[?] 服务关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= FastAPI应用 =============
|
|||
|
|
app = FastAPI(
|
|||
|
|
title="RAG API Service (Production)",
|
|||
|
|
description="RAG检索服务(生产环境)",
|
|||
|
|
version="1.0.0",
|
|||
|
|
lifespan=lifespan
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 配置CORS
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"], # 生产环境可以改为具体域名
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= API端点 =============
|
|||
|
|
@app.get("/", response_model=Dict[str, str])
|
|||
|
|
async def root():
|
|||
|
|
"""根路径"""
|
|||
|
|
return {
|
|||
|
|
"service": "RAG API Production Service",
|
|||
|
|
"status": "running",
|
|||
|
|
"environment": "production",
|
|||
|
|
"version": "1.0.0"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/health", response_model=HealthResponse)
|
|||
|
|
async def health_check():
|
|||
|
|
"""健康检查"""
|
|||
|
|
return HealthResponse(
|
|||
|
|
status="healthy" if AppState.retriever else "degraded",
|
|||
|
|
environment="production",
|
|||
|
|
langsmith_connected=AppState.langsmith_connected,
|
|||
|
|
timestamp=datetime.now().isoformat()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/query", response_model=RetrieveResponse)
|
|||
|
|
async def query(request: RetrieveRequest):
|
|||
|
|
"""RAG查询接口(生产环境)"""
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
AppState.request_count += 1
|
|||
|
|
AppState.last_request_time = datetime.now()
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"[IN] 收到查询请求 #{AppState.request_count}")
|
|||
|
|
print(f" 时间: {timestamp}")
|
|||
|
|
print(f" 查询: {request.query}")
|
|||
|
|
print(f" 环境: PRODUCTION")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
|
|||
|
|
if not AppState.retriever:
|
|||
|
|
raise HTTPException(status_code=503, detail="检索器未初始化")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 执行检索,使用项目的默认参数
|
|||
|
|
result = AppState.retriever.retrieve(
|
|||
|
|
request.query,
|
|||
|
|
request.mode
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存输出
|
|||
|
|
output_file = None
|
|||
|
|
if request.save_output:
|
|||
|
|
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
|||
|
|
output_data = {
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"environment": "production",
|
|||
|
|
"query": request.query,
|
|||
|
|
"answer": result.get("answer", ""),
|
|||
|
|
"supporting_facts": result.get("supporting_facts", []),
|
|||
|
|
"supporting_events": result.get("supporting_events", []),
|
|||
|
|
"metadata": {
|
|||
|
|
"request_count": AppState.request_count,
|
|||
|
|
"mode": request.mode
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
print(f"[SAVE] 输出已保存: {output_file}")
|
|||
|
|
|
|||
|
|
response = RetrieveResponse(
|
|||
|
|
success=True,
|
|||
|
|
query=request.query,
|
|||
|
|
answer=result.get("answer", ""),
|
|||
|
|
supporting_facts=result.get("supporting_facts", []),
|
|||
|
|
supporting_events=result.get("supporting_events", []),
|
|||
|
|
metadata={
|
|||
|
|
"environment": "production",
|
|||
|
|
"timestamp": timestamp
|
|||
|
|
},
|
|||
|
|
output_file=str(output_file) if output_file else None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
|||
|
|
return response
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 查询失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
return RetrieveResponse(
|
|||
|
|
success=False,
|
|||
|
|
query=request.query,
|
|||
|
|
answer="",
|
|||
|
|
error=str(e),
|
|||
|
|
metadata={"environment": "production", "timestamp": timestamp}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/retrieve", response_model=RetrieveResponse)
|
|||
|
|
async def retrieve(request: RetrieveRequest):
|
|||
|
|
"""检索接口(兼容旧版API)"""
|
|||
|
|
return await query(request)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/config", response_model=Dict[str, Any])
|
|||
|
|
async def get_config():
|
|||
|
|
"""获取当前生产配置信息"""
|
|||
|
|
return {
|
|||
|
|
"environment": "production",
|
|||
|
|
"note": "生产环境使用项目默认配置",
|
|||
|
|
"description": "所有参数使用项目模块的内置默认值",
|
|||
|
|
"stats": {
|
|||
|
|
"request_count": AppState.request_count,
|
|||
|
|
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/outputs", response_model=List[str])
|
|||
|
|
async def list_outputs():
|
|||
|
|
"""列出所有输出文件"""
|
|||
|
|
if not OUTPUT_DIR.exists():
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
|
|||
|
|
return sorted(files, reverse=True)[:100] # 只返回最近100个
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/outputs/{filename}")
|
|||
|
|
async def get_output(filename: str):
|
|||
|
|
"""获取指定输出文件"""
|
|||
|
|
file_path = OUTPUT_DIR / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
return FileResponse(file_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 主函数 =============
|
|||
|
|
def main():
|
|||
|
|
"""启动生产环境服务"""
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
|
|||
|
|
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
|||
|
|
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 生产环境强制关闭自动重载
|
|||
|
|
uvicorn.run(
|
|||
|
|
"rag_api_server_production:app",
|
|||
|
|
host=args.host,
|
|||
|
|
port=args.port,
|
|||
|
|
reload=False, # 生产环境禁用自动重载
|
|||
|
|
workers=1, # 可以根据需要调整worker数量
|
|||
|
|
log_level="info"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|