first commit
This commit is contained in:
118
AIEC-RAG/test_stream_real.py
Normal file
118
AIEC-RAG/test_stream_real.py
Normal file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试流式检索的真实进度反馈
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
async def test_stream_real():
|
||||
"""测试真实的流式进度"""
|
||||
|
||||
url = "http://localhost:8080/retrieve/stream"
|
||||
headers = {
|
||||
"Accept": "text/event-stream",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"query": "数据管理实践的白皮书是什么",
|
||||
"mode": "0",
|
||||
"save_output": False
|
||||
}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"开始流式检索测试 - {datetime.now()}")
|
||||
print(f"查询: {data['query']}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
async with client.stream("POST", url, json=data, headers=headers) as response:
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}\n")
|
||||
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
print(f"错误响应: {content.decode()}")
|
||||
return
|
||||
|
||||
print("接收流式数据:\n")
|
||||
buffer = ""
|
||||
event_count = 0
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
|
||||
# 处理完整的SSE消息
|
||||
while "\n\n" in buffer:
|
||||
message, buffer = buffer.split("\n\n", 1)
|
||||
|
||||
if message.startswith("data: "):
|
||||
data_str = message[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
print(f"\n流式传输完成")
|
||||
break
|
||||
|
||||
try:
|
||||
event = json.loads(data_str)
|
||||
event_count += 1
|
||||
|
||||
# 打印事件详情
|
||||
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
event_type = event.get("type", "unknown")
|
||||
progress = event.get("progress", 0)
|
||||
elapsed = event.get("elapsed", 0)
|
||||
|
||||
print(f"[{timestamp}] Event #{event_count}")
|
||||
print(f" 类型: {event_type}")
|
||||
print(f" 进度: {progress}%")
|
||||
print(f" 耗时: {elapsed}s")
|
||||
|
||||
# 打印特定类型的详情
|
||||
if event_type == "sub_queries" and "data" in event:
|
||||
queries = event["data"]
|
||||
print(f" 子查询数: {len(queries)}")
|
||||
for i, q in enumerate(queries[:3], 1):
|
||||
print(f" {i}. {q}")
|
||||
|
||||
elif event_type == "documents" and "data" in event:
|
||||
doc_data = event["data"]
|
||||
print(f" 文档数: {doc_data.get('count', 0)}")
|
||||
if "sources" in doc_data:
|
||||
sources = doc_data["sources"][:3]
|
||||
print(f" 来源: {', '.join(sources)}")
|
||||
|
||||
elif event_type == "iteration" and "data" in event:
|
||||
iter_data = event["data"]
|
||||
print(f" 当前轮次: {iter_data.get('current')}/{iter_data.get('max')}")
|
||||
|
||||
elif event_type == "sufficiency_check" and "data" in event:
|
||||
check = event["data"]
|
||||
print(f" 充分性: {check.get('is_sufficient')}")
|
||||
print(f" 置信度: {check.get('confidence')}%")
|
||||
|
||||
elif event_type == "answer" and "data" in event:
|
||||
answer_data = event["data"]
|
||||
content = answer_data.get("content", "")[:200]
|
||||
print(f" 答案预览: {content}...")
|
||||
|
||||
print() # 空行分隔
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON解析错误: {e}")
|
||||
print(f"原始数据: {data_str}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n测试失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_stream_real())
|
||||
Reference in New Issue
Block a user