420 lines
13 KiB
Python
420 lines
13 KiB
Python
|
|
"""
|
|||
|
|
RAG API 生产环境服务 - Redis缓存版
|
|||
|
|
基于原始 rag_api_server_production.py 添加Redis缓存
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import uvicorn
|
|||
|
|
import hashlib
|
|||
|
|
import redis
|
|||
|
|
from pathlib import Path
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Dict, Any, List, Optional
|
|||
|
|
from contextlib import asynccontextmanager
|
|||
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# 设置环境变量
|
|||
|
|
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
|||
|
|
|
|||
|
|
# 项目根目录
|
|||
|
|
project_root = Path(__file__).parent
|
|||
|
|
sys.path.append(str(project_root))
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
env_path = project_root / '.env'
|
|||
|
|
if env_path.exists():
|
|||
|
|
load_dotenv(env_path)
|
|||
|
|
|
|||
|
|
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
|||
|
|
|
|||
|
|
# ============= Redis配置 =============
|
|||
|
|
REDIS_HOST = os.getenv('REDIS_HOST', '172.18.0.3') # 使用Docker内部IP
|
|||
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT', 6379))
|
|||
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) # Redis密码
|
|||
|
|
CACHE_TTL = int(os.getenv('CACHE_TTL', 3600)) # 缓存1小时
|
|||
|
|
CACHE_ENABLED = True
|
|||
|
|
|
|||
|
|
# ============= 生产环境配置 =============
|
|||
|
|
OUTPUT_DIR = project_root / "api_outputs"
|
|||
|
|
|
|||
|
|
# ============= 请求/响应模型 =============
|
|||
|
|
class RetrieveRequest(BaseModel):
|
|||
|
|
"""检索请求模型"""
|
|||
|
|
query: str = Field(..., description="查询问题")
|
|||
|
|
mode: str = Field(default="0", description="调试模式: 0=自动")
|
|||
|
|
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
|||
|
|
use_cache: bool = Field(default=True, description="是否使用缓存")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RetrieveResponse(BaseModel):
|
|||
|
|
"""检索响应模型"""
|
|||
|
|
success: bool = Field(..., description="是否成功")
|
|||
|
|
query: str = Field(..., description="原始查询")
|
|||
|
|
answer: str = Field(..., description="生成的答案")
|
|||
|
|
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
|||
|
|
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
|||
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
|||
|
|
output_file: Optional[str] = Field(None, description="输出文件路径")
|
|||
|
|
error: Optional[str] = Field(None, description="错误信息")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class HealthResponse(BaseModel):
|
|||
|
|
"""健康检查响应"""
|
|||
|
|
status: str = "healthy"
|
|||
|
|
environment: str = "production"
|
|||
|
|
langsmith_connected: bool = False
|
|||
|
|
redis_connected: bool = False
|
|||
|
|
timestamp: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= Redis缓存管理 =============
|
|||
|
|
class CacheManager:
|
|||
|
|
"""Redis缓存管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
"""初始化Redis连接"""
|
|||
|
|
self.enabled = False
|
|||
|
|
self.client = None
|
|||
|
|
self.connect()
|
|||
|
|
|
|||
|
|
def connect(self):
|
|||
|
|
"""连接Redis"""
|
|||
|
|
if not CACHE_ENABLED:
|
|||
|
|
print("[INFO] 缓存已禁用")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self.client = redis.Redis(
|
|||
|
|
host=REDIS_HOST,
|
|||
|
|
port=REDIS_PORT,
|
|||
|
|
password=REDIS_PASSWORD,
|
|||
|
|
db=0,
|
|||
|
|
decode_responses=True,
|
|||
|
|
socket_connect_timeout=5,
|
|||
|
|
socket_timeout=5
|
|||
|
|
)
|
|||
|
|
# 测试连接
|
|||
|
|
self.client.ping()
|
|||
|
|
self.enabled = True
|
|||
|
|
print(f"[OK] Redis缓存已连接: {REDIS_HOST}:{REDIS_PORT}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] Redis连接失败: {e}")
|
|||
|
|
print("[INFO] 系统将在无缓存模式下运行")
|
|||
|
|
self.enabled = False
|
|||
|
|
|
|||
|
|
def get_key(self, query: str, mode: str) -> str:
|
|||
|
|
"""生成缓存键"""
|
|||
|
|
content = f"{query}:{mode}"
|
|||
|
|
return f"rag:{hashlib.md5(content.encode()).hexdigest()}"
|
|||
|
|
|
|||
|
|
def get(self, query: str, mode: str) -> Optional[Dict]:
|
|||
|
|
"""获取缓存"""
|
|||
|
|
if not self.enabled:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
key = self.get_key(query, mode)
|
|||
|
|
cached_data = self.client.get(key)
|
|||
|
|
if cached_data:
|
|||
|
|
print(f"[CACHE HIT] 查询: {query[:50]}...")
|
|||
|
|
return json.loads(cached_data)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[CACHE ERROR] 读取失败: {e}")
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def set(self, query: str, mode: str, data: Dict):
|
|||
|
|
"""设置缓存"""
|
|||
|
|
if not self.enabled:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
key = self.get_key(query, mode)
|
|||
|
|
self.client.setex(
|
|||
|
|
key,
|
|||
|
|
CACHE_TTL,
|
|||
|
|
json.dumps(data, ensure_ascii=False)
|
|||
|
|
)
|
|||
|
|
print(f"[CACHE SET] 已缓存: {query[:50]}... (TTL: {CACHE_TTL}秒)")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[CACHE ERROR] 写入失败: {e}")
|
|||
|
|
|
|||
|
|
def get_stats(self) -> Dict:
|
|||
|
|
"""获取缓存统计"""
|
|||
|
|
if not self.enabled:
|
|||
|
|
return {"enabled": False}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
info = self.client.info()
|
|||
|
|
dbsize = self.client.dbsize()
|
|||
|
|
return {
|
|||
|
|
"enabled": True,
|
|||
|
|
"keys": dbsize,
|
|||
|
|
"used_memory": info.get("used_memory_human"),
|
|||
|
|
"hits": info.get("keyspace_hits", 0),
|
|||
|
|
"misses": info.get("keyspace_misses", 0),
|
|||
|
|
"hit_rate": round(
|
|||
|
|
info.get("keyspace_hits", 0) /
|
|||
|
|
max(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0), 1) * 100,
|
|||
|
|
2
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
except:
|
|||
|
|
return {"enabled": False, "error": "无法获取统计"}
|
|||
|
|
|
|||
|
|
def clear(self):
|
|||
|
|
"""清空缓存"""
|
|||
|
|
if not self.enabled:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self.client.flushdb()
|
|||
|
|
print("[CACHE] 缓存已清空")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[CACHE ERROR] 清空失败: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 全局状态管理 =============
|
|||
|
|
class AppState:
|
|||
|
|
"""应用状态管理"""
|
|||
|
|
retriever = None
|
|||
|
|
langsmith_connected = False
|
|||
|
|
cache_manager = None
|
|||
|
|
request_count = 0
|
|||
|
|
cache_hits = 0
|
|||
|
|
last_request_time = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 生命周期管理 =============
|
|||
|
|
def initialize_retriever():
|
|||
|
|
"""初始化检索器"""
|
|||
|
|
try:
|
|||
|
|
AppState.langsmith_connected = check_langsmith_connection()
|
|||
|
|
AppState.retriever = create_langsmith_retriever(keyword="test")
|
|||
|
|
print("[OK] 检索器初始化成功")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 检索器初始化失败: {e}")
|
|||
|
|
AppState.retriever = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def initialize_cache():
|
|||
|
|
"""初始化缓存"""
|
|||
|
|
AppState.cache_manager = CacheManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@asynccontextmanager
|
|||
|
|
async def lifespan(app: FastAPI):
|
|||
|
|
"""应用生命周期管理"""
|
|||
|
|
print("="*60)
|
|||
|
|
print("[STARTING] 启动RAG API服务(Redis缓存版)")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
# 初始化组件
|
|||
|
|
initialize_retriever()
|
|||
|
|
initialize_cache()
|
|||
|
|
|
|||
|
|
print(f"[INFO] Redis配置: {REDIS_HOST}:{REDIS_PORT}")
|
|||
|
|
print(f"[INFO] 缓存TTL: {CACHE_TTL}秒")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
yield
|
|||
|
|
|
|||
|
|
print("\n[INFO] 服务关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= FastAPI应用 =============
|
|||
|
|
app = FastAPI(
|
|||
|
|
title="RAG API Service (Redis Cache)",
|
|||
|
|
description="RAG检索服务(Redis缓存版)",
|
|||
|
|
version="2.0.0",
|
|||
|
|
lifespan=lifespan
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 配置CORS
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"],
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= API端点 =============
|
|||
|
|
@app.get("/health", response_model=HealthResponse)
|
|||
|
|
async def health_check():
|
|||
|
|
"""健康检查"""
|
|||
|
|
return HealthResponse(
|
|||
|
|
status="healthy" if AppState.retriever else "degraded",
|
|||
|
|
environment="production",
|
|||
|
|
langsmith_connected=AppState.langsmith_connected,
|
|||
|
|
redis_connected=AppState.cache_manager.enabled if AppState.cache_manager else False,
|
|||
|
|
timestamp=datetime.now().isoformat()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/retrieve", response_model=RetrieveResponse)
|
|||
|
|
async def retrieve(request: RetrieveRequest):
|
|||
|
|
"""RAG查询接口(带缓存)"""
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
AppState.request_count += 1
|
|||
|
|
AppState.last_request_time = datetime.now()
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"[IN] 请求 #{AppState.request_count}")
|
|||
|
|
print(f" 查询: {request.query}")
|
|||
|
|
print(f" 使用缓存: {request.use_cache}")
|
|||
|
|
|
|||
|
|
if not AppState.retriever:
|
|||
|
|
raise HTTPException(status_code=503, detail="检索器未初始化")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 1. 检查缓存
|
|||
|
|
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
|
|||
|
|
cached_result = AppState.cache_manager.get(request.query, request.mode)
|
|||
|
|
if cached_result:
|
|||
|
|
AppState.cache_hits += 1
|
|||
|
|
print(f"[OK] 缓存命中 (命中率: {AppState.cache_hits}/{AppState.request_count})")
|
|||
|
|
|
|||
|
|
# 直接返回缓存结果
|
|||
|
|
return RetrieveResponse(
|
|||
|
|
success=True,
|
|||
|
|
query=request.query,
|
|||
|
|
answer=cached_result.get("answer", ""),
|
|||
|
|
supporting_facts=cached_result.get("supporting_facts", []),
|
|||
|
|
supporting_events=cached_result.get("supporting_events", []),
|
|||
|
|
metadata={
|
|||
|
|
"environment": "production",
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"cached": True,
|
|||
|
|
"cache_hit_rate": f"{AppState.cache_hits}/{AppState.request_count}"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 2. 执行检索
|
|||
|
|
print("[PROCESSING] 执行RAG检索...")
|
|||
|
|
result = AppState.retriever.retrieve(
|
|||
|
|
request.query,
|
|||
|
|
request.mode
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 3. 存入缓存
|
|||
|
|
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
|
|||
|
|
cache_data = {
|
|||
|
|
"answer": result.get("answer", ""),
|
|||
|
|
"supporting_facts": result.get("supporting_facts", []),
|
|||
|
|
"supporting_events": result.get("supporting_events", [])
|
|||
|
|
}
|
|||
|
|
AppState.cache_manager.set(request.query, request.mode, cache_data)
|
|||
|
|
|
|||
|
|
# 4. 保存输出(可选)
|
|||
|
|
output_file = None
|
|||
|
|
if request.save_output:
|
|||
|
|
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
|||
|
|
output_data = {
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"query": request.query,
|
|||
|
|
"answer": result.get("answer", ""),
|
|||
|
|
"supporting_facts": result.get("supporting_facts", []),
|
|||
|
|
"supporting_events": result.get("supporting_events", []),
|
|||
|
|
"metadata": {
|
|||
|
|
"cached": False,
|
|||
|
|
"mode": request.mode
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
response = RetrieveResponse(
|
|||
|
|
success=True,
|
|||
|
|
query=request.query,
|
|||
|
|
answer=result.get("answer", ""),
|
|||
|
|
supporting_facts=result.get("supporting_facts", []),
|
|||
|
|
supporting_events=result.get("supporting_events", []),
|
|||
|
|
metadata={
|
|||
|
|
"environment": "production",
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"cached": False
|
|||
|
|
},
|
|||
|
|
output_file=str(output_file) if output_file else None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
|||
|
|
return response
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 查询失败: {e}")
|
|||
|
|
return RetrieveResponse(
|
|||
|
|
success=False,
|
|||
|
|
query=request.query,
|
|||
|
|
answer="",
|
|||
|
|
error=str(e),
|
|||
|
|
metadata={"environment": "production", "timestamp": timestamp}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/cache/stats")
|
|||
|
|
async def cache_stats():
|
|||
|
|
"""获取缓存统计"""
|
|||
|
|
if not AppState.cache_manager:
|
|||
|
|
return {"error": "缓存未初始化"}
|
|||
|
|
|
|||
|
|
stats = AppState.cache_manager.get_stats()
|
|||
|
|
stats.update({
|
|||
|
|
"total_requests": AppState.request_count,
|
|||
|
|
"cache_hits": AppState.cache_hits,
|
|||
|
|
"hit_rate_session": f"{round(AppState.cache_hits/max(AppState.request_count,1)*100, 2)}%"
|
|||
|
|
})
|
|||
|
|
return stats
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.delete("/cache/clear")
|
|||
|
|
async def clear_cache():
|
|||
|
|
"""清空缓存"""
|
|||
|
|
if not AppState.cache_manager:
|
|||
|
|
return {"error": "缓存未初始化"}
|
|||
|
|
|
|||
|
|
AppState.cache_manager.clear()
|
|||
|
|
return {"success": True, "message": "缓存已清空"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 兼容旧接口 =============
|
|||
|
|
@app.post("/query", response_model=RetrieveResponse)
|
|||
|
|
async def query(request: RetrieveRequest):
|
|||
|
|
"""查询接口(兼容旧版)"""
|
|||
|
|
return await retrieve(request)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 主函数 =============
|
|||
|
|
def main():
|
|||
|
|
"""启动服务"""
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description="RAG API服务(Redis缓存版)")
|
|||
|
|
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
|||
|
|
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
|||
|
|
parser.add_argument("--workers", type=int, default=4, help="Worker进程数")
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
uvicorn.run(
|
|||
|
|
"rag_api_server_with_cache:app",
|
|||
|
|
host=args.host,
|
|||
|
|
port=args.port,
|
|||
|
|
reload=False,
|
|||
|
|
workers=args.workers,
|
|||
|
|
log_level="info"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|