483 lines
18 KiB
Python
483 lines
18 KiB
Python
|
|
"""
|
|||
|
|
阿里云DashScope原生LLM实现
|
|||
|
|
直接调用阿里云通义千问API,不通过OneAPI兼容层
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import requests
|
|||
|
|
from typing import List, Dict, Any, Optional, Union
|
|||
|
|
from langchain_core.language_models import BaseLLM
|
|||
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||
|
|
from langchain_core.outputs import LLMResult, Generation
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DashScopeLLM(BaseLLM):
|
|||
|
|
"""
|
|||
|
|
阿里云DashScope原生LLM实现
|
|||
|
|
直接调用通义千问API
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# Pydantic字段定义
|
|||
|
|
api_key: str = ""
|
|||
|
|
model_name: str = "qwen-turbo"
|
|||
|
|
max_retries: int = 3
|
|||
|
|
retry_delay: float = 1.0
|
|||
|
|
api_url: str = ""
|
|||
|
|
headers: dict = {}
|
|||
|
|
last_token_usage: dict = {}
|
|||
|
|
total_token_usage: dict = {}
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
model_name: Optional[str] = None,
|
|||
|
|
max_retries: int = 3,
|
|||
|
|
retry_delay: float = 1.0,
|
|||
|
|
**kwargs
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化DashScope LLM
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: 阿里云DashScope API Key
|
|||
|
|
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
retry_delay: 重试延迟时间(秒)
|
|||
|
|
"""
|
|||
|
|
# 先设置字段值
|
|||
|
|
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
|||
|
|
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
|||
|
|
|
|||
|
|
# DashScope API端点
|
|||
|
|
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|||
|
|
|
|||
|
|
# 设置请求头
|
|||
|
|
headers_value = {
|
|||
|
|
'Authorization': f'Bearer {api_key_value}',
|
|||
|
|
'Content-Type': 'application/json',
|
|||
|
|
'X-DashScope-SSE': 'disable' # 禁用流式响应
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Token使用统计
|
|||
|
|
last_token_usage_value = {}
|
|||
|
|
total_token_usage_value = {
|
|||
|
|
'prompt_tokens': 0,
|
|||
|
|
'completion_tokens': 0,
|
|||
|
|
'total_tokens': 0,
|
|||
|
|
'call_count': 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 初始化父类,传递所有字段
|
|||
|
|
super().__init__(
|
|||
|
|
api_key=api_key_value,
|
|||
|
|
model_name=model_name_value,
|
|||
|
|
max_retries=max_retries,
|
|||
|
|
retry_delay=retry_delay,
|
|||
|
|
api_url=api_url_value,
|
|||
|
|
headers=headers_value,
|
|||
|
|
last_token_usage=last_token_usage_value,
|
|||
|
|
total_token_usage=total_token_usage_value,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
|
|||
|
|
self._test_connection()
|
|||
|
|
|
|||
|
|
def _test_connection(self):
|
|||
|
|
"""测试DashScope连接"""
|
|||
|
|
try:
|
|||
|
|
# 发送一个简单的测试请求
|
|||
|
|
test_payload = {
|
|||
|
|
"model": self.model_name,
|
|||
|
|
"input": {
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "user", "content": "hello"}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"parameters": {
|
|||
|
|
"max_tokens": 10,
|
|||
|
|
"temperature": 0.1
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response = requests.post(
|
|||
|
|
self.api_url,
|
|||
|
|
headers=self.headers,
|
|||
|
|
json=test_payload,
|
|||
|
|
timeout=(5, 15) # 短超时进行连接测试
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
if result.get('output') and result['output'].get('text'):
|
|||
|
|
print("[OK] DashScope连接测试成功")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] DashScope响应格式异常: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] DashScope连接异常,状态码: {response.status_code}")
|
|||
|
|
print(f" 响应: {response.text[:200]}")
|
|||
|
|
|
|||
|
|
except requests.exceptions.Timeout:
|
|||
|
|
print("[WARNING] DashScope连接超时,但将在实际请求时重试")
|
|||
|
|
except requests.exceptions.ConnectionError:
|
|||
|
|
print("[WARNING] DashScope连接失败,请检查网络状态")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] DashScope连接测试出错: {e}")
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def _llm_type(self) -> str:
|
|||
|
|
return "dashscope"
|
|||
|
|
|
|||
|
|
def _generate(
|
|||
|
|
self,
|
|||
|
|
prompts: List[str],
|
|||
|
|
stop: Optional[List[str]] = None,
|
|||
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> LLMResult:
|
|||
|
|
"""生成响应"""
|
|||
|
|
generations = []
|
|||
|
|
|
|||
|
|
for prompt in prompts:
|
|||
|
|
try:
|
|||
|
|
response_text = self._call_api(prompt, **kwargs)
|
|||
|
|
generations.append([Generation(text=response_text)])
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] DashScope API调用失败: {e}")
|
|||
|
|
generations.append([Generation(text=f"API调用失败: {str(e)}")])
|
|||
|
|
|
|||
|
|
return LLMResult(generations=generations)
|
|||
|
|
|
|||
|
|
def _call_api(self, prompt: str, **kwargs) -> str:
|
|||
|
|
"""调用DashScope API"""
|
|||
|
|
# 构建请求payload
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model_name,
|
|||
|
|
"input": {
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "user", "content": prompt}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"parameters": {
|
|||
|
|
"max_tokens": kwargs.get("max_tokens", 2048),
|
|||
|
|
"temperature": kwargs.get("temperature", 0.7),
|
|||
|
|
"top_p": kwargs.get("top_p", 0.8),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 如果有stop参数,添加到payload中
|
|||
|
|
stop = kwargs.get("stop")
|
|||
|
|
if stop:
|
|||
|
|
payload["parameters"]["stop"] = stop
|
|||
|
|
|
|||
|
|
for attempt in range(self.max_retries):
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
self.api_url,
|
|||
|
|
headers=self.headers,
|
|||
|
|
json=payload,
|
|||
|
|
timeout=(30, 120) # 连接30秒,读取120秒
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
# 检查是否有错误
|
|||
|
|
if result.get('code'):
|
|||
|
|
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"DashScope API错误: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取响应内容
|
|||
|
|
if result.get('output') and result['output'].get('text'):
|
|||
|
|
content = result['output']['text']
|
|||
|
|
|
|||
|
|
# 提取Token使用信息
|
|||
|
|
usage_info = result.get('usage', {})
|
|||
|
|
self.last_token_usage = {
|
|||
|
|
'prompt_tokens': usage_info.get('input_tokens', 0),
|
|||
|
|
'completion_tokens': usage_info.get('output_tokens', 0),
|
|||
|
|
'total_tokens': usage_info.get('total_tokens', 0)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 更新累计统计
|
|||
|
|
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
|||
|
|
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
|||
|
|
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
|||
|
|
self.total_token_usage['call_count'] += 1
|
|||
|
|
|
|||
|
|
return content.strip()
|
|||
|
|
else:
|
|||
|
|
error_msg = f"API响应格式错误: {result}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(error_msg)
|
|||
|
|
else:
|
|||
|
|
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
else:
|
|||
|
|
error_text = response.text[:500] if response.text else "无响应内容"
|
|||
|
|
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(error_msg)
|
|||
|
|
else:
|
|||
|
|
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
print(f"\n[WARNING] 用户中断请求")
|
|||
|
|
raise KeyboardInterrupt("用户中断请求")
|
|||
|
|
|
|||
|
|
except requests.exceptions.Timeout as e:
|
|||
|
|
error_msg = f"请求超时: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except requests.exceptions.ConnectionError as e:
|
|||
|
|
error_msg = f"连接错误: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except requests.RequestException as e:
|
|||
|
|
error_msg = f"网络请求异常: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
raise RuntimeError("所有重试都失败了")
|
|||
|
|
|
|||
|
|
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
|
|||
|
|
"""
|
|||
|
|
生成响应,兼容原始接口
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表或批量消息列表
|
|||
|
|
max_new_tokens: 最大生成token数
|
|||
|
|
temperature: 温度参数
|
|||
|
|
response_format: 响应格式
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的文本或文本列表
|
|||
|
|
"""
|
|||
|
|
# 检查是否为批量处理
|
|||
|
|
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
|
|||
|
|
|
|||
|
|
if is_batch:
|
|||
|
|
# 批量处理
|
|||
|
|
results = []
|
|||
|
|
for msg_list in messages:
|
|||
|
|
try:
|
|||
|
|
# 构造prompt
|
|||
|
|
prompt = self._format_messages(msg_list)
|
|||
|
|
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
|||
|
|
results.append(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"批量处理中的单个请求失败: {str(e)}")
|
|||
|
|
results.append("")
|
|||
|
|
return results
|
|||
|
|
else:
|
|||
|
|
# 单个处理
|
|||
|
|
prompt = self._format_messages(messages)
|
|||
|
|
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
|||
|
|
|
|||
|
|
def _format_messages(self, messages):
|
|||
|
|
"""将消息列表格式化为单一的prompt"""
|
|||
|
|
formatted_parts = []
|
|||
|
|
for msg in messages:
|
|||
|
|
role = msg.get('role', 'user')
|
|||
|
|
content = msg.get('content', '')
|
|||
|
|
if role == 'system':
|
|||
|
|
formatted_parts.append(f"System: {content}")
|
|||
|
|
elif role == 'user':
|
|||
|
|
formatted_parts.append(f"User: {content}")
|
|||
|
|
elif role == 'assistant':
|
|||
|
|
formatted_parts.append(f"Assistant: {content}")
|
|||
|
|
return "\n\n".join(formatted_parts)
|
|||
|
|
|
|||
|
|
def ner(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
命名实体识别
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 输入文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取的实体,用逗号分隔
|
|||
|
|
"""
|
|||
|
|
messages = [
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": f"Extract the named entities from: {text}"
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"NER失败: {str(e)}")
|
|||
|
|
return text # 如果失败,返回原文本
|
|||
|
|
|
|||
|
|
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
|
|||
|
|
"""
|
|||
|
|
基于实体和事件过滤三元组
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
question: 查询问题
|
|||
|
|
triples: 三元组JSON字符串
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
过滤后的三元组JSON字符串
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
import json_repair
|
|||
|
|
|
|||
|
|
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
|
|||
|
|
filter_messages = [
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
|
|||
|
|
|
|||
|
|
关键要求:
|
|||
|
|
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
|
|||
|
|
2. 你的输出必须是输入事实的严格子集
|
|||
|
|
3. 只包含与回答问题直接相关的事实
|
|||
|
|
4. 如果不确定,宁可选择更少的事实,也不要选择更多
|
|||
|
|
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
|
|||
|
|
|
|||
|
|
过滤规则:
|
|||
|
|
- 只选择包含与问题直接相关的实体或关系的事实
|
|||
|
|
- 不能修改、改写或创建输入事实的变体
|
|||
|
|
- 不能添加看起来相关但不在输入中的事实
|
|||
|
|
- 输出事实数量必须 ≤ 输入事实数量
|
|||
|
|
|
|||
|
|
返回格式为包含"fact"键的JSON对象,值为选中的事实数组。
|
|||
|
|
|
|||
|
|
示例:
|
|||
|
|
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
|
|||
|
|
问题:A和B是什么关系?
|
|||
|
|
正确输出:{"fact": [["A", "关系1", "B"]]}
|
|||
|
|
错误做法:添加新事实或修改现有事实"""
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": f"""问题:{question}
|
|||
|
|
|
|||
|
|
待筛选的输入事实:
|
|||
|
|
{triples}
|
|||
|
|
|
|||
|
|
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = self.generate_response(
|
|||
|
|
filter_messages,
|
|||
|
|
max_new_tokens=4096,
|
|||
|
|
temperature=0.0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 尝试解析JSON响应
|
|||
|
|
try:
|
|||
|
|
parsed_response = json.loads(response)
|
|||
|
|
if 'fact' in parsed_response:
|
|||
|
|
return json.dumps(parsed_response, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
# 如果没有fact字段,尝试修复
|
|||
|
|
return json.dumps({"fact": []}, ensure_ascii=False)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# 如果JSON解析失败,尝试使用json_repair
|
|||
|
|
try:
|
|||
|
|
parsed_response = json_repair.loads(response)
|
|||
|
|
if 'fact' in parsed_response:
|
|||
|
|
return json.dumps(parsed_response, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
return json.dumps({"fact": []}, ensure_ascii=False)
|
|||
|
|
except:
|
|||
|
|
# 如果所有解析都失败,返回空结果
|
|||
|
|
return json.dumps({"fact": []}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"三元组过滤失败: {str(e)}")
|
|||
|
|
# 如果过滤失败,返回原始三元组
|
|||
|
|
return triples
|
|||
|
|
|
|||
|
|
def generate_with_context(self,
|
|||
|
|
question: str,
|
|||
|
|
context: str,
|
|||
|
|
max_new_tokens: int = 1024,
|
|||
|
|
temperature: float = 0.7) -> str:
|
|||
|
|
"""
|
|||
|
|
基于上下文生成回答
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
question: 问题
|
|||
|
|
context: 上下文
|
|||
|
|
max_new_tokens: 最大生成token数
|
|||
|
|
temperature: 温度参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的回答
|
|||
|
|
"""
|
|||
|
|
messages = [
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": f"{context}\n\n{question}\nThought:"
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return self.generate_response(messages, max_new_tokens, temperature)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"基于上下文生成失败: {str(e)}")
|
|||
|
|
return "抱歉,我无法基于提供的上下文回答这个问题。"
|
|||
|
|
|
|||
|
|
def get_token_usage(self) -> Dict[str, Any]:
|
|||
|
|
"""获取Token使用统计"""
|
|||
|
|
return {
|
|||
|
|
"last_usage": self.last_token_usage.copy(),
|
|||
|
|
"total_usage": self.total_token_usage.copy(),
|
|||
|
|
"model_name": self.model_name
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_dashscope_llm(
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
model_name: Optional[str] = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> DashScopeLLM:
|
|||
|
|
"""创建DashScope LLM实例的便捷函数"""
|
|||
|
|
return DashScopeLLM(
|
|||
|
|
api_key=api_key,
|
|||
|
|
model_name=model_name,
|
|||
|
|
**kwargs
|
|||
|
|
)
|