647 lines
22 KiB
Python
647 lines
22 KiB
Python
|
|
"""
|
|||
|
|
RAG API 生产环境服务
|
|||
|
|
使用独立的生产配置文件 rag_config_production.yaml
|
|||
|
|
与测试环境配置完全分离
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import uvicorn
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 设置环境变量,让prompt_loader使用生产配置
|
|||
|
|
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Dict, Any, List, Optional, AsyncGenerator
|
|||
|
|
from contextlib import asynccontextmanager
|
|||
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
import asyncio
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor, Future as ThreadFuture
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
# 项目根目录
|
|||
|
|
project_root = Path(__file__).parent
|
|||
|
|
sys.path.append(str(project_root))
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
env_path = project_root / '.env'
|
|||
|
|
if env_path.exists():
|
|||
|
|
load_dotenv(env_path)
|
|||
|
|
print(f"[OK] 已加载环境变量文件: {env_path}")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] 未找到.env文件,使用系统环境变量")
|
|||
|
|
|
|||
|
|
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
|||
|
|
from task_manager import get_task_manager
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 生产环境配置 =============
|
|||
|
|
# 使用项目模块的默认参数,不硬编码
|
|||
|
|
OUTPUT_DIR = project_root / "api_outputs"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 请求/响应模型 =============
|
|||
|
|
class RetrieveRequest(BaseModel):
|
|||
|
|
"""检索请求模型"""
|
|||
|
|
query: str = Field(..., description="查询问题")
|
|||
|
|
mode: str = Field(default="0", description="调试模式: 0=自动")
|
|||
|
|
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RetrieveResponse(BaseModel):
|
|||
|
|
"""检索响应模型"""
|
|||
|
|
success: bool = Field(..., description="是否成功")
|
|||
|
|
query: str = Field(..., description="原始查询")
|
|||
|
|
answer: str = Field(..., description="生成的答案")
|
|||
|
|
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
|||
|
|
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
|||
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
|||
|
|
output_file: Optional[str] = Field(None, description="输出文件路径")
|
|||
|
|
error: Optional[str] = Field(None, description="错误信息")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class HealthResponse(BaseModel):
|
|||
|
|
"""健康检查响应"""
|
|||
|
|
status: str = "healthy"
|
|||
|
|
environment: str = "production"
|
|||
|
|
langsmith_connected: bool = False
|
|||
|
|
timestamp: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 全局状态管理 =============
|
|||
|
|
class AppState:
|
|||
|
|
"""应用状态管理"""
|
|||
|
|
retriever = None
|
|||
|
|
langsmith_connected = False
|
|||
|
|
request_count = 0
|
|||
|
|
last_request_time = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 生命周期管理 =============
|
|||
|
|
def initialize_retriever():
|
|||
|
|
"""初始化检索器(使用项目默认配置)"""
|
|||
|
|
try:
|
|||
|
|
# 检查LangSmith连接
|
|||
|
|
AppState.langsmith_connected = check_langsmith_connection()
|
|||
|
|
|
|||
|
|
# 创建检索器,使用项目默认参数
|
|||
|
|
# keyword 是必需参数,使用项目默认值 "test"
|
|||
|
|
AppState.retriever = create_langsmith_retriever(
|
|||
|
|
keyword="test" # 项目默认关键词
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("[OK] 生产环境检索器初始化成功(使用项目默认参数)")
|
|||
|
|
print(" 环境: PRODUCTION")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 检索器初始化失败: {e}")
|
|||
|
|
AppState.retriever = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@asynccontextmanager
|
|||
|
|
async def lifespan(app: FastAPI):
|
|||
|
|
"""应用生命周期管理"""
|
|||
|
|
# 启动时执行
|
|||
|
|
print("="*60)
|
|||
|
|
print("[STARTING] 启动RAG API服务(生产环境)")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|||
|
|
print(f"[OK] 输出目录: {OUTPUT_DIR}")
|
|||
|
|
|
|||
|
|
# 初始化检索器
|
|||
|
|
initialize_retriever()
|
|||
|
|
|
|||
|
|
if AppState.retriever:
|
|||
|
|
print("[OK] 服务就绪(使用项目默认参数)")
|
|||
|
|
else:
|
|||
|
|
print("[WARNING] 服务启动但检索器未初始化")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
yield
|
|||
|
|
|
|||
|
|
# 关闭时执行
|
|||
|
|
print("\n[?] 服务关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= FastAPI应用 =============
|
|||
|
|
app = FastAPI(
|
|||
|
|
title="RAG API Service (Production)",
|
|||
|
|
description="RAG检索服务(生产环境)",
|
|||
|
|
version="1.0.0",
|
|||
|
|
lifespan=lifespan
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 配置CORS
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"], # 生产环境可以改为具体域名
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= API端点 =============
|
|||
|
|
@app.get("/", response_model=Dict[str, str])
|
|||
|
|
async def root():
|
|||
|
|
"""根路径"""
|
|||
|
|
return {
|
|||
|
|
"service": "RAG API Production Service",
|
|||
|
|
"status": "running",
|
|||
|
|
"environment": "production",
|
|||
|
|
"version": "1.0.0"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/health", response_model=HealthResponse)
|
|||
|
|
async def health_check():
|
|||
|
|
"""健康检查"""
|
|||
|
|
return HealthResponse(
|
|||
|
|
status="healthy" if AppState.retriever else "degraded",
|
|||
|
|
environment="production",
|
|||
|
|
langsmith_connected=AppState.langsmith_connected,
|
|||
|
|
timestamp=datetime.now().isoformat()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/query", response_model=RetrieveResponse)
|
|||
|
|
async def query(request: RetrieveRequest):
|
|||
|
|
"""RAG查询接口(生产环境)"""
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
AppState.request_count += 1
|
|||
|
|
AppState.last_request_time = datetime.now()
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"[IN] 收到查询请求 #{AppState.request_count}")
|
|||
|
|
print(f" 时间: {timestamp}")
|
|||
|
|
print(f" 查询: {request.query}")
|
|||
|
|
print(f" 环境: PRODUCTION")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
|
|||
|
|
if not AppState.retriever:
|
|||
|
|
raise HTTPException(status_code=503, detail="检索器未初始化")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 执行检索,使用项目的默认参数
|
|||
|
|
result = AppState.retriever.retrieve(
|
|||
|
|
request.query,
|
|||
|
|
request.mode
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存输出
|
|||
|
|
output_file = None
|
|||
|
|
if request.save_output:
|
|||
|
|
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
|||
|
|
output_data = {
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"environment": "production",
|
|||
|
|
"query": request.query,
|
|||
|
|
"answer": result.get("answer", ""),
|
|||
|
|
"supporting_facts": result.get("supporting_facts", []),
|
|||
|
|
"supporting_events": result.get("supporting_events", []),
|
|||
|
|
"metadata": {
|
|||
|
|
"request_count": AppState.request_count,
|
|||
|
|
"mode": request.mode
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
print(f"[SAVE] 输出已保存: {output_file}")
|
|||
|
|
|
|||
|
|
response = RetrieveResponse(
|
|||
|
|
success=True,
|
|||
|
|
query=request.query,
|
|||
|
|
answer=result.get("answer", ""),
|
|||
|
|
supporting_facts=result.get("supporting_facts", []),
|
|||
|
|
supporting_events=result.get("supporting_events", []),
|
|||
|
|
metadata={
|
|||
|
|
"environment": "production",
|
|||
|
|
"timestamp": timestamp
|
|||
|
|
},
|
|||
|
|
output_file=str(output_file) if output_file else None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
|||
|
|
return response
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 查询失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
return RetrieveResponse(
|
|||
|
|
success=False,
|
|||
|
|
query=request.query,
|
|||
|
|
answer="",
|
|||
|
|
error=str(e),
|
|||
|
|
metadata={"environment": "production", "timestamp": timestamp}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/retrieve", response_model=RetrieveResponse)
|
|||
|
|
async def retrieve(request: RetrieveRequest):
|
|||
|
|
"""检索接口(兼容旧版API)"""
|
|||
|
|
return await query(request)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 流式检索支持 =============
|
|||
|
|
# 创建线程池用于异步执行检索
|
|||
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/retrieve/stream")
|
|||
|
|
async def retrieve_stream(request: RetrieveRequest):
|
|||
|
|
"""
|
|||
|
|
流式检索接口 - 支持SSE实时状态反馈
|
|||
|
|
|
|||
|
|
返回6类关键信息:
|
|||
|
|
1. 子问题分解列表
|
|||
|
|
2. 主要来源文档
|
|||
|
|
3. 当前检索轮次
|
|||
|
|
4. 总文档数
|
|||
|
|
5. 耗时统计
|
|||
|
|
6. 最终答案
|
|||
|
|
"""
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 获取任务管理器并创建任务
|
|||
|
|
task_manager = get_task_manager()
|
|||
|
|
task_id = task_manager.create_task(request.query)
|
|||
|
|
|
|||
|
|
async def generate_sse_events() -> AsyncGenerator[str, None]:
|
|||
|
|
"""生成SSE事件流"""
|
|||
|
|
AppState.request_count += 1
|
|||
|
|
AppState.last_request_time = datetime.now()
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"[STREAM] 收到流式查询请求 #{AppState.request_count}")
|
|||
|
|
print(f" 时间: {timestamp}")
|
|||
|
|
print(f" 查询: {request.query}")
|
|||
|
|
print(f" 任务ID: {task_id[:8]}...")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
|
|||
|
|
# 首先发送任务ID给客户端
|
|||
|
|
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
if not AppState.retriever:
|
|||
|
|
error_data = {
|
|||
|
|
"type": "error",
|
|||
|
|
"message": "检索器未初始化",
|
|||
|
|
"progress": 0
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
|||
|
|
yield "data: [DONE]\n\n"
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 状态队列,用于接收检索过程中的状态更新
|
|||
|
|
status_queue = asyncio.Queue()
|
|||
|
|
# 使用线程安全的Future
|
|||
|
|
result_future = ThreadFuture()
|
|||
|
|
|
|||
|
|
# 获取当前运行的事件循环
|
|||
|
|
loop = asyncio.get_running_loop()
|
|||
|
|
|
|||
|
|
def status_callback(status_type: str, data: Any):
|
|||
|
|
"""状态回调函数(在线程中运行)"""
|
|||
|
|
elapsed = time.time() - start_time
|
|||
|
|
status_data = {
|
|||
|
|
"type": status_type,
|
|||
|
|
"data": data,
|
|||
|
|
"elapsed": round(elapsed, 1),
|
|||
|
|
"timestamp": time.time()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 根据不同类型设置进度
|
|||
|
|
progress_map = {
|
|||
|
|
"starting": 5,
|
|||
|
|
"complexity_check": 10,
|
|||
|
|
"sub_queries": 15,
|
|||
|
|
"initial_retrieval": 30,
|
|||
|
|
"documents": 40,
|
|||
|
|
"sufficiency_check": 50,
|
|||
|
|
"iteration": 60,
|
|||
|
|
"parallel_retrieval": 70,
|
|||
|
|
"pagerank": 80,
|
|||
|
|
"generating": 90,
|
|||
|
|
"answer": 100
|
|||
|
|
}
|
|||
|
|
status_data["progress"] = progress_map.get(status_type, 50)
|
|||
|
|
|
|||
|
|
# 将状态放入队列 - 使用线程安全的方式
|
|||
|
|
try:
|
|||
|
|
# 使用asyncio.run_coroutine_threadsafe确保线程安全
|
|||
|
|
future = asyncio.run_coroutine_threadsafe(
|
|||
|
|
status_queue.put(status_data),
|
|||
|
|
loop # 使用之前保存的loop
|
|||
|
|
)
|
|||
|
|
# 等待完成,设置超时防止阻塞
|
|||
|
|
future.result(timeout=1.0)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] Failed to send status update: {e}")
|
|||
|
|
|
|||
|
|
def retriever_with_callback():
|
|||
|
|
"""带回调的检索器执行"""
|
|||
|
|
try:
|
|||
|
|
# 标记任务开始
|
|||
|
|
task_manager.start_task(task_id)
|
|||
|
|
|
|||
|
|
# 调用增强版检索器,传递task_id
|
|||
|
|
from retriver.langsmith.langsmith_retriever_stream import stream_retrieve
|
|||
|
|
result = stream_retrieve(
|
|||
|
|
AppState.retriever,
|
|||
|
|
request.query,
|
|||
|
|
request.mode,
|
|||
|
|
status_callback,
|
|||
|
|
task_id=task_id # 传递任务ID
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 任务成功完成
|
|||
|
|
task_manager.complete_task(task_id, result)
|
|||
|
|
result_future.set_result(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
# 任务失败
|
|||
|
|
task_manager.fail_task(task_id, str(e))
|
|||
|
|
result_future.set_exception(e)
|
|||
|
|
|
|||
|
|
# 在线程池中执行检索
|
|||
|
|
executor.submit(retriever_with_callback)
|
|||
|
|
|
|||
|
|
# 初始状态已在stream_retrieve中发送,这里不再重复发送
|
|||
|
|
# yield f"data: {json.dumps({'type': 'starting', 'message': '正在分析您的问题...', 'progress': 5}, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
# 持续发送状态更新直到检索完成
|
|||
|
|
while not result_future.done():
|
|||
|
|
# 检查任务是否被取消
|
|||
|
|
if task_manager.should_stop(task_id):
|
|||
|
|
print(f"[STREAM] 任务 {task_id[:8]}... 被取消")
|
|||
|
|
cancel_data = {
|
|||
|
|
"type": "cancelled",
|
|||
|
|
"message": "任务已被取消",
|
|||
|
|
"progress": 0
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(cancel_data, ensure_ascii=False)}\n\n"
|
|||
|
|
yield "data: [DONE]\n\n"
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 非阻塞获取状态,超时100ms
|
|||
|
|
status = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
|||
|
|
|
|||
|
|
# 处理不同类型的状态
|
|||
|
|
if status["type"] == "sub_queries":
|
|||
|
|
# 子问题列表
|
|||
|
|
queries = status["data"][:3] # 最多显示3个
|
|||
|
|
message = f"📝 问题分解为{len(status['data'])}个子查询"
|
|||
|
|
status["formatted_data"] = queries
|
|||
|
|
status["message"] = message
|
|||
|
|
|
|||
|
|
elif status["type"] == "documents":
|
|||
|
|
# 文档信息
|
|||
|
|
data = status["data"]
|
|||
|
|
retrieval_type = data.get("retrieval_type", "检索")
|
|||
|
|
retrieval_reason = data.get("retrieval_reason", "")
|
|||
|
|
new_docs = data.get("new_docs", 0)
|
|||
|
|
total_docs = data.get("count", 0)
|
|||
|
|
|
|||
|
|
# 只在有新文档时显示
|
|||
|
|
if new_docs > 0:
|
|||
|
|
if data.get("is_incremental", False):
|
|||
|
|
# 增量检索(第二次及以后)
|
|||
|
|
message = f"📊 {retrieval_reason}:新增 {new_docs} 篇文档(总计 {total_docs} 篇)"
|
|||
|
|
else:
|
|||
|
|
# 初始检索
|
|||
|
|
message = f"📊 已检索 {total_docs} 篇文档"
|
|||
|
|
|
|||
|
|
if "sources" in data:
|
|||
|
|
sources = data["sources"][:5] # 最多显示5个
|
|||
|
|
status["formatted_sources"] = sources
|
|||
|
|
message += f" | 主要来源:{', '.join(sources)}"
|
|||
|
|
status["message"] = message
|
|||
|
|
else:
|
|||
|
|
# 没有新文档,不显示
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
elif status["type"] == "iteration":
|
|||
|
|
# 迭代信息
|
|||
|
|
current = status["data"]["current"]
|
|||
|
|
max_iter = status["data"]["max"]
|
|||
|
|
# 只有在实际进行迭代时才显示(current > 0表示已经开始迭代)
|
|||
|
|
if current > 0:
|
|||
|
|
message = f"🔄 迭代检索:第{current}轮"
|
|||
|
|
else:
|
|||
|
|
message = f"🔄 开始检索流程"
|
|||
|
|
status["message"] = message
|
|||
|
|
|
|||
|
|
elif status["type"] == "sufficiency_check":
|
|||
|
|
# 充分性检查
|
|||
|
|
is_sufficient = status["data"].get("is_sufficient", False)
|
|||
|
|
confidence = status["data"].get("confidence", 0)
|
|||
|
|
message = f"✅ 充分性检查:{'充分' if is_sufficient else '需要更多信息'} (置信度: {confidence:.1%})"
|
|||
|
|
status["message"] = message
|
|||
|
|
|
|||
|
|
elif status["type"] == "generating":
|
|||
|
|
# 正在生成答案
|
|||
|
|
message = "💡 正在生成答案..."
|
|||
|
|
status["message"] = message
|
|||
|
|
|
|||
|
|
elif status["type"] == "answer":
|
|||
|
|
# 完整答案
|
|||
|
|
answer_data = {
|
|||
|
|
"type": "answer",
|
|||
|
|
"data": {
|
|||
|
|
"content": status["data"].get("content", ""),
|
|||
|
|
"complete": True
|
|||
|
|
},
|
|||
|
|
"message": "✅ 答案生成完成",
|
|||
|
|
"progress": 100,
|
|||
|
|
"elapsed": status.get("elapsed", round(time.time() - start_time, 1))
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
|||
|
|
continue # 跳过默认处理
|
|||
|
|
|
|||
|
|
# 发送格式化的状态
|
|||
|
|
yield f"data: {json.dumps(status, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
# 没有新状态,继续等待
|
|||
|
|
continue
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 状态处理异常: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 获取最终结果
|
|||
|
|
try:
|
|||
|
|
# 在异步上下文中等待线程Future完成
|
|||
|
|
result = await loop.run_in_executor(None, result_future.result, 60) # 最多等待60秒
|
|||
|
|
|
|||
|
|
# 发送最终答案
|
|||
|
|
if result:
|
|||
|
|
answer_data = {
|
|||
|
|
"type": "answer",
|
|||
|
|
"data": {
|
|||
|
|
"content": result.get("answer", ""),
|
|||
|
|
"supporting_facts": result.get("supporting_facts", []),
|
|||
|
|
"supporting_events": result.get("supporting_events", []),
|
|||
|
|
"total_documents": result.get("total_passages", 0),
|
|||
|
|
"iterations": result.get("iterations", 0)
|
|||
|
|
},
|
|||
|
|
"message": "✅ 答案生成完成",
|
|||
|
|
"progress": 100,
|
|||
|
|
"elapsed": round(time.time() - start_time, 1)
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(answer_data, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
print(f"[OK] 流式查询完成 (耗时: {round(time.time() - start_time, 1)}秒)")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_data = {
|
|||
|
|
"type": "error",
|
|||
|
|
"message": f"检索过程出错: {str(e)}",
|
|||
|
|
"progress": 0,
|
|||
|
|
"elapsed": round(time.time() - start_time, 1)
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
|||
|
|
print(f"[ERROR] 流式查询失败: {e}")
|
|||
|
|
|
|||
|
|
# 发送完成信号
|
|||
|
|
yield "data: [DONE]\n\n"
|
|||
|
|
|
|||
|
|
# 清理任务资源
|
|||
|
|
task_manager.cleanup_task(task_id)
|
|||
|
|
|
|||
|
|
# 返回SSE响应,在响应头中包含task_id
|
|||
|
|
return StreamingResponse(
|
|||
|
|
generate_sse_events(),
|
|||
|
|
media_type="text/event-stream",
|
|||
|
|
headers={
|
|||
|
|
"Cache-Control": "no-cache",
|
|||
|
|
"Connection": "keep-alive",
|
|||
|
|
"X-Accel-Buffering": "no", # 禁用nginx缓冲
|
|||
|
|
"X-Task-Id": task_id # 在响应头中返回任务ID
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/task/cancel/{task_id}")
|
|||
|
|
async def cancel_task(task_id: str):
|
|||
|
|
"""
|
|||
|
|
取消正在执行的任务
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
取消操作的结果
|
|||
|
|
"""
|
|||
|
|
task_manager = get_task_manager()
|
|||
|
|
success = task_manager.cancel_task(task_id)
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print(f"[API] 成功取消任务: {task_id[:8]}...")
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": "任务已取消",
|
|||
|
|
"task_id": task_id
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
# 获取任务信息判断状态
|
|||
|
|
task_info = task_manager.get_task_info(task_id)
|
|||
|
|
if task_info:
|
|||
|
|
return {
|
|||
|
|
"success": False,
|
|||
|
|
"message": f"任务无法取消,当前状态: {task_info['status']}",
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"task_info": task_info
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/task/{task_id}")
|
|||
|
|
async def get_task_status(task_id: str):
|
|||
|
|
"""
|
|||
|
|
获取任务状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
任务状态信息
|
|||
|
|
"""
|
|||
|
|
task_manager = get_task_manager()
|
|||
|
|
task_info = task_manager.get_task_info(task_id)
|
|||
|
|
|
|||
|
|
if task_info:
|
|||
|
|
return task_info
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/config", response_model=Dict[str, Any])
|
|||
|
|
async def get_config():
|
|||
|
|
"""获取当前生产配置信息"""
|
|||
|
|
return {
|
|||
|
|
"environment": "production",
|
|||
|
|
"note": "生产环境使用项目默认配置",
|
|||
|
|
"description": "所有参数使用项目模块的内置默认值",
|
|||
|
|
"stats": {
|
|||
|
|
"request_count": AppState.request_count,
|
|||
|
|
"last_request_time": AppState.last_request_time.isoformat() if AppState.last_request_time else None
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/outputs", response_model=List[str])
|
|||
|
|
async def list_outputs():
|
|||
|
|
"""列出所有输出文件"""
|
|||
|
|
if not OUTPUT_DIR.exists():
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
files = [f.name for f in OUTPUT_DIR.glob("output_*.json")]
|
|||
|
|
return sorted(files, reverse=True)[:100] # 只返回最近100个
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/outputs/{filename}")
|
|||
|
|
async def get_output(filename: str):
|
|||
|
|
"""获取指定输出文件"""
|
|||
|
|
file_path = OUTPUT_DIR / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
return FileResponse(file_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 主函数 =============
|
|||
|
|
def main():
|
|||
|
|
"""启动生产环境服务"""
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description="RAG API生产环境服务")
|
|||
|
|
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
|||
|
|
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 生产环境强制关闭自动重载
|
|||
|
|
uvicorn.run(
|
|||
|
|
"rag_api_server_production:app",
|
|||
|
|
host=args.host,
|
|||
|
|
port=args.port,
|
|||
|
|
reload=False, # 生产环境禁用自动重载
|
|||
|
|
workers=1, # 可以根据需要调整worker数量
|
|||
|
|
log_level="info"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|