420 lines
13 KiB
Python
420 lines
13 KiB
Python
"""
|
||
RAG API 生产环境服务 - Redis缓存版
|
||
基于原始 rag_api_server_production.py 添加Redis缓存
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import uvicorn
|
||
import hashlib
|
||
import redis
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import Dict, Any, List, Optional
|
||
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
from dotenv import load_dotenv
|
||
|
||
# 设置环境变量
|
||
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
|
||
|
||
# 项目根目录
|
||
project_root = Path(__file__).parent
|
||
sys.path.append(str(project_root))
|
||
|
||
# 加载环境变量
|
||
env_path = project_root / '.env'
|
||
if env_path.exists():
|
||
load_dotenv(env_path)
|
||
|
||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||
|
||
# ============= Redis配置 =============
|
||
REDIS_HOST = os.getenv('REDIS_HOST', '172.18.0.3') # 使用Docker内部IP
|
||
REDIS_PORT = int(os.getenv('REDIS_PORT', 6379))
|
||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) # Redis密码
|
||
CACHE_TTL = int(os.getenv('CACHE_TTL', 3600)) # 缓存1小时
|
||
CACHE_ENABLED = True
|
||
|
||
# ============= 生产环境配置 =============
|
||
OUTPUT_DIR = project_root / "api_outputs"
|
||
|
||
# ============= 请求/响应模型 =============
|
||
class RetrieveRequest(BaseModel):
|
||
"""检索请求模型"""
|
||
query: str = Field(..., description="查询问题")
|
||
mode: str = Field(default="0", description="调试模式: 0=自动")
|
||
save_output: bool = Field(default=True, description="是否保存输出到文件")
|
||
use_cache: bool = Field(default=True, description="是否使用缓存")
|
||
|
||
|
||
class RetrieveResponse(BaseModel):
|
||
"""检索响应模型"""
|
||
success: bool = Field(..., description="是否成功")
|
||
query: str = Field(..., description="原始查询")
|
||
answer: str = Field(..., description="生成的答案")
|
||
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
|
||
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
|
||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||
output_file: Optional[str] = Field(None, description="输出文件路径")
|
||
error: Optional[str] = Field(None, description="错误信息")
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
"""健康检查响应"""
|
||
status: str = "healthy"
|
||
environment: str = "production"
|
||
langsmith_connected: bool = False
|
||
redis_connected: bool = False
|
||
timestamp: str = ""
|
||
|
||
|
||
# ============= Redis缓存管理 =============
|
||
class CacheManager:
|
||
"""Redis缓存管理器"""
|
||
|
||
def __init__(self):
|
||
"""初始化Redis连接"""
|
||
self.enabled = False
|
||
self.client = None
|
||
self.connect()
|
||
|
||
def connect(self):
|
||
"""连接Redis"""
|
||
if not CACHE_ENABLED:
|
||
print("[INFO] 缓存已禁用")
|
||
return
|
||
|
||
try:
|
||
self.client = redis.Redis(
|
||
host=REDIS_HOST,
|
||
port=REDIS_PORT,
|
||
password=REDIS_PASSWORD,
|
||
db=0,
|
||
decode_responses=True,
|
||
socket_connect_timeout=5,
|
||
socket_timeout=5
|
||
)
|
||
# 测试连接
|
||
self.client.ping()
|
||
self.enabled = True
|
||
print(f"[OK] Redis缓存已连接: {REDIS_HOST}:{REDIS_PORT}")
|
||
except Exception as e:
|
||
print(f"[WARNING] Redis连接失败: {e}")
|
||
print("[INFO] 系统将在无缓存模式下运行")
|
||
self.enabled = False
|
||
|
||
def get_key(self, query: str, mode: str) -> str:
|
||
"""生成缓存键"""
|
||
content = f"{query}:{mode}"
|
||
return f"rag:{hashlib.md5(content.encode()).hexdigest()}"
|
||
|
||
def get(self, query: str, mode: str) -> Optional[Dict]:
|
||
"""获取缓存"""
|
||
if not self.enabled:
|
||
return None
|
||
|
||
try:
|
||
key = self.get_key(query, mode)
|
||
cached_data = self.client.get(key)
|
||
if cached_data:
|
||
print(f"[CACHE HIT] 查询: {query[:50]}...")
|
||
return json.loads(cached_data)
|
||
except Exception as e:
|
||
print(f"[CACHE ERROR] 读取失败: {e}")
|
||
|
||
return None
|
||
|
||
def set(self, query: str, mode: str, data: Dict):
|
||
"""设置缓存"""
|
||
if not self.enabled:
|
||
return
|
||
|
||
try:
|
||
key = self.get_key(query, mode)
|
||
self.client.setex(
|
||
key,
|
||
CACHE_TTL,
|
||
json.dumps(data, ensure_ascii=False)
|
||
)
|
||
print(f"[CACHE SET] 已缓存: {query[:50]}... (TTL: {CACHE_TTL}秒)")
|
||
except Exception as e:
|
||
print(f"[CACHE ERROR] 写入失败: {e}")
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取缓存统计"""
|
||
if not self.enabled:
|
||
return {"enabled": False}
|
||
|
||
try:
|
||
info = self.client.info()
|
||
dbsize = self.client.dbsize()
|
||
return {
|
||
"enabled": True,
|
||
"keys": dbsize,
|
||
"used_memory": info.get("used_memory_human"),
|
||
"hits": info.get("keyspace_hits", 0),
|
||
"misses": info.get("keyspace_misses", 0),
|
||
"hit_rate": round(
|
||
info.get("keyspace_hits", 0) /
|
||
max(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0), 1) * 100,
|
||
2
|
||
)
|
||
}
|
||
except:
|
||
return {"enabled": False, "error": "无法获取统计"}
|
||
|
||
def clear(self):
|
||
"""清空缓存"""
|
||
if not self.enabled:
|
||
return
|
||
|
||
try:
|
||
self.client.flushdb()
|
||
print("[CACHE] 缓存已清空")
|
||
except Exception as e:
|
||
print(f"[CACHE ERROR] 清空失败: {e}")
|
||
|
||
|
||
# ============= 全局状态管理 =============
|
||
class AppState:
|
||
"""应用状态管理"""
|
||
retriever = None
|
||
langsmith_connected = False
|
||
cache_manager = None
|
||
request_count = 0
|
||
cache_hits = 0
|
||
last_request_time = None
|
||
|
||
|
||
# ============= 生命周期管理 =============
|
||
def initialize_retriever():
|
||
"""初始化检索器"""
|
||
try:
|
||
AppState.langsmith_connected = check_langsmith_connection()
|
||
AppState.retriever = create_langsmith_retriever(keyword="test")
|
||
print("[OK] 检索器初始化成功")
|
||
except Exception as e:
|
||
print(f"[ERROR] 检索器初始化失败: {e}")
|
||
AppState.retriever = None
|
||
|
||
|
||
def initialize_cache():
|
||
"""初始化缓存"""
|
||
AppState.cache_manager = CacheManager()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
print("="*60)
|
||
print("[STARTING] 启动RAG API服务(Redis缓存版)")
|
||
print("="*60)
|
||
|
||
# 创建输出目录
|
||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||
|
||
# 初始化组件
|
||
initialize_retriever()
|
||
initialize_cache()
|
||
|
||
print(f"[INFO] Redis配置: {REDIS_HOST}:{REDIS_PORT}")
|
||
print(f"[INFO] 缓存TTL: {CACHE_TTL}秒")
|
||
print("="*60)
|
||
|
||
yield
|
||
|
||
print("\n[INFO] 服务关闭")
|
||
|
||
|
||
# ============= FastAPI应用 =============
|
||
app = FastAPI(
|
||
title="RAG API Service (Redis Cache)",
|
||
description="RAG检索服务(Redis缓存版)",
|
||
version="2.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ============= API端点 =============
|
||
@app.get("/health", response_model=HealthResponse)
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return HealthResponse(
|
||
status="healthy" if AppState.retriever else "degraded",
|
||
environment="production",
|
||
langsmith_connected=AppState.langsmith_connected,
|
||
redis_connected=AppState.cache_manager.enabled if AppState.cache_manager else False,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
|
||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||
async def retrieve(request: RetrieveRequest):
|
||
"""RAG查询接口(带缓存)"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
AppState.request_count += 1
|
||
AppState.last_request_time = datetime.now()
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"[IN] 请求 #{AppState.request_count}")
|
||
print(f" 查询: {request.query}")
|
||
print(f" 使用缓存: {request.use_cache}")
|
||
|
||
if not AppState.retriever:
|
||
raise HTTPException(status_code=503, detail="检索器未初始化")
|
||
|
||
try:
|
||
# 1. 检查缓存
|
||
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
|
||
cached_result = AppState.cache_manager.get(request.query, request.mode)
|
||
if cached_result:
|
||
AppState.cache_hits += 1
|
||
print(f"[OK] 缓存命中 (命中率: {AppState.cache_hits}/{AppState.request_count})")
|
||
|
||
# 直接返回缓存结果
|
||
return RetrieveResponse(
|
||
success=True,
|
||
query=request.query,
|
||
answer=cached_result.get("answer", ""),
|
||
supporting_facts=cached_result.get("supporting_facts", []),
|
||
supporting_events=cached_result.get("supporting_events", []),
|
||
metadata={
|
||
"environment": "production",
|
||
"timestamp": timestamp,
|
||
"cached": True,
|
||
"cache_hit_rate": f"{AppState.cache_hits}/{AppState.request_count}"
|
||
}
|
||
)
|
||
|
||
# 2. 执行检索
|
||
print("[PROCESSING] 执行RAG检索...")
|
||
result = AppState.retriever.retrieve(
|
||
request.query,
|
||
request.mode
|
||
)
|
||
|
||
# 3. 存入缓存
|
||
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
|
||
cache_data = {
|
||
"answer": result.get("answer", ""),
|
||
"supporting_facts": result.get("supporting_facts", []),
|
||
"supporting_events": result.get("supporting_events", [])
|
||
}
|
||
AppState.cache_manager.set(request.query, request.mode, cache_data)
|
||
|
||
# 4. 保存输出(可选)
|
||
output_file = None
|
||
if request.save_output:
|
||
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
|
||
output_data = {
|
||
"timestamp": timestamp,
|
||
"query": request.query,
|
||
"answer": result.get("answer", ""),
|
||
"supporting_facts": result.get("supporting_facts", []),
|
||
"supporting_events": result.get("supporting_events", []),
|
||
"metadata": {
|
||
"cached": False,
|
||
"mode": request.mode
|
||
}
|
||
}
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
||
|
||
response = RetrieveResponse(
|
||
success=True,
|
||
query=request.query,
|
||
answer=result.get("answer", ""),
|
||
supporting_facts=result.get("supporting_facts", []),
|
||
supporting_events=result.get("supporting_events", []),
|
||
metadata={
|
||
"environment": "production",
|
||
"timestamp": timestamp,
|
||
"cached": False
|
||
},
|
||
output_file=str(output_file) if output_file else None
|
||
)
|
||
|
||
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
|
||
return response
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 查询失败: {e}")
|
||
return RetrieveResponse(
|
||
success=False,
|
||
query=request.query,
|
||
answer="",
|
||
error=str(e),
|
||
metadata={"environment": "production", "timestamp": timestamp}
|
||
)
|
||
|
||
|
||
@app.get("/cache/stats")
|
||
async def cache_stats():
|
||
"""获取缓存统计"""
|
||
if not AppState.cache_manager:
|
||
return {"error": "缓存未初始化"}
|
||
|
||
stats = AppState.cache_manager.get_stats()
|
||
stats.update({
|
||
"total_requests": AppState.request_count,
|
||
"cache_hits": AppState.cache_hits,
|
||
"hit_rate_session": f"{round(AppState.cache_hits/max(AppState.request_count,1)*100, 2)}%"
|
||
})
|
||
return stats
|
||
|
||
|
||
@app.delete("/cache/clear")
|
||
async def clear_cache():
|
||
"""清空缓存"""
|
||
if not AppState.cache_manager:
|
||
return {"error": "缓存未初始化"}
|
||
|
||
AppState.cache_manager.clear()
|
||
return {"success": True, "message": "缓存已清空"}
|
||
|
||
|
||
# ============= 兼容旧接口 =============
|
||
@app.post("/query", response_model=RetrieveResponse)
|
||
async def query(request: RetrieveRequest):
|
||
"""查询接口(兼容旧版)"""
|
||
return await retrieve(request)
|
||
|
||
|
||
# ============= 主函数 =============
|
||
def main():
|
||
"""启动服务"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="RAG API服务(Redis缓存版)")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
|
||
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
||
parser.add_argument("--workers", type=int, default=4, help="Worker进程数")
|
||
|
||
args = parser.parse_args()
|
||
|
||
uvicorn.run(
|
||
"rag_api_server_with_cache:app",
|
||
host=args.host,
|
||
port=args.port,
|
||
reload=False,
|
||
workers=args.workers,
|
||
log_level="info"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |