Files
AIEC-RAG/retriver/langgraph/dashscope_llm.py
2025-09-24 09:29:12 +08:00

483 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
阿里云DashScope原生LLM实现
直接调用阿里云通义千问API不通过OneAPI兼容层
"""
import os
import sys
import json
import time
import requests
from typing import List, Dict, Any, Optional, Union
from langchain_core.language_models import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation
from dotenv import load_dotenv
# 加载环境变量
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
class DashScopeLLM(BaseLLM):
"""
阿里云DashScope原生LLM实现
直接调用通义千问API
"""
# Pydantic字段定义
api_key: str = ""
model_name: str = "qwen-turbo"
max_retries: int = 3
retry_delay: float = 1.0
api_url: str = ""
headers: dict = {}
last_token_usage: dict = {}
total_token_usage: dict = {}
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
**kwargs
):
"""
初始化DashScope LLM
Args:
api_key: 阿里云DashScope API Key
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
"""
# 先设置字段值
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
if not api_key_value:
raise ValueError("DashScope API Key未设置请在.env文件中设置ONEAPI_KEY")
# DashScope API端点
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
# 设置请求头
headers_value = {
'Authorization': f'Bearer {api_key_value}',
'Content-Type': 'application/json',
'X-DashScope-SSE': 'disable' # 禁用流式响应
}
# Token使用统计
last_token_usage_value = {}
total_token_usage_value = {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'call_count': 0
}
# 初始化父类,传递所有字段
super().__init__(
api_key=api_key_value,
model_name=model_name_value,
max_retries=max_retries,
retry_delay=retry_delay,
api_url=api_url_value,
headers=headers_value,
last_token_usage=last_token_usage_value,
total_token_usage=total_token_usage_value,
**kwargs
)
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
self._test_connection()
def _test_connection(self):
"""测试DashScope连接"""
try:
# 发送一个简单的测试请求
test_payload = {
"model": self.model_name,
"input": {
"messages": [
{"role": "user", "content": "hello"}
]
},
"parameters": {
"max_tokens": 10,
"temperature": 0.1
}
}
response = requests.post(
self.api_url,
headers=self.headers,
json=test_payload,
timeout=(5, 15) # 短超时进行连接测试
)
if response.status_code == 200:
result = response.json()
if result.get('output') and result['output'].get('text'):
print("[OK] DashScope连接测试成功")
else:
print(f"[WARNING] DashScope响应格式异常: {result}")
else:
print(f"[WARNING] DashScope连接异常状态码: {response.status_code}")
print(f" 响应: {response.text[:200]}")
except requests.exceptions.Timeout:
print("[WARNING] DashScope连接超时但将在实际请求时重试")
except requests.exceptions.ConnectionError:
print("[WARNING] DashScope连接失败请检查网络状态")
except Exception as e:
print(f"[WARNING] DashScope连接测试出错: {e}")
@property
def _llm_type(self) -> str:
return "dashscope"
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""生成响应"""
generations = []
for prompt in prompts:
try:
response_text = self._call_api(prompt, **kwargs)
generations.append([Generation(text=response_text)])
except Exception as e:
print(f"[ERROR] DashScope API调用失败: {e}")
generations.append([Generation(text=f"API调用失败: {str(e)}")])
return LLMResult(generations=generations)
def _call_api(self, prompt: str, **kwargs) -> str:
"""调用DashScope API"""
# 构建请求payload
payload = {
"model": self.model_name,
"input": {
"messages": [
{"role": "user", "content": prompt}
]
},
"parameters": {
"max_tokens": kwargs.get("max_tokens", 2048),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.8),
}
}
# 如果有stop参数添加到payload中
stop = kwargs.get("stop")
if stop:
payload["parameters"]["stop"] = stop
for attempt in range(self.max_retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=(30, 120) # 连接30秒读取120秒
)
if response.status_code == 200:
result = response.json()
# 检查是否有错误
if result.get('code'):
error_msg = result.get('message', f'API错误代码: {result["code"]}')
if attempt == self.max_retries - 1:
raise RuntimeError(f"DashScope API错误: {error_msg}")
else:
print(f"API错误正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
continue
# 提取响应内容
if result.get('output') and result['output'].get('text'):
content = result['output']['text']
# 提取Token使用信息
usage_info = result.get('usage', {})
self.last_token_usage = {
'prompt_tokens': usage_info.get('input_tokens', 0),
'completion_tokens': usage_info.get('output_tokens', 0),
'total_tokens': usage_info.get('total_tokens', 0)
}
# 更新累计统计
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
self.total_token_usage['call_count'] += 1
return content.strip()
else:
error_msg = f"API响应格式错误: {result}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
time.sleep(self.retry_delay * (attempt + 1))
else:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败状态码: {response.status_code}, 响应: {error_text}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
time.sleep(self.retry_delay * (attempt + 1))
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断请求")
raise KeyboardInterrupt("用户中断请求")
except requests.exceptions.Timeout as e:
error_msg = f"请求超时: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
else:
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.exceptions.ConnectionError as e:
error_msg = f"连接错误: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
else:
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.RequestException as e:
error_msg = f"网络请求异常: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
else:
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
time.sleep(self.retry_delay * (attempt + 1))
raise RuntimeError("所有重试都失败了")
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
"""
生成响应,兼容原始接口
Args:
messages: 消息列表或批量消息列表
max_new_tokens: 最大生成token数
temperature: 温度参数
response_format: 响应格式
**kwargs: 其他参数
Returns:
生成的文本或文本列表
"""
# 检查是否为批量处理
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
if is_batch:
# 批量处理
results = []
for msg_list in messages:
try:
# 构造prompt
prompt = self._format_messages(msg_list)
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
results.append(result)
except Exception as e:
print(f"批量处理中的单个请求失败: {str(e)}")
results.append("")
return results
else:
# 单个处理
prompt = self._format_messages(messages)
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
def _format_messages(self, messages):
"""将消息列表格式化为单一的prompt"""
formatted_parts = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted_parts.append(f"System: {content}")
elif role == 'user':
formatted_parts.append(f"User: {content}")
elif role == 'assistant':
formatted_parts.append(f"Assistant: {content}")
return "\n\n".join(formatted_parts)
def ner(self, text: str) -> str:
"""
命名实体识别
Args:
text: 输入文本
Returns:
提取的实体,用逗号分隔
"""
messages = [
{
"role": "system",
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
},
{
"role": "user",
"content": f"Extract the named entities from: {text}"
}
]
try:
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
except Exception as e:
print(f"NER失败: {str(e)}")
return text # 如果失败,返回原文本
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
"""
基于实体和事件过滤三元组
Args:
question: 查询问题
triples: 三元组JSON字符串
Returns:
过滤后的三元组JSON字符串
"""
import json
import json_repair
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
filter_messages = [
{
"role": "system",
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
关键要求:
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
2. 你的输出必须是输入事实的严格子集
3. 只包含与回答问题直接相关的事实
4. 如果不确定,宁可选择更少的事实,也不要选择更多
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
过滤规则:
- 只选择包含与问题直接相关的实体或关系的事实
- 不能修改、改写或创建输入事实的变体
- 不能添加看起来相关但不在输入中的事实
- 输出事实数量必须 ≤ 输入事实数量
返回格式为包含"fact"键的JSON对象值为选中的事实数组。
示例:
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
问题A和B是什么关系
正确输出:{"fact": [["A", "关系1", "B"]]}
错误做法:添加新事实或修改现有事实"""
},
{
"role": "user",
"content": f"""问题:{question}
待筛选的输入事实:
{triples}
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
}
]
try:
response = self.generate_response(
filter_messages,
max_new_tokens=4096,
temperature=0.0
)
# 尝试解析JSON响应
try:
parsed_response = json.loads(response)
if 'fact' in parsed_response:
return json.dumps(parsed_response, ensure_ascii=False)
else:
# 如果没有fact字段尝试修复
return json.dumps({"fact": []}, ensure_ascii=False)
except json.JSONDecodeError:
# 如果JSON解析失败尝试使用json_repair
try:
parsed_response = json_repair.loads(response)
if 'fact' in parsed_response:
return json.dumps(parsed_response, ensure_ascii=False)
else:
return json.dumps({"fact": []}, ensure_ascii=False)
except:
# 如果所有解析都失败,返回空结果
return json.dumps({"fact": []}, ensure_ascii=False)
except Exception as e:
print(f"三元组过滤失败: {str(e)}")
# 如果过滤失败,返回原始三元组
return triples
def generate_with_context(self,
question: str,
context: str,
max_new_tokens: int = 1024,
temperature: float = 0.7) -> str:
"""
基于上下文生成回答
Args:
question: 问题
context: 上下文
max_new_tokens: 最大生成token数
temperature: 温度参数
Returns:
生成的回答
"""
messages = [
{
"role": "system",
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
},
{
"role": "user",
"content": f"{context}\n\n{question}\nThought:"
}
]
try:
return self.generate_response(messages, max_new_tokens, temperature)
except Exception as e:
print(f"基于上下文生成失败: {str(e)}")
return "抱歉,我无法基于提供的上下文回答这个问题。"
def get_token_usage(self) -> Dict[str, Any]:
"""获取Token使用统计"""
return {
"last_usage": self.last_token_usage.copy(),
"total_usage": self.total_token_usage.copy(),
"model_name": self.model_name
}
def create_dashscope_llm(
api_key: Optional[str] = None,
model_name: Optional[str] = None,
**kwargs
) -> DashScopeLLM:
"""创建DashScope LLM实例的便捷函数"""
return DashScopeLLM(
api_key=api_key,
model_name=model_name,
**kwargs
)