Files
AIEC-RAG---/AIEC-RAG/rag_api_server_with_cache.py
2025-09-25 10:33:37 +08:00

420 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG API 生产环境服务 - Redis缓存版
基于原始 rag_api_server_production.py 添加Redis缓存
"""
import os
import sys
import json
import uvicorn
import hashlib
import redis
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from dotenv import load_dotenv
# 设置环境变量
os.environ['RAG_CONFIG_PATH'] = 'rag_config_production.yaml'
# 项目根目录
project_root = Path(__file__).parent
sys.path.append(str(project_root))
# 加载环境变量
env_path = project_root / '.env'
if env_path.exists():
load_dotenv(env_path)
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
# ============= Redis配置 =============
REDIS_HOST = os.getenv('REDIS_HOST', '172.18.0.3') # 使用Docker内部IP
REDIS_PORT = int(os.getenv('REDIS_PORT', 6379))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) # Redis密码
CACHE_TTL = int(os.getenv('CACHE_TTL', 3600)) # 缓存1小时
CACHE_ENABLED = True
# ============= 生产环境配置 =============
OUTPUT_DIR = project_root / "api_outputs"
# ============= 请求/响应模型 =============
class RetrieveRequest(BaseModel):
"""检索请求模型"""
query: str = Field(..., description="查询问题")
mode: str = Field(default="0", description="调试模式: 0=自动")
save_output: bool = Field(default=True, description="是否保存输出到文件")
use_cache: bool = Field(default=True, description="是否使用缓存")
class RetrieveResponse(BaseModel):
"""检索响应模型"""
success: bool = Field(..., description="是否成功")
query: str = Field(..., description="原始查询")
answer: str = Field(..., description="生成的答案")
supporting_facts: List[List[str]] = Field(default_factory=list, description="支撑段落")
supporting_events: List[List[str]] = Field(default_factory=list, description="支撑事件")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
output_file: Optional[str] = Field(None, description="输出文件路径")
error: Optional[str] = Field(None, description="错误信息")
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str = "healthy"
environment: str = "production"
langsmith_connected: bool = False
redis_connected: bool = False
timestamp: str = ""
# ============= Redis缓存管理 =============
class CacheManager:
"""Redis缓存管理器"""
def __init__(self):
"""初始化Redis连接"""
self.enabled = False
self.client = None
self.connect()
def connect(self):
"""连接Redis"""
if not CACHE_ENABLED:
print("[INFO] 缓存已禁用")
return
try:
self.client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=0,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5
)
# 测试连接
self.client.ping()
self.enabled = True
print(f"[OK] Redis缓存已连接: {REDIS_HOST}:{REDIS_PORT}")
except Exception as e:
print(f"[WARNING] Redis连接失败: {e}")
print("[INFO] 系统将在无缓存模式下运行")
self.enabled = False
def get_key(self, query: str, mode: str) -> str:
"""生成缓存键"""
content = f"{query}:{mode}"
return f"rag:{hashlib.md5(content.encode()).hexdigest()}"
def get(self, query: str, mode: str) -> Optional[Dict]:
"""获取缓存"""
if not self.enabled:
return None
try:
key = self.get_key(query, mode)
cached_data = self.client.get(key)
if cached_data:
print(f"[CACHE HIT] 查询: {query[:50]}...")
return json.loads(cached_data)
except Exception as e:
print(f"[CACHE ERROR] 读取失败: {e}")
return None
def set(self, query: str, mode: str, data: Dict):
"""设置缓存"""
if not self.enabled:
return
try:
key = self.get_key(query, mode)
self.client.setex(
key,
CACHE_TTL,
json.dumps(data, ensure_ascii=False)
)
print(f"[CACHE SET] 已缓存: {query[:50]}... (TTL: {CACHE_TTL}秒)")
except Exception as e:
print(f"[CACHE ERROR] 写入失败: {e}")
def get_stats(self) -> Dict:
"""获取缓存统计"""
if not self.enabled:
return {"enabled": False}
try:
info = self.client.info()
dbsize = self.client.dbsize()
return {
"enabled": True,
"keys": dbsize,
"used_memory": info.get("used_memory_human"),
"hits": info.get("keyspace_hits", 0),
"misses": info.get("keyspace_misses", 0),
"hit_rate": round(
info.get("keyspace_hits", 0) /
max(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0), 1) * 100,
2
)
}
except:
return {"enabled": False, "error": "无法获取统计"}
def clear(self):
"""清空缓存"""
if not self.enabled:
return
try:
self.client.flushdb()
print("[CACHE] 缓存已清空")
except Exception as e:
print(f"[CACHE ERROR] 清空失败: {e}")
# ============= 全局状态管理 =============
class AppState:
"""应用状态管理"""
retriever = None
langsmith_connected = False
cache_manager = None
request_count = 0
cache_hits = 0
last_request_time = None
# ============= 生命周期管理 =============
def initialize_retriever():
"""初始化检索器"""
try:
AppState.langsmith_connected = check_langsmith_connection()
AppState.retriever = create_langsmith_retriever(keyword="test")
print("[OK] 检索器初始化成功")
except Exception as e:
print(f"[ERROR] 检索器初始化失败: {e}")
AppState.retriever = None
def initialize_cache():
"""初始化缓存"""
AppState.cache_manager = CacheManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
print("="*60)
print("[STARTING] 启动RAG API服务Redis缓存版")
print("="*60)
# 创建输出目录
OUTPUT_DIR.mkdir(exist_ok=True)
# 初始化组件
initialize_retriever()
initialize_cache()
print(f"[INFO] Redis配置: {REDIS_HOST}:{REDIS_PORT}")
print(f"[INFO] 缓存TTL: {CACHE_TTL}")
print("="*60)
yield
print("\n[INFO] 服务关闭")
# ============= FastAPI应用 =============
app = FastAPI(
title="RAG API Service (Redis Cache)",
description="RAG检索服务Redis缓存版",
version="2.0.0",
lifespan=lifespan
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============= API端点 =============
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
return HealthResponse(
status="healthy" if AppState.retriever else "degraded",
environment="production",
langsmith_connected=AppState.langsmith_connected,
redis_connected=AppState.cache_manager.enabled if AppState.cache_manager else False,
timestamp=datetime.now().isoformat()
)
@app.post("/retrieve", response_model=RetrieveResponse)
async def retrieve(request: RetrieveRequest):
"""RAG查询接口带缓存"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
AppState.request_count += 1
AppState.last_request_time = datetime.now()
print(f"\n{'='*60}")
print(f"[IN] 请求 #{AppState.request_count}")
print(f" 查询: {request.query}")
print(f" 使用缓存: {request.use_cache}")
if not AppState.retriever:
raise HTTPException(status_code=503, detail="检索器未初始化")
try:
# 1. 检查缓存
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
cached_result = AppState.cache_manager.get(request.query, request.mode)
if cached_result:
AppState.cache_hits += 1
print(f"[OK] 缓存命中 (命中率: {AppState.cache_hits}/{AppState.request_count})")
# 直接返回缓存结果
return RetrieveResponse(
success=True,
query=request.query,
answer=cached_result.get("answer", ""),
supporting_facts=cached_result.get("supporting_facts", []),
supporting_events=cached_result.get("supporting_events", []),
metadata={
"environment": "production",
"timestamp": timestamp,
"cached": True,
"cache_hit_rate": f"{AppState.cache_hits}/{AppState.request_count}"
}
)
# 2. 执行检索
print("[PROCESSING] 执行RAG检索...")
result = AppState.retriever.retrieve(
request.query,
request.mode
)
# 3. 存入缓存
if request.use_cache and AppState.cache_manager and AppState.cache_manager.enabled:
cache_data = {
"answer": result.get("answer", ""),
"supporting_facts": result.get("supporting_facts", []),
"supporting_events": result.get("supporting_events", [])
}
AppState.cache_manager.set(request.query, request.mode, cache_data)
# 4. 保存输出(可选)
output_file = None
if request.save_output:
output_file = OUTPUT_DIR / f"output_{timestamp}.json"
output_data = {
"timestamp": timestamp,
"query": request.query,
"answer": result.get("answer", ""),
"supporting_facts": result.get("supporting_facts", []),
"supporting_events": result.get("supporting_events", []),
"metadata": {
"cached": False,
"mode": request.mode
}
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
response = RetrieveResponse(
success=True,
query=request.query,
answer=result.get("answer", ""),
supporting_facts=result.get("supporting_facts", []),
supporting_events=result.get("supporting_events", []),
metadata={
"environment": "production",
"timestamp": timestamp,
"cached": False
},
output_file=str(output_file) if output_file else None
)
print(f"[OK] 查询完成 (答案长度: {len(response.answer)} 字符)")
return response
except Exception as e:
print(f"[ERROR] 查询失败: {e}")
return RetrieveResponse(
success=False,
query=request.query,
answer="",
error=str(e),
metadata={"environment": "production", "timestamp": timestamp}
)
@app.get("/cache/stats")
async def cache_stats():
"""获取缓存统计"""
if not AppState.cache_manager:
return {"error": "缓存未初始化"}
stats = AppState.cache_manager.get_stats()
stats.update({
"total_requests": AppState.request_count,
"cache_hits": AppState.cache_hits,
"hit_rate_session": f"{round(AppState.cache_hits/max(AppState.request_count,1)*100, 2)}%"
})
return stats
@app.delete("/cache/clear")
async def clear_cache():
"""清空缓存"""
if not AppState.cache_manager:
return {"error": "缓存未初始化"}
AppState.cache_manager.clear()
return {"success": True, "message": "缓存已清空"}
# ============= 兼容旧接口 =============
@app.post("/query", response_model=RetrieveResponse)
async def query(request: RetrieveRequest):
"""查询接口(兼容旧版)"""
return await retrieve(request)
# ============= 主函数 =============
def main():
"""启动服务"""
import argparse
parser = argparse.ArgumentParser(description="RAG API服务Redis缓存版")
parser.add_argument("--host", default="0.0.0.0", help="服务地址")
parser.add_argument("--port", type=int, default=8000, help="服务端口")
parser.add_argument("--workers", type=int, default=4, help="Worker进程数")
args = parser.parse_args()
uvicorn.run(
"rag_api_server_with_cache:app",
host=args.host,
port=args.port,
reload=False,
workers=args.workers,
log_level="info"
)
if __name__ == "__main__":
main()