Files
deepagents----/tests/debug_llm_calls.py
2025-11-02 18:06:38 +08:00

309 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
记录LLM调用的详细信息 - 保存为JSON文件
使用方法:
export PYTHONIOENCODING=utf-8 && python tests/debug_llm_calls.py
"""
import sys
import os
import json
from datetime import datetime
from typing import Any, Dict, List
from uuid import UUID
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from src.agents.coordinator import create_research_coordinator
from src.config import Config
class LLMCallLogger(BaseCallbackHandler):
"""记录所有LLM调用的回调处理器"""
def __init__(self):
self.calls: List[Dict[str, Any]] = []
self.current_call = None
self.call_count = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""LLM开始时调用"""
self.call_count += 1
self.current_call = {
"call_id": self.call_count,
"timestamp_start": datetime.now().isoformat(),
"prompts": prompts,
"kwargs": {k: str(v) for k, v in kwargs.items() if k != "invocation_params"},
}
print(f"\n{'='*80}")
print(f"🔵 LLM调用 #{self.call_count} 开始 - {datetime.now().strftime('%H:%M:%S')}")
print('='*80)
if prompts:
print(f"Prompt长度: {len(prompts[0])} 字符")
print(f"Prompt预览: {prompts[0][:200]}...")
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> None:
"""Chat模型开始时调用"""
self.call_count += 1
self.current_call = {
"call_id": self.call_count,
"timestamp_start": datetime.now().isoformat(),
"messages": [
[
{
"type": type(msg).__name__,
"content": msg.content if hasattr(msg, 'content') else str(msg),
"tool_calls": getattr(msg, 'tool_calls', None)
}
for msg in msg_list
]
for msg_list in messages
],
"kwargs": {k: str(v) for k, v in kwargs.items() if k not in ["invocation_params", "tags", "metadata"]},
}
print(f"\n{'='*80}")
print(f"🔵 Chat模型调用 #{self.call_count} 开始 - {datetime.now().strftime('%H:%M:%S')}")
print('='*80)
if messages:
print(f"消息数量: {len(messages[0])}")
for i, msg in enumerate(messages[0][-3:], 1):
msg_type = type(msg).__name__
print(f" {i}. {msg_type}: {str(msg.content)[:100] if hasattr(msg, 'content') else 'N/A'}...")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""LLM结束时调用"""
if self.current_call:
self.current_call["timestamp_end"] = datetime.now().isoformat()
# 提取响应
generations = []
for gen_list in response.generations:
for gen in gen_list:
gen_info = {
"text": gen.text if hasattr(gen, 'text') else None,
}
if hasattr(gen, 'message'):
msg = gen.message
gen_info["message"] = {
"type": type(msg).__name__,
"content": msg.content if hasattr(msg, 'content') else None,
"tool_calls": [
{
"name": tc.get("name"),
"args": tc.get("args"),
"id": tc.get("id")
}
for tc in (msg.tool_calls if hasattr(msg, 'tool_calls') and msg.tool_calls else [])
] if hasattr(msg, 'tool_calls') else None
}
generations.append(gen_info)
self.current_call["response"] = {
"generations": generations,
"llm_output": response.llm_output,
}
self.calls.append(self.current_call)
print(f"\n✅ LLM调用 #{self.current_call['call_id']} 完成")
if generations:
gen = generations[0]
if gen.get("message"):
msg = gen["message"]
print(f"响应类型: {msg['type']}")
if msg.get('content'):
print(f"内容: {msg['content'][:150]}...")
if msg.get('tool_calls'):
print(f"工具调用: {len(msg['tool_calls'])}")
for tc in msg['tool_calls'][:3]:
print(f" - {tc['name']}")
self.current_call = None
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
"""LLM出错时调用"""
if self.current_call:
self.current_call["timestamp_end"] = datetime.now().isoformat()
self.current_call["error"] = str(error)
self.calls.append(self.current_call)
print(f"\n❌ LLM调用 #{self.current_call['call_id']} 出错: {error}")
self.current_call = None
def save_to_file(self, filepath: str):
"""保存记录到JSON文件"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump({
"total_calls": len(self.calls),
"calls": self.calls
}, f, ensure_ascii=False, indent=2)
print(f"\n💾 已保存 {len(self.calls)} 次LLM调用记录到: {filepath}")
def test_with_llm_logging(question: str, depth: str = "quick", max_steps: int = 10):
"""
测试研究流程记录所有LLM调用
Args:
question: 研究问题
depth: 深度模式
max_steps: 最大执行步骤数(防止无限循环)
"""
print("\n" + "🔬 " * 40)
print("智能深度研究系统 - LLM调用记录模式")
print("🔬 " * 40)
print(f"\n研究问题: {question}")
print(f"深度模式: {depth}")
print(f"最大步骤数: {max_steps}")
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# 创建日志记录器
logger = LLMCallLogger()
# 创建Agent带callback
print("\n" + "="*80)
print("创建Agent...")
print("="*80)
try:
# 获取LLM并添加callback
llm = Config.get_llm()
llm.callbacks = [logger]
# 创建Agent
agent = create_research_coordinator(
question=question,
depth=depth,
format="technical",
min_tier=3
)
print("✅ Agent创建成功")
except Exception as e:
print(f"❌ Agent创建失败: {e}")
import traceback
traceback.print_exc()
return
# 执行研究
print("\n" + "="*80)
print(f"执行研究流程(最多{max_steps}步)...")
print("="*80)
try:
start_time = datetime.now()
step_count = 0
# 使用stream模式但限制步骤数
for chunk in agent.stream(
{
"messages": [
{
"role": "user",
"content": f"请开始研究这个问题:{question}"
}
]
},
config={"callbacks": [logger]}
):
step_count += 1
print(f"\n{''*80}")
print(f"📍 步骤 #{step_count} - {datetime.now().strftime('%H:%M:%S')}")
print(''*80)
# 显示state更新
if isinstance(chunk, dict):
if 'messages' in chunk:
print(f" 消息: {len(chunk['messages'])}")
if 'files' in chunk:
print(f" 文件: {len(chunk['files'])}")
for path in list(chunk['files'].keys())[:3]:
print(f" - {path}")
# 限制步骤数
if step_count >= max_steps:
print(f"\n⚠️ 达到最大步骤数 {max_steps},停止执行")
break
# 超时保护
elapsed = (datetime.now() - start_time).total_seconds()
if elapsed > 120: # 2分钟
print(f"\n⚠️ 超过2分钟停止执行")
break
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
print("\n" + "="*80)
print("执行结束")
print("="*80)
print(f"总步骤数: {step_count}")
print(f"LLM调用次数: {len(logger.calls)}")
print(f"总耗时: {duration:.2f}")
except KeyboardInterrupt:
print("\n\n⚠️ 用户中断")
except Exception as e:
print(f"\n\n❌ 执行失败: {e}")
import traceback
traceback.print_exc()
finally:
# 保存日志
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = "tests"
os.makedirs(output_dir, exist_ok=True)
log_file = os.path.join(output_dir, f"llm_calls_{timestamp}.json")
logger.save_to_file(log_file)
# 也保存一份摘要
summary_file = os.path.join(output_dir, f"llm_calls_summary_{timestamp}.txt")
with open(summary_file, 'w', encoding='utf-8') as f:
f.write(f"LLM调用记录摘要\n")
f.write(f"{'='*80}\n\n")
f.write(f"总调用次数: {len(logger.calls)}\n")
f.write(f"执行时长: {duration:.2f}\n\n")
for i, call in enumerate(logger.calls, 1):
f.write(f"\n{''*80}\n")
f.write(f"调用 #{i}\n")
f.write(f"{''*80}\n")
f.write(f"开始: {call['timestamp_start']}\n")
f.write(f"结束: {call.get('timestamp_end', 'N/A')}\n")
if 'messages' in call:
f.write(f"消息数: {len(call['messages'][0]) if call['messages'] else 0}\n")
if 'response' in call:
gens = call['response'].get('generations', [])
if gens:
gen = gens[0]
if gen.get('message'):
msg = gen['message']
f.write(f"响应类型: {msg['type']}\n")
if msg.get('tool_calls'):
f.write(f"工具调用: {[tc['name'] for tc in msg['tool_calls']]}\n")
if 'error' in call:
f.write(f"错误: {call['error']}\n")
print(f"📄 摘要已保存到: {summary_file}")
if __name__ == "__main__":
question = "Python asyncio最佳实践"
# 只执行前几步不做完整research
test_with_llm_logging(question, depth="quick", max_steps=10)