first commit
This commit is contained in:
558
AIEC-RAG/rag_api_server_production.py
Normal file
558
AIEC-RAG/rag_api_server_production.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""
|
||||
RAG API 生产环境服务
|
||||
使用独立的生产配置文件 rag_config_production.yaml
|
||||
与测试环境配置完全分离
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import uvicorn
|
||||
from pathlib import Path
|
||||
|
||||
# 设置环境变量,让prompt_loader使用生产配置
|
||||
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from dotenv import load_dotenv
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, Future as ThreadFuture
|
||||
import time
|
||||
|
||||
# 项目根目录
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
# 加载环境变量
|
||||
env_path = project_root / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"[OK] 已加载环境变量文件: {env_path}")
|
||||
else:
|
||||
print(f"[WARNING] 未找到.env文件,使用系统环境变量")
|
||||
|
||||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||||
|
||||
|
||||
# ============= 生产环境配置 =============
|
||||
# 使用项目模块的默认参数,不硬编码
|
||||
OUTPUT_DIR = project_root / "api_outputs"
|
||||
|
||||
|
||||
# ============= 请求/响应模型 =============
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""检索请求模型"""
|
||||
query: str = Field(..., description="查询问题")
|
||||
mode: str = Field(default="0", description="调试模式: 0=自动")
|
||||
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
||||
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
"""检索响应模型"""
|
||||
success: bool = Field(..., description="是否成功")
|
||||
query: str = Field(..., description="原始查询")
|
||||
answer: str = Field(..., description="生成的答案")
|
||||
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
||||
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
output_file: Optional[str] = Field(None, description="输出文件路径")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
status: str = "healthy"
|
||||
environment: str = "production"
|
||||
langsmith_connected: bool = False
|
||||
timestamp: str = ""
|
||||
|
||||
|
||||
# ============= 全局状态管理 =============
|
||||
class AppState:
|
||||
"""应用状态管理"""
|
||||
retriever = None
|
||||
langsmith_connected = False
|
||||
request_count = 0
|
||||
last_request_time = None
|
||||
|
||||
|
||||
# ============= 生命周期管理 =============
|
||||
def initialize_retriever():
|
||||
"""初始化检索器(使用项目默认配置)"""
|
||||
try:
|
||||
# 检查LangSmith连接
|
||||
AppState.langsmith_connected = check_langsmith_connection()
|
||||
|
||||
# 创建检索器,使用项目默认参数
|
||||
# keyword 是必需参数,使用项目默认值 "test"
|
||||
AppState.retriever = create_langsmith_retriever(
|
||||
keyword="test" # 项目默认关键词
|
||||
)
|
||||
|
||||
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
|
||||
print(" 环境: PRODUCTION")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 检索器初始化失败: {e}")
|
||||
AppState.retriever = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动时执行
|
||||
print("="*60)
|
||||
print("[STARTING] 启动RAG API服务(生产环境)")
|
||||
print("="*60)
|
||||
|
||||
# 创建输出目录
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
print(f"[OK] 输出目录: {OUTPUT_DIR}")
|
||||
|
||||
# 初始化检索器
|
||||
initialize_retriever()
|
||||
|
||||
if AppState.retriever:
|
||||
print("[OK] 服务就绪(使用项目默认参数)")
|
||||
else:
|
||||
print("[WARNING] 服务启动但检索器未初始化")
|
||||
print("="*60)
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时执行
|
||||
print("\n[?] 服务关闭")
|
||||
|
||||
|
||||
# ============= FastAPI应用 =============
|
||||
app = FastAPI(
|
||||
title="RAG API Service (Production)",
|
||||
description="RAG检索服务(生产环境)",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境可以改为具体域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ============= API端点 =============
|
||||
@app.get("/", response_model=Dict[str, str])
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"service": "RAG API Production Service",
|
||||
"status": "running",
|
||||
"environment": "production",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return HealthResponse(
|
||||
status="healthy" if AppState.retriever else "degraded",
|
||||
environment="production",
|
||||
langsmith_connected=AppState.langsmith_connected,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
|
||||
@app.post("/query", response_model=RetrieveResponse)
|
||||
async def query(request: RetrieveRequest):
|
||||
"""RAG查询接口(生产环境)"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
AppState.request_count += 1
|
||||
AppState.last_request_time = datetime.now()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[IN] 收到查询请求 #{AppState.request_count}")
|
||||
print(f" 时间: {timestamp}")
|
||||
print(f" 查询: {request.query}")
|
||||
print(f" 环境: PRODUCTION")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if not AppState.retriever:
|
||||
raise HTTPException(status_code=503, detail="检索器未初始化")
|
||||
|
||||
try:
|
||||
# 执行检索,使用项目的默认参数
|
||||
result = AppState.retriever.retrieve(
|
||||
request.query,
|
||||
request.mode
|
||||
)
|
||||
|
||||
# 保存输出
|
||||
output_file = None
|
||||
if request.save_output:
|
||||
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
||||
output_data = {
|
||||
"timestamp": timestamp,
|
||||
"environment": "production",
|
||||
"query": request.query,
|
||||
"answer": result.get("answer", ""),
|
||||
"supporting_facts": result.get("supporting_facts", []),
|
||||
"supporting_events": result.get("supporting_events", []),
|
||||
"metadata": {
|
||||
"request_count": AppState.request_count,
|
||||
"mode": request.mode
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"[SAVE] 输出已保存: {output_file}")
|
||||
|
||||
response = RetrieveResponse(
|
||||
success=True,
|
||||
query=request.query,
|
||||
answer=result.get("answer", ""),
|
||||
supporting_facts=result.get("supporting_facts", []),
|
||||
supporting_events=result.get("supporting_events", []),
|
||||
metadata={
|
||||
"environment": "production",
|
||||
"timestamp": timestamp
|
||||
},
|
||||
output_file=str(output_file) if output_file else None
|
||||
)
|
||||
|
||||
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return RetrieveResponse(
|
||||
success=False,
|
||||
query=request.query,
|
||||
answer="",
|
||||
error=str(e),
|
||||
metadata={"environment": "production", "timestamp": timestamp}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve(request: RetrieveRequest):
|
||||
"""检索接口(兼容旧版API)"""
|
||||
return await query(request)
|
||||
|
||||
|
||||
# ============= 流式检索支持 =============
|
||||
# 创建线程池用于异步执行检索
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@app.post("/retrieve/stream")
|
||||
async def retrieve_stream(request: RetrieveRequest):
|
||||
"""
|
||||
流式检索接口 - 支持SSE实时状态反馈
|
||||
|
||||
返回6类关键信息:
|
||||
1. 子问题分解列表
|
||||
2. 主要来源文档
|
||||
3. 当前检索轮次
|
||||
4. 总文档数
|
||||
5. 耗时统计
|
||||
6. 最终答案
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
start_time = time.time()
|
||||
|
||||
async def generate_sse_events() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
AppState.request_count += 1
|
||||
AppState.last_request_time = datetime.now()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[STREAM] 收到流式查询请求 #{AppState.request_count}")
|
||||
print(f" 时间: {timestamp}")
|
||||
print(f" 查询: {request.query}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if not AppState.retriever:
|
||||
error_data = {
|
||||
"type": "error",
|
||||
"message": "检索器未初始化",
|
||||
"progress": 0
|
||||
}
|
||||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# 状态队列,用于接收检索过程中的状态更新
|
||||
status_queue = asyncio.Queue()
|
||||
# 使用线程安全的Future
|
||||
result_future = ThreadFuture()
|
||||
|
||||
# 获取当前运行的事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def status_callback(status_type: str, data: Any):
|
||||
"""状态回调函数(在线程中运行)"""
|
||||
elapsed = time.time() - start_time
|
||||
status_data = {
|
||||
"type": status_type,
|
||||
"data": data,
|
||||
"elapsed": round(elapsed, 1),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 根据不同类型设置进度
|
||||
progress_map = {
|
||||
"starting": 5,
|
||||
"complexity_check": 10,
|
||||
"sub_queries": 15,
|
||||
"initial_retrieval": 30,
|
||||
"documents": 40,
|
||||
"sufficiency_check": 50,
|
||||
"iteration": 60,
|
||||
"parallel_retrieval": 70,
|
||||
"pagerank": 80,
|
||||
"generating": 90,
|
||||
"answer": 100
|
||||
}
|
||||
status_data["progress"] = progress_map.get(status_type, 50)
|
||||
|
||||
# 将状态放入队列 - 使用线程安全的方式
|
||||
try:
|
||||
# 使用asyncio.run_coroutine_threadsafe确保线程安全
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
status_queue.put(status_data),
|
||||
loop # 使用之前保存的loop
|
||||
)
|
||||
# 等待完成,设置超时防止阻塞
|
||||
future.result(timeout=1.0)
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Failed to send status update: {e}")
|
||||
|
||||
def retriever_with_callback():
|
||||
"""带回调的检索器执行"""
|
||||
try:
|
||||
# 调用增强版检索器(稍后实现)
|
||||
from retriver.langsmith.langsmith_retriever_stream import stream_retrieve
|
||||
result = stream_retrieve(
|
||||
AppState.retriever,
|
||||
request.query,
|
||||
request.mode,
|
||||
status_callback
|
||||
)
|
||||
result_future.set_result(result)
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# 在线程池中执行检索
|
||||
executor.submit(retriever_with_callback)
|
||||
|
||||
# 初始状态已在stream_retrieve中发送,这里不再重复发送
|
||||
# yield f"data: {json.dumps({'type': 'starting', 'message': '正在分析您的问题...', 'progress': 5}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 持续发送状态更新直到检索完成
|
||||
while not result_future.done():
|
||||
try:
|
||||
# 非阻塞获取状态,超时100ms
|
||||
status = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
||||
|
||||
# 处理不同类型的状态
|
||||
if status["type"] == "sub_queries":
|
||||
# 子问题列表
|
||||
queries = status["data"][:3] # 最多显示3个
|
||||
message = f"📝 问题分解为{len(status['data'])}个子查询"
|
||||
status["formatted_data"] = queries
|
||||
status["message"] = message
|
||||
|
||||
elif status["type"] == "documents":
|
||||
# 文档信息
|
||||
data = status["data"]
|
||||
retrieval_type = data.get("retrieval_type", "检索")
|
||||
retrieval_reason = data.get("retrieval_reason", "")
|
||||
new_docs = data.get("new_docs", 0)
|
||||
total_docs = data.get("count", 0)
|
||||
|
||||
# 只在有新文档时显示
|
||||
if new_docs > 0:
|
||||
if data.get("is_incremental", False):
|
||||
# 增量检索(第二次及以后)
|
||||
message = f"📊 {retrieval_reason}:新增 {new_docs} 篇文档(总计 {total_docs} 篇)"
|
||||
else:
|
||||
# 初始检索
|
||||
message = f"📊 已检索 {total_docs} 篇文档"
|
||||
|
||||
if "sources" in data:
|
||||
sources = data["sources"][:5] # 最多显示5个
|
||||
status["formatted_sources"] = sources
|
||||
message += f" | 主要来源:{', '.join(sources)}"
|
||||
status["message"] = message
|
||||
else:
|
||||
# 没有新文档,不显示
|
||||
continue
|
||||
|
||||
elif status["type"] == "iteration":
|
||||
# 迭代信息
|
||||
current = status["data"]["current"]
|
||||
max_iter = status["data"]["max"]
|
||||
# 只有在实际进行迭代时才显示(current > 0表示已经开始迭代)
|
||||
if current > 0:
|
||||
message = f"🔄 迭代检索:第{current}轮"
|
||||
else:
|
||||
message = f"🔄 开始检索流程"
|
||||
status["message"] = message
|
||||
|
||||
elif status["type"] == "sufficiency_check":
|
||||
# 充分性检查
|
||||
is_sufficient = status["data"].get("is_sufficient", False)
|
||||
confidence = status["data"].get("confidence", 0)
|
||||
message = f"✅ 充分性检查:{'充分' if is_sufficient else '需要更多信息'} (置信度: {confidence:.1%})"
|
||||
status["message"] = message
|
||||
|
||||
elif status["type"] == "generating":
|
||||
# 正在生成答案
|
||||
message = "💡 正在生成答案..."
|
||||
status["message"] = message
|
||||
|
||||
elif status["type"] == "answer":
|
||||
# 完整答案
|
||||
answer_data = {
|
||||
"type": "answer",
|
||||
"data": {
|
||||
"content": status["data"].get("content", ""),
|
||||
"complete": True
|
||||
},
|
||||
"message": "✅ 答案生成完成",
|
||||
"progress": 100,
|
||||
"elapsed": status.get("elapsed", round(time.time() - start_time, 1))
|
||||
}
|
||||
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
||||
continue # 跳过默认处理
|
||||
|
||||
# 发送格式化的状态
|
||||
yield f"data: {json.dumps(status, ensure_ascii=False)}\n\n"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 没有新状态,继续等待
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 状态处理异常: {e}")
|
||||
continue
|
||||
|
||||
# 获取最终结果
|
||||
try:
|
||||
# 在异步上下文中等待线程Future完成
|
||||
result = await loop.run_in_executor(None, result_future.result, 60) # 最多等待60秒
|
||||
|
||||
# 发送最终答案
|
||||
if result:
|
||||
answer_data = {
|
||||
"type": "answer",
|
||||
"data": {
|
||||
"content": result.get("answer", ""),
|
||||
"supporting_facts": result.get("supporting_facts", []),
|
||||
"supporting_events": result.get("supporting_events", []),
|
||||
"total_documents": result.get("total_passages", 0),
|
||||
"iterations": result.get("iterations", 0)
|
||||
},
|
||||
"message": "✅ 答案生成完成",
|
||||
"progress": 100,
|
||||
"elapsed": round(time.time() - start_time, 1)
|
||||
}
|
||||
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
print(f"[OK] 流式查询完成 (耗时: {round(time.time() - start_time, 1)}秒)")
|
||||
|
||||
except Exception as e:
|
||||
error_data = {
|
||||
"type": "error",
|
||||
"message": f"检索过程出错: {str(e)}",
|
||||
"progress": 0,
|
||||
"elapsed": round(time.time() - start_time, 1)
|
||||
}
|
||||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||
print(f"[ERROR] 流式查询失败: {e}")
|
||||
|
||||
# 发送完成信号
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# 返回SSE响应
|
||||
return StreamingResponse(
|
||||
generate_sse_events(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/config", response_model=Dict[str, Any])
|
||||
async def get_config():
|
||||
"""获取当前生产配置信息"""
|
||||
return {
|
||||
"environment": "production",
|
||||
"note": "生产环境使用项目默认配置",
|
||||
"description": "所有参数使用项目模块的内置默认值",
|
||||
"stats": {
|
||||
"request_count": AppState.request_count,
|
||||
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/outputs", response_model=List[str])
|
||||
async def list_outputs():
|
||||
"""列出所有输出文件"""
|
||||
if not OUTPUT_DIR.exists():
|
||||
return []
|
||||
|
||||
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
|
||||
return sorted(files, reverse=True)[:100] # 只返回最近100个
|
||||
|
||||
|
||||
@app.get("/outputs/{filename}")
|
||||
async def get_output(filename: str):
|
||||
"""获取指定输出文件"""
|
||||
file_path = OUTPUT_DIR / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
# ============= 主函数 =============
|
||||
def main():
|
||||
"""启动生产环境服务"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
||||
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 生产环境强制关闭自动重载
|
||||
uvicorn.run(
|
||||
"rag_api_server_production:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=False, # 生产环境禁用自动重载
|
||||
workers=1, # 可以根据需要调整worker数量
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user