first commit
This commit is contained in:
309
rag_api_server_production.py
Normal file
309
rag_api_server_production.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""
|
||||
RAG API 生产环境服务
|
||||
使用独立的生产配置文件 rag_config_production.yaml
|
||||
与测试环境配置完全分离
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import uvicorn
|
||||
from pathlib import Path
|
||||
|
||||
# 设置环境变量,让prompt_loader使用生产配置
|
||||
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 项目根目录
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
# 加载环境变量
|
||||
env_path = project_root / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"[OK] 已加载环境变量文件: {env_path}")
|
||||
else:
|
||||
print(f"[WARNING] 未找到.env文件,使用系统环境变量")
|
||||
|
||||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||||
|
||||
|
||||
# ============= 生产环境配置 =============
|
||||
# 使用项目模块的默认参数,不硬编码
|
||||
OUTPUT_DIR = project_root / "api_outputs"
|
||||
|
||||
|
||||
# ============= 请求/响应模型 =============
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""检索请求模型"""
|
||||
query: str = Field(..., description="查询问题")
|
||||
mode: str = Field(default="0", description="调试模式: 0=自动")
|
||||
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
||||
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
"""检索响应模型"""
|
||||
success: bool = Field(..., description="是否成功")
|
||||
query: str = Field(..., description="原始查询")
|
||||
answer: str = Field(..., description="生成的答案")
|
||||
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
||||
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
output_file: Optional[str] = Field(None, description="输出文件路径")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
status: str = "healthy"
|
||||
environment: str = "production"
|
||||
langsmith_connected: bool = False
|
||||
timestamp: str = ""
|
||||
|
||||
|
||||
# ============= 全局状态管理 =============
|
||||
class AppState:
|
||||
"""应用状态管理"""
|
||||
retriever = None
|
||||
langsmith_connected = False
|
||||
request_count = 0
|
||||
last_request_time = None
|
||||
|
||||
|
||||
# ============= 生命周期管理 =============
|
||||
def initialize_retriever():
|
||||
"""初始化检索器(使用项目默认配置)"""
|
||||
try:
|
||||
# 检查LangSmith连接
|
||||
AppState.langsmith_connected = check_langsmith_connection()
|
||||
|
||||
# 创建检索器,使用项目默认参数
|
||||
# keyword 是必需参数,使用项目默认值 "test"
|
||||
AppState.retriever = create_langsmith_retriever(
|
||||
keyword="test" # 项目默认关键词
|
||||
)
|
||||
|
||||
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
|
||||
print(" 环境: PRODUCTION")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 检索器初始化失败: {e}")
|
||||
AppState.retriever = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动时执行
|
||||
print("="*60)
|
||||
print("[STARTING] 启动RAG API服务(生产环境)")
|
||||
print("="*60)
|
||||
|
||||
# 创建输出目录
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
print(f"[OK] 输出目录: {OUTPUT_DIR}")
|
||||
|
||||
# 初始化检索器
|
||||
initialize_retriever()
|
||||
|
||||
if AppState.retriever:
|
||||
print("[OK] 服务就绪(使用项目默认参数)")
|
||||
else:
|
||||
print("[WARNING] 服务启动但检索器未初始化")
|
||||
print("="*60)
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时执行
|
||||
print("\n[?] 服务关闭")
|
||||
|
||||
|
||||
# ============= FastAPI应用 =============
|
||||
app = FastAPI(
|
||||
title="RAG API Service (Production)",
|
||||
description="RAG检索服务(生产环境)",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境可以改为具体域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ============= API端点 =============
|
||||
@app.get("/", response_model=Dict[str, str])
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"service": "RAG API Production Service",
|
||||
"status": "running",
|
||||
"environment": "production",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return HealthResponse(
|
||||
status="healthy" if AppState.retriever else "degraded",
|
||||
environment="production",
|
||||
langsmith_connected=AppState.langsmith_connected,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
|
||||
@app.post("/query", response_model=RetrieveResponse)
|
||||
async def query(request: RetrieveRequest):
|
||||
"""RAG查询接口(生产环境)"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
AppState.request_count += 1
|
||||
AppState.last_request_time = datetime.now()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[IN] 收到查询请求 #{AppState.request_count}")
|
||||
print(f" 时间: {timestamp}")
|
||||
print(f" 查询: {request.query}")
|
||||
print(f" 环境: PRODUCTION")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if not AppState.retriever:
|
||||
raise HTTPException(status_code=503, detail="检索器未初始化")
|
||||
|
||||
try:
|
||||
# 执行检索,使用项目的默认参数
|
||||
result = AppState.retriever.retrieve(
|
||||
request.query,
|
||||
request.mode
|
||||
)
|
||||
|
||||
# 保存输出
|
||||
output_file = None
|
||||
if request.save_output:
|
||||
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
||||
output_data = {
|
||||
"timestamp": timestamp,
|
||||
"environment": "production",
|
||||
"query": request.query,
|
||||
"answer": result.get("answer", ""),
|
||||
"supporting_facts": result.get("supporting_facts", []),
|
||||
"supporting_events": result.get("supporting_events", []),
|
||||
"metadata": {
|
||||
"request_count": AppState.request_count,
|
||||
"mode": request.mode
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"[SAVE] 输出已保存: {output_file}")
|
||||
|
||||
response = RetrieveResponse(
|
||||
success=True,
|
||||
query=request.query,
|
||||
answer=result.get("answer", ""),
|
||||
supporting_facts=result.get("supporting_facts", []),
|
||||
supporting_events=result.get("supporting_events", []),
|
||||
metadata={
|
||||
"environment": "production",
|
||||
"timestamp": timestamp
|
||||
},
|
||||
output_file=str(output_file) if output_file else None
|
||||
)
|
||||
|
||||
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return RetrieveResponse(
|
||||
success=False,
|
||||
query=request.query,
|
||||
answer="",
|
||||
error=str(e),
|
||||
metadata={"environment": "production", "timestamp": timestamp}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve(request: RetrieveRequest):
|
||||
"""检索接口(兼容旧版API)"""
|
||||
return await query(request)
|
||||
|
||||
|
||||
@app.get("/config", response_model=Dict[str, Any])
|
||||
async def get_config():
|
||||
"""获取当前生产配置信息"""
|
||||
return {
|
||||
"environment": "production",
|
||||
"note": "生产环境使用项目默认配置",
|
||||
"description": "所有参数使用项目模块的内置默认值",
|
||||
"stats": {
|
||||
"request_count": AppState.request_count,
|
||||
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/outputs", response_model=List[str])
|
||||
async def list_outputs():
|
||||
"""列出所有输出文件"""
|
||||
if not OUTPUT_DIR.exists():
|
||||
return []
|
||||
|
||||
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
|
||||
return sorted(files, reverse=True)[:100] # 只返回最近100个
|
||||
|
||||
|
||||
@app.get("/outputs/{filename}")
|
||||
async def get_output(filename: str):
|
||||
"""获取指定输出文件"""
|
||||
file_path = OUTPUT_DIR / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
# ============= 主函数 =============
|
||||
def main():
|
||||
"""启动生产环境服务"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
||||
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 生产环境强制关闭自动重载
|
||||
uvicorn.run(
|
||||
"rag_api_server_production:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=False, # 生产环境禁用自动重载
|
||||
workers=1, # 可以根据需要调整worker数量
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user