first commit
This commit is contained in:
308
tests/debug_llm_calls.py
Normal file
308
tests/debug_llm_calls.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""
|
||||
记录LLM调用的详细信息 - 保存为JSON文件
|
||||
|
||||
使用方法:
|
||||
export PYTHONIOENCODING=utf-8 && python tests/debug_llm_calls.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from src.agents.coordinator import create_research_coordinator
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class LLMCallLogger(BaseCallbackHandler):
|
||||
"""记录所有LLM调用的回调处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
self.current_call = None
|
||||
self.call_count = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""LLM开始时调用"""
|
||||
self.call_count += 1
|
||||
self.current_call = {
|
||||
"call_id": self.call_count,
|
||||
"timestamp_start": datetime.now().isoformat(),
|
||||
"prompts": prompts,
|
||||
"kwargs": {k: str(v) for k, v in kwargs.items() if k != "invocation_params"},
|
||||
}
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔵 LLM调用 #{self.call_count} 开始 - {datetime.now().strftime('%H:%M:%S')}")
|
||||
print('='*80)
|
||||
if prompts:
|
||||
print(f"Prompt长度: {len(prompts[0])} 字符")
|
||||
print(f"Prompt预览: {prompts[0][:200]}...")
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Chat模型开始时调用"""
|
||||
self.call_count += 1
|
||||
self.current_call = {
|
||||
"call_id": self.call_count,
|
||||
"timestamp_start": datetime.now().isoformat(),
|
||||
"messages": [
|
||||
[
|
||||
{
|
||||
"type": type(msg).__name__,
|
||||
"content": msg.content if hasattr(msg, 'content') else str(msg),
|
||||
"tool_calls": getattr(msg, 'tool_calls', None)
|
||||
}
|
||||
for msg in msg_list
|
||||
]
|
||||
for msg_list in messages
|
||||
],
|
||||
"kwargs": {k: str(v) for k, v in kwargs.items() if k not in ["invocation_params", "tags", "metadata"]},
|
||||
}
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔵 Chat模型调用 #{self.call_count} 开始 - {datetime.now().strftime('%H:%M:%S')}")
|
||||
print('='*80)
|
||||
if messages:
|
||||
print(f"消息数量: {len(messages[0])}")
|
||||
for i, msg in enumerate(messages[0][-3:], 1):
|
||||
msg_type = type(msg).__name__
|
||||
print(f" {i}. {msg_type}: {str(msg.content)[:100] if hasattr(msg, 'content') else 'N/A'}...")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""LLM结束时调用"""
|
||||
if self.current_call:
|
||||
self.current_call["timestamp_end"] = datetime.now().isoformat()
|
||||
|
||||
# 提取响应
|
||||
generations = []
|
||||
for gen_list in response.generations:
|
||||
for gen in gen_list:
|
||||
gen_info = {
|
||||
"text": gen.text if hasattr(gen, 'text') else None,
|
||||
}
|
||||
if hasattr(gen, 'message'):
|
||||
msg = gen.message
|
||||
gen_info["message"] = {
|
||||
"type": type(msg).__name__,
|
||||
"content": msg.content if hasattr(msg, 'content') else None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": tc.get("name"),
|
||||
"args": tc.get("args"),
|
||||
"id": tc.get("id")
|
||||
}
|
||||
for tc in (msg.tool_calls if hasattr(msg, 'tool_calls') and msg.tool_calls else [])
|
||||
] if hasattr(msg, 'tool_calls') else None
|
||||
}
|
||||
generations.append(gen_info)
|
||||
|
||||
self.current_call["response"] = {
|
||||
"generations": generations,
|
||||
"llm_output": response.llm_output,
|
||||
}
|
||||
|
||||
self.calls.append(self.current_call)
|
||||
|
||||
print(f"\n✅ LLM调用 #{self.current_call['call_id']} 完成")
|
||||
if generations:
|
||||
gen = generations[0]
|
||||
if gen.get("message"):
|
||||
msg = gen["message"]
|
||||
print(f"响应类型: {msg['type']}")
|
||||
if msg.get('content'):
|
||||
print(f"内容: {msg['content'][:150]}...")
|
||||
if msg.get('tool_calls'):
|
||||
print(f"工具调用: {len(msg['tool_calls'])} 个")
|
||||
for tc in msg['tool_calls'][:3]:
|
||||
print(f" - {tc['name']}")
|
||||
|
||||
self.current_call = None
|
||||
|
||||
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
"""LLM出错时调用"""
|
||||
if self.current_call:
|
||||
self.current_call["timestamp_end"] = datetime.now().isoformat()
|
||||
self.current_call["error"] = str(error)
|
||||
self.calls.append(self.current_call)
|
||||
print(f"\n❌ LLM调用 #{self.current_call['call_id']} 出错: {error}")
|
||||
self.current_call = None
|
||||
|
||||
def save_to_file(self, filepath: str):
|
||||
"""保存记录到JSON文件"""
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"total_calls": len(self.calls),
|
||||
"calls": self.calls
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 已保存 {len(self.calls)} 次LLM调用记录到: {filepath}")
|
||||
|
||||
|
||||
def test_with_llm_logging(question: str, depth: str = "quick", max_steps: int = 10):
|
||||
"""
|
||||
测试研究流程,记录所有LLM调用
|
||||
|
||||
Args:
|
||||
question: 研究问题
|
||||
depth: 深度模式
|
||||
max_steps: 最大执行步骤数(防止无限循环)
|
||||
"""
|
||||
print("\n" + "🔬 " * 40)
|
||||
print("智能深度研究系统 - LLM调用记录模式")
|
||||
print("🔬 " * 40)
|
||||
|
||||
print(f"\n研究问题: {question}")
|
||||
print(f"深度模式: {depth}")
|
||||
print(f"最大步骤数: {max_steps}")
|
||||
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# 创建日志记录器
|
||||
logger = LLMCallLogger()
|
||||
|
||||
# 创建Agent(带callback)
|
||||
print("\n" + "="*80)
|
||||
print("创建Agent...")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# 获取LLM并添加callback
|
||||
llm = Config.get_llm()
|
||||
llm.callbacks = [logger]
|
||||
|
||||
# 创建Agent
|
||||
agent = create_research_coordinator(
|
||||
question=question,
|
||||
depth=depth,
|
||||
format="technical",
|
||||
min_tier=3
|
||||
)
|
||||
print("✅ Agent创建成功")
|
||||
except Exception as e:
|
||||
print(f"❌ Agent创建失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# 执行研究
|
||||
print("\n" + "="*80)
|
||||
print(f"执行研究流程(最多{max_steps}步)...")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
start_time = datetime.now()
|
||||
step_count = 0
|
||||
|
||||
# 使用stream模式,但限制步骤数
|
||||
for chunk in agent.stream(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"请开始研究这个问题:{question}"
|
||||
}
|
||||
]
|
||||
},
|
||||
config={"callbacks": [logger]}
|
||||
):
|
||||
step_count += 1
|
||||
print(f"\n{'─'*80}")
|
||||
print(f"📍 步骤 #{step_count} - {datetime.now().strftime('%H:%M:%S')}")
|
||||
print('─'*80)
|
||||
|
||||
# 显示state更新
|
||||
if isinstance(chunk, dict):
|
||||
if 'messages' in chunk:
|
||||
print(f" 消息: {len(chunk['messages'])} 条")
|
||||
if 'files' in chunk:
|
||||
print(f" 文件: {len(chunk['files'])} 个")
|
||||
for path in list(chunk['files'].keys())[:3]:
|
||||
print(f" - {path}")
|
||||
|
||||
# 限制步骤数
|
||||
if step_count >= max_steps:
|
||||
print(f"\n⚠️ 达到最大步骤数 {max_steps},停止执行")
|
||||
break
|
||||
|
||||
# 超时保护
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
if elapsed > 120: # 2分钟
|
||||
print(f"\n⚠️ 超过2分钟,停止执行")
|
||||
break
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("执行结束")
|
||||
print("="*80)
|
||||
print(f"总步骤数: {step_count}")
|
||||
print(f"LLM调用次数: {len(logger.calls)}")
|
||||
print(f"总耗时: {duration:.2f}秒")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 用户中断")
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ 执行失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# 保存日志
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = "tests"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_file = os.path.join(output_dir, f"llm_calls_{timestamp}.json")
|
||||
logger.save_to_file(log_file)
|
||||
|
||||
# 也保存一份摘要
|
||||
summary_file = os.path.join(output_dir, f"llm_calls_summary_{timestamp}.txt")
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
f.write(f"LLM调用记录摘要\n")
|
||||
f.write(f"{'='*80}\n\n")
|
||||
f.write(f"总调用次数: {len(logger.calls)}\n")
|
||||
f.write(f"执行时长: {duration:.2f}秒\n\n")
|
||||
|
||||
for i, call in enumerate(logger.calls, 1):
|
||||
f.write(f"\n{'─'*80}\n")
|
||||
f.write(f"调用 #{i}\n")
|
||||
f.write(f"{'─'*80}\n")
|
||||
f.write(f"开始: {call['timestamp_start']}\n")
|
||||
f.write(f"结束: {call.get('timestamp_end', 'N/A')}\n")
|
||||
|
||||
if 'messages' in call:
|
||||
f.write(f"消息数: {len(call['messages'][0]) if call['messages'] else 0}\n")
|
||||
|
||||
if 'response' in call:
|
||||
gens = call['response'].get('generations', [])
|
||||
if gens:
|
||||
gen = gens[0]
|
||||
if gen.get('message'):
|
||||
msg = gen['message']
|
||||
f.write(f"响应类型: {msg['type']}\n")
|
||||
if msg.get('tool_calls'):
|
||||
f.write(f"工具调用: {[tc['name'] for tc in msg['tool_calls']]}\n")
|
||||
|
||||
if 'error' in call:
|
||||
f.write(f"错误: {call['error']}\n")
|
||||
|
||||
print(f"📄 摘要已保存到: {summary_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
question = "Python asyncio最佳实践"
|
||||
|
||||
# 只执行前几步,不做完整research
|
||||
test_with_llm_logging(question, depth="quick", max_steps=10)
|
||||
Reference in New Issue
Block a user