first commit
This commit is contained in:
483
retriver/langgraph/dashscope_llm.py
Normal file
483
retriver/langgraph/dashscope_llm.py
Normal file
@ -0,0 +1,483 @@
|
||||
"""
|
||||
阿里云DashScope原生LLM实现
|
||||
直接调用阿里云通义千问API,不通过OneAPI兼容层
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
||||
|
||||
|
||||
class DashScopeLLM(BaseLLM):
|
||||
"""
|
||||
阿里云DashScope原生LLM实现
|
||||
直接调用通义千问API
|
||||
"""
|
||||
|
||||
# Pydantic字段定义
|
||||
api_key: str = ""
|
||||
model_name: str = "qwen-turbo"
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
api_url: str = ""
|
||||
headers: dict = {}
|
||||
last_token_usage: dict = {}
|
||||
total_token_usage: dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
初始化DashScope LLM
|
||||
|
||||
Args:
|
||||
api_key: 阿里云DashScope API Key
|
||||
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
|
||||
max_retries: 最大重试次数
|
||||
retry_delay: 重试延迟时间(秒)
|
||||
"""
|
||||
# 先设置字段值
|
||||
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
||||
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
|
||||
|
||||
if not api_key_value:
|
||||
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
||||
|
||||
# DashScope API端点
|
||||
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
|
||||
# 设置请求头
|
||||
headers_value = {
|
||||
'Authorization': f'Bearer {api_key_value}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-DashScope-SSE': 'disable' # 禁用流式响应
|
||||
}
|
||||
|
||||
# Token使用统计
|
||||
last_token_usage_value = {}
|
||||
total_token_usage_value = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
'call_count': 0
|
||||
}
|
||||
|
||||
# 初始化父类,传递所有字段
|
||||
super().__init__(
|
||||
api_key=api_key_value,
|
||||
model_name=model_name_value,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
api_url=api_url_value,
|
||||
headers=headers_value,
|
||||
last_token_usage=last_token_usage_value,
|
||||
total_token_usage=total_token_usage_value,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
|
||||
self._test_connection()
|
||||
|
||||
def _test_connection(self):
|
||||
"""测试DashScope连接"""
|
||||
try:
|
||||
# 发送一个简单的测试请求
|
||||
test_payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=test_payload,
|
||||
timeout=(5, 15) # 短超时进行连接测试
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('output') and result['output'].get('text'):
|
||||
print("[OK] DashScope连接测试成功")
|
||||
else:
|
||||
print(f"[WARNING] DashScope响应格式异常: {result}")
|
||||
else:
|
||||
print(f"[WARNING] DashScope连接异常,状态码: {response.status_code}")
|
||||
print(f" 响应: {response.text[:200]}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("[WARNING] DashScope连接超时,但将在实际请求时重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("[WARNING] DashScope连接失败,请检查网络状态")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] DashScope连接测试出错: {e}")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "dashscope"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""生成响应"""
|
||||
generations = []
|
||||
|
||||
for prompt in prompts:
|
||||
try:
|
||||
response_text = self._call_api(prompt, **kwargs)
|
||||
generations.append([Generation(text=response_text)])
|
||||
except Exception as e:
|
||||
print(f"[ERROR] DashScope API调用失败: {e}")
|
||||
generations.append([Generation(text=f"API调用失败: {str(e)}")])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _call_api(self, prompt: str, **kwargs) -> str:
|
||||
"""调用DashScope API"""
|
||||
# 构建请求payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"max_tokens": kwargs.get("max_tokens", 2048),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"top_p": kwargs.get("top_p", 0.8),
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有stop参数,添加到payload中
|
||||
stop = kwargs.get("stop")
|
||||
if stop:
|
||||
payload["parameters"]["stop"] = stop
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=(30, 120) # 连接30秒,读取120秒
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
# 检查是否有错误
|
||||
if result.get('code'):
|
||||
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"DashScope API错误: {error_msg}")
|
||||
else:
|
||||
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
|
||||
# 提取响应内容
|
||||
if result.get('output') and result['output'].get('text'):
|
||||
content = result['output']['text']
|
||||
|
||||
# 提取Token使用信息
|
||||
usage_info = result.get('usage', {})
|
||||
self.last_token_usage = {
|
||||
'prompt_tokens': usage_info.get('input_tokens', 0),
|
||||
'completion_tokens': usage_info.get('output_tokens', 0),
|
||||
'total_tokens': usage_info.get('total_tokens', 0)
|
||||
}
|
||||
|
||||
# 更新累计统计
|
||||
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
||||
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
||||
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
||||
self.total_token_usage['call_count'] += 1
|
||||
|
||||
return content.strip()
|
||||
else:
|
||||
error_msg = f"API响应格式错误: {result}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
else:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n[WARNING] 用户中断请求")
|
||||
raise KeyboardInterrupt("用户中断请求")
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
error_msg = f"请求超时: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
||||
else:
|
||||
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
error_msg = f"连接错误: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
||||
else:
|
||||
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"网络请求异常: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
||||
else:
|
||||
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
raise RuntimeError("所有重试都失败了")
|
||||
|
||||
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
|
||||
"""
|
||||
生成响应,兼容原始接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表或批量消息列表
|
||||
max_new_tokens: 最大生成token数
|
||||
temperature: 温度参数
|
||||
response_format: 响应格式
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本或文本列表
|
||||
"""
|
||||
# 检查是否为批量处理
|
||||
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
|
||||
|
||||
if is_batch:
|
||||
# 批量处理
|
||||
results = []
|
||||
for msg_list in messages:
|
||||
try:
|
||||
# 构造prompt
|
||||
prompt = self._format_messages(msg_list)
|
||||
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"批量处理中的单个请求失败: {str(e)}")
|
||||
results.append("")
|
||||
return results
|
||||
else:
|
||||
# 单个处理
|
||||
prompt = self._format_messages(messages)
|
||||
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def _format_messages(self, messages):
|
||||
"""将消息列表格式化为单一的prompt"""
|
||||
formatted_parts = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted_parts.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted_parts.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted_parts.append(f"Assistant: {content}")
|
||||
return "\n\n".join(formatted_parts)
|
||||
|
||||
def ner(self, text: str) -> str:
|
||||
"""
|
||||
命名实体识别
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
提取的实体,用逗号分隔
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Extract the named entities from: {text}"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
|
||||
except Exception as e:
|
||||
print(f"NER失败: {str(e)}")
|
||||
return text # 如果失败,返回原文本
|
||||
|
||||
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
|
||||
"""
|
||||
基于实体和事件过滤三元组
|
||||
|
||||
Args:
|
||||
question: 查询问题
|
||||
triples: 三元组JSON字符串
|
||||
|
||||
Returns:
|
||||
过滤后的三元组JSON字符串
|
||||
"""
|
||||
import json
|
||||
import json_repair
|
||||
|
||||
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
|
||||
filter_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
|
||||
|
||||
关键要求:
|
||||
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
|
||||
2. 你的输出必须是输入事实的严格子集
|
||||
3. 只包含与回答问题直接相关的事实
|
||||
4. 如果不确定,宁可选择更少的事实,也不要选择更多
|
||||
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
|
||||
|
||||
过滤规则:
|
||||
- 只选择包含与问题直接相关的实体或关系的事实
|
||||
- 不能修改、改写或创建输入事实的变体
|
||||
- 不能添加看起来相关但不在输入中的事实
|
||||
- 输出事实数量必须 ≤ 输入事实数量
|
||||
|
||||
返回格式为包含"fact"键的JSON对象,值为选中的事实数组。
|
||||
|
||||
示例:
|
||||
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
|
||||
问题:A和B是什么关系?
|
||||
正确输出:{"fact": [["A", "关系1", "B"]]}
|
||||
错误做法:添加新事实或修改现有事实"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""问题:{question}
|
||||
|
||||
待筛选的输入事实:
|
||||
{triples}
|
||||
|
||||
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.generate_response(
|
||||
filter_messages,
|
||||
max_new_tokens=4096,
|
||||
temperature=0.0
|
||||
)
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
parsed_response = json.loads(response)
|
||||
if 'fact' in parsed_response:
|
||||
return json.dumps(parsed_response, ensure_ascii=False)
|
||||
else:
|
||||
# 如果没有fact字段,尝试修复
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试使用json_repair
|
||||
try:
|
||||
parsed_response = json_repair.loads(response)
|
||||
if 'fact' in parsed_response:
|
||||
return json.dumps(parsed_response, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
except:
|
||||
# 如果所有解析都失败,返回空结果
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"三元组过滤失败: {str(e)}")
|
||||
# 如果过滤失败,返回原始三元组
|
||||
return triples
|
||||
|
||||
def generate_with_context(self,
|
||||
question: str,
|
||||
context: str,
|
||||
max_new_tokens: int = 1024,
|
||||
temperature: float = 0.7) -> str:
|
||||
"""
|
||||
基于上下文生成回答
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
context: 上下文
|
||||
max_new_tokens: 最大生成token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
生成的回答
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{context}\n\n{question}\nThought:"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
return self.generate_response(messages, max_new_tokens, temperature)
|
||||
except Exception as e:
|
||||
print(f"基于上下文生成失败: {str(e)}")
|
||||
return "抱歉,我无法基于提供的上下文回答这个问题。"
|
||||
|
||||
def get_token_usage(self) -> Dict[str, Any]:
|
||||
"""获取Token使用统计"""
|
||||
return {
|
||||
"last_usage": self.last_token_usage.copy(),
|
||||
"total_usage": self.total_token_usage.copy(),
|
||||
"model_name": self.model_name
|
||||
}
|
||||
|
||||
|
||||
def create_dashscope_llm(
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> DashScopeLLM:
|
||||
"""创建DashScope LLM实例的便捷函数"""
|
||||
return DashScopeLLM(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
**kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user