56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
|
|
"""
|
||
|
|
生成LangGraph Mermaid图表的脚本
|
||
|
|
使用LangGraph内置的draw_mermaid()方法
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# 添加项目根目录到Python路径
|
||
|
|
# 从当前文件位置向上找到项目根目录 (Retriver/)
|
||
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||
|
|
sys.path.insert(0, project_root)
|
||
|
|
|
||
|
|
def generate_mermaid_graph():
|
||
|
|
"""使用LangGraph内置方法生成Mermaid图表"""
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 导入必要的模块
|
||
|
|
from retriver.langgraph.iterative_retriever import IterativeRetriever
|
||
|
|
|
||
|
|
print("初始化检索器...")
|
||
|
|
# 创建一个最小化配置的迭代检索器实例
|
||
|
|
retriever = IterativeRetriever(
|
||
|
|
keyword="temp",
|
||
|
|
top_k=2,
|
||
|
|
max_iterations=3,
|
||
|
|
max_parallel_retrievals=2
|
||
|
|
)
|
||
|
|
|
||
|
|
print("生成LangGraph Mermaid图表...")
|
||
|
|
# 使用LangGraph内置的draw_mermaid()方法
|
||
|
|
mermaid_graph = retriever.workflow.get_graph().draw_mermaid()
|
||
|
|
|
||
|
|
# 保存到Graph.txt文件
|
||
|
|
output_file = os.path.join(os.path.dirname(__file__), "Graph.txt")
|
||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||
|
|
f.write(mermaid_graph)
|
||
|
|
|
||
|
|
print(f"LangGraph Mermaid图表已保存到: {output_file}")
|
||
|
|
print("\n图表内容:")
|
||
|
|
print(mermaid_graph)
|
||
|
|
|
||
|
|
return mermaid_graph
|
||
|
|
|
||
|
|
except ImportError as e:
|
||
|
|
print(f"模块导入失败: {e}")
|
||
|
|
print("请确保在正确的Python环境中运行此脚本")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
print(f"生成Mermaid图表失败: {e}")
|
||
|
|
print("可能是因为缺少必要的依赖或配置")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
generate_mermaid_graph()
|