Files
AIEC-RAG---/AIEC-RAG/rag_api_server_production.py

558 lines
19 KiB
Python
Raw Normal View History

2025-09-25 10:33:37 +08:00
"""
RAG API 生产环境服务
使用独立的生产配置文件 rag_config_production.yaml
与测试环境配置完全分离
"""
import os
import sys
import json
import uvicorn
from pathlib import Path
# 设置环境变量让prompt_loader使用生产配置
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
from datetime import datetime
from typing import Dict, Any, List, Optional, AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import asyncio
from concurrent.futures import ThreadPoolExecutor, Future as ThreadFuture
import time
# 项目根目录
project_root = Path(__file__).parent
sys.path.append(str(project_root))
# 加载环境变量
env_path = project_root / '.env'
if env_path.exists():
load_dotenv(env_path)
print(f"[OK] 已加载环境变量文件: {env_path}")
else:
print(f"[WARNING] 未找到.env文件使用系统环境变量")
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
# ============= 生产环境配置 =============
# 使用项目模块的默认参数,不硬编码
OUTPUT_DIR = project_root / "api_outputs"
# ============= 请求/响应模型 =============
class RetrieveRequest(BaseModel):
"""检索请求模型"""
query: str = Field(..., description="查询问题")
mode: str = Field(default="0", description="调试模式: 0=自动")
save_output: bool = Field(default=True, description="是否保存输出到文件")
class RetrieveResponse(BaseModel):
"""检索响应模型"""
success: bool = Field(..., description="是否成功")
query: str = Field(..., description="原始查询")
answer: str = Field(..., description="生成的答案")
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
output_file: Optional[str] = Field(None, description="输出文件路径")
error: Optional[str] = Field(None, description="错误信息")
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str = "healthy"
environment: str = "production"
langsmith_connected: bool = False
timestamp: str = ""
# ============= 全局状态管理 =============
class AppState:
"""应用状态管理"""
retriever = None
langsmith_connected = False
request_count = 0
last_request_time = None
# ============= 生命周期管理 =============
def initialize_retriever():
"""初始化检索器(使用项目默认配置)"""
try:
# 检查LangSmith连接
AppState.langsmith_connected = check_langsmith_connection()
# 创建检索器,使用项目默认参数
# keyword 是必需参数,使用项目默认值 "test"
AppState.retriever = create_langsmith_retriever(
keyword="test" # 项目默认关键词
)
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
print(" 环境: PRODUCTION")
except Exception as e:
print(f"[ERROR] 检索器初始化失败: {e}")
AppState.retriever = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时执行
print("="*60)
print("[STARTING] 启动RAG API服务生产环境")
print("="*60)
# 创建输出目录
OUTPUT_DIR.mkdir(exist_ok=True)
print(f"[OK] 输出目录: {OUTPUT_DIR}")
# 初始化检索器
initialize_retriever()
if AppState.retriever:
print("[OK] 服务就绪(使用项目默认参数)")
else:
print("[WARNING] 服务启动但检索器未初始化")
print("="*60)
yield
# 关闭时执行
print("\n[?] 服务关闭")
# ============= FastAPI应用 =============
app = FastAPI(
title="RAG API Service (Production)",
description="RAG检索服务生产环境",
version="1.0.0",
lifespan=lifespan
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境可以改为具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============= API端点 =============
@app.get("/", response_model=Dict[str, str])
async def root():
"""根路径"""
return {
"service": "RAG API Production Service",
"status": "running",
"environment": "production",
"version": "1.0.0"
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
return HealthResponse(
status="healthy" if AppState.retriever else "degraded",
environment="production",
langsmith_connected=AppState.langsmith_connected,
timestamp=datetime.now().isoformat()
)
@app.post("/query", response_model=RetrieveResponse)
async def query(request: RetrieveRequest):
"""RAG查询接口生产环境"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
AppState.request_count += 1
AppState.last_request_time = datetime.now()
print(f"\n{'='*60}")
print(f"[IN] 收到查询请求 #{AppState.request_count}")
print(f" 时间: {timestamp}")
print(f" 查询: {request.query}")
print(f" 环境: PRODUCTION")
print(f"{'='*60}")
if not AppState.retriever:
raise HTTPException(status_code=503, detail="检索器未初始化")
try:
# 执行检索,使用项目的默认参数
result = AppState.retriever.retrieve(
request.query,
request.mode
)
# 保存输出
output_file = None
if request.save_output:
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
output_data = {
"timestamp": timestamp,
"environment": "production",
"query": request.query,
"answer": result.get("answer", ""),
"supporting_facts": result.get("supporting_facts", []),
"supporting_events": result.get("supporting_events", []),
"metadata": {
"request_count": AppState.request_count,
"mode": request.mode
}
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
print(f"[SAVE] 输出已保存: {output_file}")
response = RetrieveResponse(
success=True,
query=request.query,
answer=result.get("answer", ""),
supporting_facts=result.get("supporting_facts", []),
supporting_events=result.get("supporting_events", []),
metadata={
"environment": "production",
"timestamp": timestamp
},
output_file=str(output_file) if output_file else None
)
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
return response
except Exception as e:
print(f"[ERROR] 查询失败: {e}")
import traceback
traceback.print_exc()
return RetrieveResponse(
success=False,
query=request.query,
answer="",
error=str(e),
metadata={"environment": "production", "timestamp": timestamp}
)
@app.post("/retrieve", response_model=RetrieveResponse)
async def retrieve(request: RetrieveRequest):
"""检索接口兼容旧版API"""
return await query(request)
# ============= 流式检索支持 =============
# 创建线程池用于异步执行检索
executor = ThreadPoolExecutor(max_workers=4)
@app.post("/retrieve/stream")
async def retrieve_stream(request: RetrieveRequest):
"""
流式检索接口 - 支持SSE实时状态反馈
返回6类关键信息
1. 子问题分解列表
2. 主要来源文档
3. 当前检索轮次
4. 总文档数
5. 耗时统计
6. 最终答案
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
start_time = time.time()
async def generate_sse_events() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
AppState.request_count += 1
AppState.last_request_time = datetime.now()
print(f"\n{'='*60}")
print(f"[STREAM] 收到流式查询请求 #{AppState.request_count}")
print(f" 时间: {timestamp}")
print(f" 查询: {request.query}")
print(f"{'='*60}")
if not AppState.retriever:
error_data = {
"type": "error",
"message": "检索器未初始化",
"progress": 0
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
# 状态队列,用于接收检索过程中的状态更新
status_queue = asyncio.Queue()
# 使用线程安全的Future
result_future = ThreadFuture()
# 获取当前运行的事件循环
loop = asyncio.get_running_loop()
def status_callback(status_type: str, data: Any):
"""状态回调函数(在线程中运行)"""
elapsed = time.time() - start_time
status_data = {
"type": status_type,
"data": data,
"elapsed": round(elapsed, 1),
"timestamp": time.time()
}
# 根据不同类型设置进度
progress_map = {
"starting": 5,
"complexity_check": 10,
"sub_queries": 15,
"initial_retrieval": 30,
"documents": 40,
"sufficiency_check": 50,
"iteration": 60,
"parallel_retrieval": 70,
"pagerank": 80,
"generating": 90,
"answer": 100
}
status_data["progress"] = progress_map.get(status_type, 50)
# 将状态放入队列 - 使用线程安全的方式
try:
# 使用asyncio.run_coroutine_threadsafe确保线程安全
future = asyncio.run_coroutine_threadsafe(
status_queue.put(status_data),
loop # 使用之前保存的loop
)
# 等待完成,设置超时防止阻塞
future.result(timeout=1.0)
except Exception as e:
print(f"[WARNING] Failed to send status update: {e}")
def retriever_with_callback():
"""带回调的检索器执行"""
try:
# 调用增强版检索器(稍后实现)
from retriver.langsmith.langsmith_retriever_stream import stream_retrieve
result = stream_retrieve(
AppState.retriever,
request.query,
request.mode,
status_callback
)
result_future.set_result(result)
except Exception as e:
result_future.set_exception(e)
# 在线程池中执行检索
executor.submit(retriever_with_callback)
# 初始状态已在stream_retrieve中发送这里不再重复发送
# yield f"data: {json.dumps({'type': 'starting', 'message': '正在分析您的问题...', 'progress': 5}, ensure_ascii=False)}\n\n"
# 持续发送状态更新直到检索完成
while not result_future.done():
try:
# 非阻塞获取状态超时100ms
status = await asyncio.wait_for(status_queue.get(), timeout=0.1)
# 处理不同类型的状态
if status["type"] == "sub_queries":
# 子问题列表
queries = status["data"][:3] # 最多显示3个
message = f"📝 问题分解为{len(status['data'])}个子查询"
status["formatted_data"] = queries
status["message"] = message
elif status["type"] == "documents":
# 文档信息
data = status["data"]
retrieval_type = data.get("retrieval_type", "检索")
retrieval_reason = data.get("retrieval_reason", "")
new_docs = data.get("new_docs", 0)
total_docs = data.get("count", 0)
# 只在有新文档时显示
if new_docs > 0:
if data.get("is_incremental", False):
# 增量检索(第二次及以后)
message = f"📊 {retrieval_reason}:新增 {new_docs} 篇文档(总计 {total_docs} 篇)"
else:
# 初始检索
message = f"📊 已检索 {total_docs} 篇文档"
if "sources" in data:
sources = data["sources"][:5] # 最多显示5个
status["formatted_sources"] = sources
message += f" | 主要来源:{', '.join(sources)}"
status["message"] = message
else:
# 没有新文档,不显示
continue
elif status["type"] == "iteration":
# 迭代信息
current = status["data"]["current"]
max_iter = status["data"]["max"]
# 只有在实际进行迭代时才显示current > 0表示已经开始迭代
if current > 0:
message = f"🔄 迭代检索:第{current}"
else:
message = f"🔄 开始检索流程"
status["message"] = message
elif status["type"] == "sufficiency_check":
# 充分性检查
is_sufficient = status["data"].get("is_sufficient", False)
confidence = status["data"].get("confidence", 0)
message = f"✅ 充分性检查:{'充分' if is_sufficient else '需要更多信息'} (置信度: {confidence:.1%})"
status["message"] = message
elif status["type"] == "generating":
# 正在生成答案
message = "💡 正在生成答案..."
status["message"] = message
elif status["type"] == "answer":
# 完整答案
answer_data = {
"type": "answer",
"data": {
"content": status["data"].get("content", ""),
"complete": True
},
"message": "✅ 答案生成完成",
"progress": 100,
"elapsed": status.get("elapsed", round(time.time() - start_time, 1))
}
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
continue # 跳过默认处理
# 发送格式化的状态
yield f"data: {json.dumps(status, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
# 没有新状态,继续等待
continue
except Exception as e:
print(f"[ERROR] 状态处理异常: {e}")
continue
# 获取最终结果
try:
# 在异步上下文中等待线程Future完成
result = await loop.run_in_executor(None, result_future.result, 60) # 最多等待60秒
# 发送最终答案
if result:
answer_data = {
"type": "answer",
"data": {
"content": result.get("answer", ""),
"supporting_facts": result.get("supporting_facts", []),
"supporting_events": result.get("supporting_events", []),
"total_documents": result.get("total_passages", 0),
"iterations": result.get("iterations", 0)
},
"message": "✅ 答案生成完成",
"progress": 100,
"elapsed": round(time.time() - start_time, 1)
}
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
print(f"[OK] 流式查询完成 (耗时: {round(time.time() - start_time, 1)}秒)")
except Exception as e:
error_data = {
"type": "error",
"message": f"检索过程出错: {str(e)}",
"progress": 0,
"elapsed": round(time.time() - start_time, 1)
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
print(f"[ERROR] 流式查询失败: {e}")
# 发送完成信号
yield "data: [DONE]\n\n"
# 返回SSE响应
return StreamingResponse(
generate_sse_events(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
)
@app.get("/config", response_model=Dict[str, Any])
async def get_config():
"""获取当前生产配置信息"""
return {
"environment": "production",
"note": "生产环境使用项目默认配置",
"description": "所有参数使用项目模块的内置默认值",
"stats": {
"request_count": AppState.request_count,
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
}
}
@app.get("/outputs", response_model=List[str])
async def list_outputs():
"""列出所有输出文件"""
if not OUTPUT_DIR.exists():
return []
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
return sorted(files, reverse=True)[:100] # 只返回最近100个
@app.get("/outputs/{filename}")
async def get_output(filename: str):
"""获取指定输出文件"""
file_path = OUTPUT_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(file_path)
# ============= 主函数 =============
def main():
"""启动生产环境服务"""
import argparse
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
parser.add_argument("--port", type=int, default=8000, help="服务端口")
args = parser.parse_args()
# 生产环境强制关闭自动重载
uvicorn.run(
"rag_api_server_production:app",
host=args.host,
port=args.port,
reload=False, # 生产环境禁用自动重载
workers=1, # 可以根据需要调整worker数量
log_level="info"
)
if __name__ == "__main__":
main()