558 lines
19 KiB
Python
558 lines
19 KiB
Python
"""
|
||
RAG API 生产环境服务
|
||
使用独立的生产配置文件 rag_config_production.yaml
|
||
与测试环境配置完全分离
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import uvicorn
|
||
from pathlib import Path
|
||
|
||
# 设置环境变量,让prompt_loader使用生产配置
|
||
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
||
from datetime import datetime
|
||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
from dotenv import load_dotenv
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor, Future as ThreadFuture
|
||
import time
|
||
|
||
# 项目根目录
|
||
project_root = Path(__file__).parent
|
||
sys.path.append(str(project_root))
|
||
|
||
# 加载环境变量
|
||
env_path = project_root / '.env'
|
||
if env_path.exists():
|
||
load_dotenv(env_path)
|
||
print(f"[OK] 已加载环境变量文件: {env_path}")
|
||
else:
|
||
print(f"[WARNING] 未找到.env文件,使用系统环境变量")
|
||
|
||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||
|
||
|
||
# ============= 生产环境配置 =============
|
||
# 使用项目模块的默认参数,不硬编码
|
||
OUTPUT_DIR = project_root / "api_outputs"
|
||
|
||
|
||
# ============= 请求/响应模型 =============
|
||
class RetrieveRequest(BaseModel):
|
||
"""检索请求模型"""
|
||
query: str = Field(..., description="查询问题")
|
||
mode: str = Field(default="0", description="调试模式: 0=自动")
|
||
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
||
|
||
|
||
class RetrieveResponse(BaseModel):
|
||
"""检索响应模型"""
|
||
success: bool = Field(..., description="是否成功")
|
||
query: str = Field(..., description="原始查询")
|
||
answer: str = Field(..., description="生成的答案")
|
||
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
||
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||
output_file: Optional[str] = Field(None, description="输出文件路径")
|
||
error: Optional[str] = Field(None, description="错误信息")
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
"""健康检查响应"""
|
||
status: str = "healthy"
|
||
environment: str = "production"
|
||
langsmith_connected: bool = False
|
||
timestamp: str = ""
|
||
|
||
|
||
# ============= 全局状态管理 =============
|
||
class AppState:
|
||
"""应用状态管理"""
|
||
retriever = None
|
||
langsmith_connected = False
|
||
request_count = 0
|
||
last_request_time = None
|
||
|
||
|
||
# ============= 生命周期管理 =============
|
||
def initialize_retriever():
|
||
"""初始化检索器(使用项目默认配置)"""
|
||
try:
|
||
# 检查LangSmith连接
|
||
AppState.langsmith_connected = check_langsmith_connection()
|
||
|
||
# 创建检索器,使用项目默认参数
|
||
# keyword 是必需参数,使用项目默认值 "test"
|
||
AppState.retriever = create_langsmith_retriever(
|
||
keyword="test" # 项目默认关键词
|
||
)
|
||
|
||
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
|
||
print(" 环境: PRODUCTION")
|
||
except Exception as e:
|
||
print(f"[ERROR] 检索器初始化失败: {e}")
|
||
AppState.retriever = None
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
# 启动时执行
|
||
print("="*60)
|
||
print("[STARTING] 启动RAG API服务(生产环境)")
|
||
print("="*60)
|
||
|
||
# 创建输出目录
|
||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||
print(f"[OK] 输出目录: {OUTPUT_DIR}")
|
||
|
||
# 初始化检索器
|
||
initialize_retriever()
|
||
|
||
if AppState.retriever:
|
||
print("[OK] 服务就绪(使用项目默认参数)")
|
||
else:
|
||
print("[WARNING] 服务启动但检索器未初始化")
|
||
print("="*60)
|
||
|
||
yield
|
||
|
||
# 关闭时执行
|
||
print("\n[?] 服务关闭")
|
||
|
||
|
||
# ============= FastAPI应用 =============
|
||
app = FastAPI(
|
||
title="RAG API Service (Production)",
|
||
description="RAG检索服务(生产环境)",
|
||
version="1.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 生产环境可以改为具体域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ============= API端点 =============
|
||
@app.get("/", response_model=Dict[str, str])
|
||
async def root():
|
||
"""根路径"""
|
||
return {
|
||
"service": "RAG API Production Service",
|
||
"status": "running",
|
||
"environment": "production",
|
||
"version": "1.0.0"
|
||
}
|
||
|
||
|
||
@app.get("/health", response_model=HealthResponse)
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return HealthResponse(
|
||
status="healthy" if AppState.retriever else "degraded",
|
||
environment="production",
|
||
langsmith_connected=AppState.langsmith_connected,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
|
||
@app.post("/query", response_model=RetrieveResponse)
|
||
async def query(request: RetrieveRequest):
|
||
"""RAG查询接口(生产环境)"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
AppState.request_count += 1
|
||
AppState.last_request_time = datetime.now()
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"[IN] 收到查询请求 #{AppState.request_count}")
|
||
print(f" 时间: {timestamp}")
|
||
print(f" 查询: {request.query}")
|
||
print(f" 环境: PRODUCTION")
|
||
print(f"{'='*60}")
|
||
|
||
if not AppState.retriever:
|
||
raise HTTPException(status_code=503, detail="检索器未初始化")
|
||
|
||
try:
|
||
# 执行检索,使用项目的默认参数
|
||
result = AppState.retriever.retrieve(
|
||
request.query,
|
||
request.mode
|
||
)
|
||
|
||
# 保存输出
|
||
output_file = None
|
||
if request.save_output:
|
||
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
||
output_data = {
|
||
"timestamp": timestamp,
|
||
"environment": "production",
|
||
"query": request.query,
|
||
"answer": result.get("answer", ""),
|
||
"supporting_facts": result.get("supporting_facts", []),
|
||
"supporting_events": result.get("supporting_events", []),
|
||
"metadata": {
|
||
"request_count": AppState.request_count,
|
||
"mode": request.mode
|
||
}
|
||
}
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"[SAVE] 输出已保存: {output_file}")
|
||
|
||
response = RetrieveResponse(
|
||
success=True,
|
||
query=request.query,
|
||
answer=result.get("answer", ""),
|
||
supporting_facts=result.get("supporting_facts", []),
|
||
supporting_events=result.get("supporting_events", []),
|
||
metadata={
|
||
"environment": "production",
|
||
"timestamp": timestamp
|
||
},
|
||
output_file=str(output_file) if output_file else None
|
||
)
|
||
|
||
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
||
return response
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 查询失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
return RetrieveResponse(
|
||
success=False,
|
||
query=request.query,
|
||
answer="",
|
||
error=str(e),
|
||
metadata={"environment": "production", "timestamp": timestamp}
|
||
)
|
||
|
||
|
||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||
async def retrieve(request: RetrieveRequest):
|
||
"""检索接口(兼容旧版API)"""
|
||
return await query(request)
|
||
|
||
|
||
# ============= 流式检索支持 =============
|
||
# 创建线程池用于异步执行检索
|
||
executor = ThreadPoolExecutor(max_workers=4)
|
||
|
||
|
||
@app.post("/retrieve/stream")
|
||
async def retrieve_stream(request: RetrieveRequest):
|
||
"""
|
||
流式检索接口 - 支持SSE实时状态反馈
|
||
|
||
返回6类关键信息:
|
||
1. 子问题分解列表
|
||
2. 主要来源文档
|
||
3. 当前检索轮次
|
||
4. 总文档数
|
||
5. 耗时统计
|
||
6. 最终答案
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
start_time = time.time()
|
||
|
||
async def generate_sse_events() -> AsyncGenerator[str, None]:
|
||
"""生成SSE事件流"""
|
||
AppState.request_count += 1
|
||
AppState.last_request_time = datetime.now()
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"[STREAM] 收到流式查询请求 #{AppState.request_count}")
|
||
print(f" 时间: {timestamp}")
|
||
print(f" 查询: {request.query}")
|
||
print(f"{'='*60}")
|
||
|
||
if not AppState.retriever:
|
||
error_data = {
|
||
"type": "error",
|
||
"message": "检索器未初始化",
|
||
"progress": 0
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
return
|
||
|
||
# 状态队列,用于接收检索过程中的状态更新
|
||
status_queue = asyncio.Queue()
|
||
# 使用线程安全的Future
|
||
result_future = ThreadFuture()
|
||
|
||
# 获取当前运行的事件循环
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def status_callback(status_type: str, data: Any):
|
||
"""状态回调函数(在线程中运行)"""
|
||
elapsed = time.time() - start_time
|
||
status_data = {
|
||
"type": status_type,
|
||
"data": data,
|
||
"elapsed": round(elapsed, 1),
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
# 根据不同类型设置进度
|
||
progress_map = {
|
||
"starting": 5,
|
||
"complexity_check": 10,
|
||
"sub_queries": 15,
|
||
"initial_retrieval": 30,
|
||
"documents": 40,
|
||
"sufficiency_check": 50,
|
||
"iteration": 60,
|
||
"parallel_retrieval": 70,
|
||
"pagerank": 80,
|
||
"generating": 90,
|
||
"answer": 100
|
||
}
|
||
status_data["progress"] = progress_map.get(status_type, 50)
|
||
|
||
# 将状态放入队列 - 使用线程安全的方式
|
||
try:
|
||
# 使用asyncio.run_coroutine_threadsafe确保线程安全
|
||
future = asyncio.run_coroutine_threadsafe(
|
||
status_queue.put(status_data),
|
||
loop # 使用之前保存的loop
|
||
)
|
||
# 等待完成,设置超时防止阻塞
|
||
future.result(timeout=1.0)
|
||
except Exception as e:
|
||
print(f"[WARNING] Failed to send status update: {e}")
|
||
|
||
def retriever_with_callback():
|
||
"""带回调的检索器执行"""
|
||
try:
|
||
# 调用增强版检索器(稍后实现)
|
||
from retriver.langsmith.langsmith_retriever_stream import stream_retrieve
|
||
result = stream_retrieve(
|
||
AppState.retriever,
|
||
request.query,
|
||
request.mode,
|
||
status_callback
|
||
)
|
||
result_future.set_result(result)
|
||
except Exception as e:
|
||
result_future.set_exception(e)
|
||
|
||
# 在线程池中执行检索
|
||
executor.submit(retriever_with_callback)
|
||
|
||
# 初始状态已在stream_retrieve中发送,这里不再重复发送
|
||
# yield f"data: {json.dumps({'type': 'starting', 'message': '正在分析您的问题...', 'progress': 5}, ensure_ascii=False)}\n\n"
|
||
|
||
# 持续发送状态更新直到检索完成
|
||
while not result_future.done():
|
||
try:
|
||
# 非阻塞获取状态,超时100ms
|
||
status = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
||
|
||
# 处理不同类型的状态
|
||
if status["type"] == "sub_queries":
|
||
# 子问题列表
|
||
queries = status["data"][:3] # 最多显示3个
|
||
message = f"📝 问题分解为{len(status['data'])}个子查询"
|
||
status["formatted_data"] = queries
|
||
status["message"] = message
|
||
|
||
elif status["type"] == "documents":
|
||
# 文档信息
|
||
data = status["data"]
|
||
retrieval_type = data.get("retrieval_type", "检索")
|
||
retrieval_reason = data.get("retrieval_reason", "")
|
||
new_docs = data.get("new_docs", 0)
|
||
total_docs = data.get("count", 0)
|
||
|
||
# 只在有新文档时显示
|
||
if new_docs > 0:
|
||
if data.get("is_incremental", False):
|
||
# 增量检索(第二次及以后)
|
||
message = f"📊 {retrieval_reason}:新增 {new_docs} 篇文档(总计 {total_docs} 篇)"
|
||
else:
|
||
# 初始检索
|
||
message = f"📊 已检索 {total_docs} 篇文档"
|
||
|
||
if "sources" in data:
|
||
sources = data["sources"][:5] # 最多显示5个
|
||
status["formatted_sources"] = sources
|
||
message += f" | 主要来源:{', '.join(sources)}"
|
||
status["message"] = message
|
||
else:
|
||
# 没有新文档,不显示
|
||
continue
|
||
|
||
elif status["type"] == "iteration":
|
||
# 迭代信息
|
||
current = status["data"]["current"]
|
||
max_iter = status["data"]["max"]
|
||
# 只有在实际进行迭代时才显示(current > 0表示已经开始迭代)
|
||
if current > 0:
|
||
message = f"🔄 迭代检索:第{current}轮"
|
||
else:
|
||
message = f"🔄 开始检索流程"
|
||
status["message"] = message
|
||
|
||
elif status["type"] == "sufficiency_check":
|
||
# 充分性检查
|
||
is_sufficient = status["data"].get("is_sufficient", False)
|
||
confidence = status["data"].get("confidence", 0)
|
||
message = f"✅ 充分性检查:{'充分' if is_sufficient else '需要更多信息'} (置信度: {confidence:.1%})"
|
||
status["message"] = message
|
||
|
||
elif status["type"] == "generating":
|
||
# 正在生成答案
|
||
message = "💡 正在生成答案..."
|
||
status["message"] = message
|
||
|
||
elif status["type"] == "answer":
|
||
# 完整答案
|
||
answer_data = {
|
||
"type": "answer",
|
||
"data": {
|
||
"content": status["data"].get("content", ""),
|
||
"complete": True
|
||
},
|
||
"message": "✅ 答案生成完成",
|
||
"progress": 100,
|
||
"elapsed": status.get("elapsed", round(time.time() - start_time, 1))
|
||
}
|
||
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
||
continue # 跳过默认处理
|
||
|
||
# 发送格式化的状态
|
||
yield f"data: {json.dumps(status, ensure_ascii=False)}\n\n"
|
||
|
||
except asyncio.TimeoutError:
|
||
# 没有新状态,继续等待
|
||
continue
|
||
except Exception as e:
|
||
print(f"[ERROR] 状态处理异常: {e}")
|
||
continue
|
||
|
||
# 获取最终结果
|
||
try:
|
||
# 在异步上下文中等待线程Future完成
|
||
result = await loop.run_in_executor(None, result_future.result, 60) # 最多等待60秒
|
||
|
||
# 发送最终答案
|
||
if result:
|
||
answer_data = {
|
||
"type": "answer",
|
||
"data": {
|
||
"content": result.get("answer", ""),
|
||
"supporting_facts": result.get("supporting_facts", []),
|
||
"supporting_events": result.get("supporting_events", []),
|
||
"total_documents": result.get("total_passages", 0),
|
||
"iterations": result.get("iterations", 0)
|
||
},
|
||
"message": "✅ 答案生成完成",
|
||
"progress": 100,
|
||
"elapsed": round(time.time() - start_time, 1)
|
||
}
|
||
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
||
|
||
print(f"[OK] 流式查询完成 (耗时: {round(time.time() - start_time, 1)}秒)")
|
||
|
||
except Exception as e:
|
||
error_data = {
|
||
"type": "error",
|
||
"message": f"检索过程出错: {str(e)}",
|
||
"progress": 0,
|
||
"elapsed": round(time.time() - start_time, 1)
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
print(f"[ERROR] 流式查询失败: {e}")
|
||
|
||
# 发送完成信号
|
||
yield "data: [DONE]\n\n"
|
||
|
||
# 返回SSE响应
|
||
return StreamingResponse(
|
||
generate_sse_events(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||
}
|
||
)
|
||
|
||
|
||
@app.get("/config", response_model=Dict[str, Any])
|
||
async def get_config():
|
||
"""获取当前生产配置信息"""
|
||
return {
|
||
"environment": "production",
|
||
"note": "生产环境使用项目默认配置",
|
||
"description": "所有参数使用项目模块的内置默认值",
|
||
"stats": {
|
||
"request_count": AppState.request_count,
|
||
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/outputs", response_model=List[str])
|
||
async def list_outputs():
|
||
"""列出所有输出文件"""
|
||
if not OUTPUT_DIR.exists():
|
||
return []
|
||
|
||
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
|
||
return sorted(files, reverse=True)[:100] # 只返回最近100个
|
||
|
||
|
||
@app.get("/outputs/{filename}")
|
||
async def get_output(filename: str):
|
||
"""获取指定输出文件"""
|
||
file_path = OUTPUT_DIR / filename
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
return FileResponse(file_path)
|
||
|
||
|
||
# ============= 主函数 =============
|
||
def main():
|
||
"""启动生产环境服务"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
||
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 生产环境强制关闭自动重载
|
||
uvicorn.run(
|
||
"rag_api_server_production:app",
|
||
host=args.host,
|
||
port=args.port,
|
||
reload=False, # 生产环境禁用自动重载
|
||
workers=1, # 可以根据需要调整worker数量
|
||
log_level="info"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |