605 lines
23 KiB
Python
605 lines
23 KiB
Python
"""
|
||
阿里云DashScope原生LLM实现
|
||
直接调用阿里云通义千问API,不通过OneAPI兼容层
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import requests
|
||
from typing import List, Dict, Any, Optional, Union, Iterator, Callable
|
||
from langchain_core.language_models import BaseLLM
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
from langchain_core.outputs import LLMResult, Generation
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
||
|
||
|
||
class DashScopeLLM(BaseLLM):
|
||
"""
|
||
阿里云DashScope原生LLM实现
|
||
直接调用通义千问API
|
||
"""
|
||
|
||
# Pydantic字段定义
|
||
api_key: str = ""
|
||
model_name: str = "qwen-turbo"
|
||
max_retries: int = 3
|
||
retry_delay: float = 1.0
|
||
api_url: str = ""
|
||
headers: dict = {}
|
||
last_token_usage: dict = {}
|
||
total_token_usage: dict = {}
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
model_name: Optional[str] = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
**kwargs
|
||
):
|
||
"""
|
||
初始化DashScope LLM
|
||
|
||
Args:
|
||
api_key: 阿里云DashScope API Key
|
||
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
|
||
max_retries: 最大重试次数
|
||
retry_delay: 重试延迟时间(秒)
|
||
"""
|
||
# 先设置字段值
|
||
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
||
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
|
||
|
||
if not api_key_value:
|
||
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
||
|
||
# DashScope API端点
|
||
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||
|
||
# 设置请求头
|
||
headers_value = {
|
||
'Authorization': f'Bearer {api_key_value}',
|
||
'Content-Type': 'application/json',
|
||
'X-DashScope-SSE': 'disable' # 禁用流式响应
|
||
}
|
||
|
||
# Token使用统计
|
||
last_token_usage_value = {}
|
||
total_token_usage_value = {
|
||
'prompt_tokens': 0,
|
||
'completion_tokens': 0,
|
||
'total_tokens': 0,
|
||
'call_count': 0
|
||
}
|
||
|
||
# 初始化父类,传递所有字段
|
||
super().__init__(
|
||
api_key=api_key_value,
|
||
model_name=model_name_value,
|
||
max_retries=max_retries,
|
||
retry_delay=retry_delay,
|
||
api_url=api_url_value,
|
||
headers=headers_value,
|
||
last_token_usage=last_token_usage_value,
|
||
total_token_usage=total_token_usage_value,
|
||
**kwargs
|
||
)
|
||
|
||
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
|
||
self._test_connection()
|
||
|
||
def _test_connection(self):
|
||
"""测试DashScope连接"""
|
||
try:
|
||
# 发送一个简单的测试请求
|
||
test_payload = {
|
||
"model": self.model_name,
|
||
"input": {
|
||
"messages": [
|
||
{"role": "user", "content": "hello"}
|
||
]
|
||
},
|
||
"parameters": {
|
||
"max_tokens": 10,
|
||
"temperature": 0.1
|
||
}
|
||
}
|
||
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=self.headers,
|
||
json=test_payload,
|
||
timeout=(5, 15) # 短超时进行连接测试
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
if result.get('output') and result['output'].get('text'):
|
||
print("[OK] DashScope连接测试成功")
|
||
else:
|
||
print(f"[WARNING] DashScope响应格式异常: {result}")
|
||
else:
|
||
print(f"[WARNING] DashScope连接异常,状态码: {response.status_code}")
|
||
print(f" 响应: {response.text[:200]}")
|
||
|
||
except requests.exceptions.Timeout:
|
||
print("[WARNING] DashScope连接超时,但将在实际请求时重试")
|
||
except requests.exceptions.ConnectionError:
|
||
print("[WARNING] DashScope连接失败,请检查网络状态")
|
||
except Exception as e:
|
||
print(f"[WARNING] DashScope连接测试出错: {e}")
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
return "dashscope"
|
||
|
||
def _generate(
|
||
self,
|
||
prompts: List[str],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> LLMResult:
|
||
"""生成响应"""
|
||
generations = []
|
||
|
||
for prompt in prompts:
|
||
try:
|
||
response_text = self._call_api(prompt, **kwargs)
|
||
generations.append([Generation(text=response_text)])
|
||
except Exception as e:
|
||
print(f"[ERROR] DashScope API调用失败: {e}")
|
||
generations.append([Generation(text=f"API调用失败: {str(e)}")])
|
||
|
||
return LLMResult(generations=generations)
|
||
|
||
def _call_api(self, prompt: str, **kwargs) -> str:
|
||
"""调用DashScope API"""
|
||
# 构建请求payload
|
||
payload = {
|
||
"model": self.model_name,
|
||
"input": {
|
||
"messages": [
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
},
|
||
"parameters": {
|
||
"max_tokens": kwargs.get("max_tokens", 2048),
|
||
"temperature": kwargs.get("temperature", 0.7),
|
||
"top_p": kwargs.get("top_p", 0.8),
|
||
}
|
||
}
|
||
|
||
# 如果有stop参数,添加到payload中
|
||
stop = kwargs.get("stop")
|
||
if stop:
|
||
payload["parameters"]["stop"] = stop
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=self.headers,
|
||
json=payload,
|
||
timeout=(30, 120) # 连接30秒,读取120秒
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
|
||
# 检查是否有错误
|
||
if result.get('code'):
|
||
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(f"DashScope API错误: {error_msg}")
|
||
else:
|
||
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
continue
|
||
|
||
# 提取响应内容
|
||
if result.get('output') and result['output'].get('text'):
|
||
content = result['output']['text']
|
||
|
||
# 提取Token使用信息
|
||
usage_info = result.get('usage', {})
|
||
self.last_token_usage = {
|
||
'prompt_tokens': usage_info.get('input_tokens', 0),
|
||
'completion_tokens': usage_info.get('output_tokens', 0),
|
||
'total_tokens': usage_info.get('total_tokens', 0)
|
||
}
|
||
|
||
# 更新累计统计
|
||
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
||
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
||
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
||
self.total_token_usage['call_count'] += 1
|
||
|
||
return content.strip()
|
||
else:
|
||
error_msg = f"API响应格式错误: {result}"
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(error_msg)
|
||
else:
|
||
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
else:
|
||
error_text = response.text[:500] if response.text else "无响应内容"
|
||
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(error_msg)
|
||
else:
|
||
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
except KeyboardInterrupt:
|
||
print(f"\n[WARNING] 用户中断请求")
|
||
raise KeyboardInterrupt("用户中断请求")
|
||
|
||
except requests.exceptions.Timeout as e:
|
||
error_msg = f"请求超时: {str(e)}"
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
||
else:
|
||
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
except requests.exceptions.ConnectionError as e:
|
||
error_msg = f"连接错误: {str(e)}"
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
||
else:
|
||
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
except requests.RequestException as e:
|
||
error_msg = f"网络请求异常: {str(e)}"
|
||
if attempt == self.max_retries - 1:
|
||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
||
else:
|
||
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
raise RuntimeError("所有重试都失败了")
|
||
|
||
|
||
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
|
||
"""
|
||
生成响应,兼容原始接口
|
||
|
||
Args:
|
||
messages: 消息列表或批量消息列表
|
||
max_new_tokens: 最大生成token数
|
||
temperature: 温度参数
|
||
response_format: 响应格式
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
生成的文本或文本列表
|
||
"""
|
||
# 检查是否为批量处理
|
||
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
|
||
|
||
if is_batch:
|
||
# 批量处理
|
||
results = []
|
||
for msg_list in messages:
|
||
try:
|
||
# 构造prompt
|
||
prompt = self._format_messages(msg_list)
|
||
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"批量处理中的单个请求失败: {str(e)}")
|
||
results.append("")
|
||
return results
|
||
else:
|
||
# 单个处理
|
||
prompt = self._format_messages(messages)
|
||
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||
|
||
def _format_messages(self, messages):
|
||
"""将消息列表格式化为单一的prompt"""
|
||
formatted_parts = []
|
||
for msg in messages:
|
||
role = msg.get('role', 'user')
|
||
content = msg.get('content', '')
|
||
if role == 'system':
|
||
formatted_parts.append(f"System: {content}")
|
||
elif role == 'user':
|
||
formatted_parts.append(f"User: {content}")
|
||
elif role == 'assistant':
|
||
formatted_parts.append(f"Assistant: {content}")
|
||
return "\n\n".join(formatted_parts)
|
||
|
||
def ner(self, text: str) -> str:
|
||
"""
|
||
命名实体识别
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
提取的实体,用逗号分隔
|
||
"""
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"Extract the named entities from: {text}"
|
||
}
|
||
]
|
||
|
||
try:
|
||
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
|
||
except Exception as e:
|
||
print(f"NER失败: {str(e)}")
|
||
return text # 如果失败,返回原文本
|
||
|
||
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
|
||
"""
|
||
基于实体和事件过滤三元组
|
||
|
||
Args:
|
||
question: 查询问题
|
||
triples: 三元组JSON字符串
|
||
|
||
Returns:
|
||
过滤后的三元组JSON字符串
|
||
"""
|
||
import json
|
||
import json_repair
|
||
|
||
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
|
||
filter_messages = [
|
||
{
|
||
"role": "system",
|
||
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
|
||
|
||
关键要求:
|
||
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
|
||
2. 你的输出必须是输入事实的严格子集
|
||
3. 只包含与回答问题直接相关的事实
|
||
4. 如果不确定,宁可选择更少的事实,也不要选择更多
|
||
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
|
||
|
||
过滤规则:
|
||
- 只选择包含与问题直接相关的实体或关系的事实
|
||
- 不能修改、改写或创建输入事实的变体
|
||
- 不能添加看起来相关但不在输入中的事实
|
||
- 输出事实数量必须 ≤ 输入事实数量
|
||
|
||
返回格式为包含"fact"键的JSON对象,值为选中的事实数组。
|
||
|
||
示例:
|
||
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
|
||
问题:A和B是什么关系?
|
||
正确输出:{"fact": [["A", "关系1", "B"]]}
|
||
错误做法:添加新事实或修改现有事实"""
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"""问题:{question}
|
||
|
||
待筛选的输入事实:
|
||
{triples}
|
||
|
||
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
|
||
}
|
||
]
|
||
|
||
try:
|
||
response = self.generate_response(
|
||
filter_messages,
|
||
max_new_tokens=4096,
|
||
temperature=0.0
|
||
)
|
||
|
||
# 尝试解析JSON响应
|
||
try:
|
||
parsed_response = json.loads(response)
|
||
if 'fact' in parsed_response:
|
||
return json.dumps(parsed_response, ensure_ascii=False)
|
||
else:
|
||
# 如果没有fact字段,尝试修复
|
||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||
except json.JSONDecodeError:
|
||
# 如果JSON解析失败,尝试使用json_repair
|
||
try:
|
||
parsed_response = json_repair.loads(response)
|
||
if 'fact' in parsed_response:
|
||
return json.dumps(parsed_response, ensure_ascii=False)
|
||
else:
|
||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||
except:
|
||
# 如果所有解析都失败,返回空结果
|
||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||
|
||
except Exception as e:
|
||
print(f"三元组过滤失败: {str(e)}")
|
||
# 如果过滤失败,返回原始三元组
|
||
return triples
|
||
|
||
def generate_with_context(self,
|
||
question: str,
|
||
context: str,
|
||
max_new_tokens: int = 1024,
|
||
temperature: float = 0.7) -> str:
|
||
"""
|
||
基于上下文生成回答
|
||
|
||
Args:
|
||
question: 问题
|
||
context: 上下文
|
||
max_new_tokens: 最大生成token数
|
||
temperature: 温度参数
|
||
|
||
Returns:
|
||
生成的回答
|
||
"""
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"{context}\n\n{question}\nThought:"
|
||
}
|
||
]
|
||
|
||
try:
|
||
return self.generate_response(messages, max_new_tokens, temperature)
|
||
except Exception as e:
|
||
print(f"基于上下文生成失败: {str(e)}")
|
||
return "抱歉,我无法基于提供的上下文回答这个问题。"
|
||
|
||
def get_token_usage(self) -> Dict[str, Any]:
|
||
"""获取Token使用统计"""
|
||
return {
|
||
"last_usage": self.last_token_usage.copy(),
|
||
"total_usage": self.total_token_usage.copy(),
|
||
"model_name": self.model_name
|
||
}
|
||
|
||
def stream_generate(
|
||
self,
|
||
prompt: str,
|
||
stream_callback: Optional[Callable[[str], None]] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
流式生成响应 - 使用OpenAI兼容接口
|
||
|
||
Args:
|
||
prompt: 输入提示词
|
||
stream_callback: 流式回调函数,接收每个token chunk
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
完整的生成文本
|
||
"""
|
||
try:
|
||
# 使用OpenAI Python SDK进行流式调用
|
||
from openai import OpenAI
|
||
|
||
# 创建OpenAI兼容客户端
|
||
client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
# 发起流式请求
|
||
stream = client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
stream=True,
|
||
max_tokens=kwargs.get("max_tokens", 2048),
|
||
temperature=kwargs.get("temperature", 0.7),
|
||
top_p=kwargs.get("top_p", 0.8),
|
||
stream_options={"include_usage": True}
|
||
)
|
||
|
||
full_text = ""
|
||
chunk_count = 0
|
||
print(f"[DEBUG] 开始接收OpenAI兼容流式响应...")
|
||
|
||
for chunk in stream:
|
||
# 提取增量内容
|
||
if chunk.choices and len(chunk.choices) > 0:
|
||
delta = chunk.choices[0].delta
|
||
if delta and delta.content:
|
||
text_chunk = delta.content
|
||
full_text += text_chunk
|
||
chunk_count += 1
|
||
|
||
if stream_callback:
|
||
stream_callback(text_chunk)
|
||
|
||
# 检查是否包含使用信息
|
||
if hasattr(chunk, 'usage') and chunk.usage:
|
||
self._update_token_usage({
|
||
'input_tokens': chunk.usage.prompt_tokens,
|
||
'output_tokens': chunk.usage.completion_tokens,
|
||
'total_tokens': chunk.usage.total_tokens
|
||
})
|
||
|
||
print(f"[DEBUG] 流式结束,共收到 {chunk_count} 个chunk,总长度 {len(full_text)} 字符")
|
||
return full_text.strip()
|
||
|
||
except ImportError:
|
||
print("[WARNING] OpenAI SDK未安装,降级到非流式")
|
||
return self._call_api(prompt, **kwargs)
|
||
except Exception as e:
|
||
print(f"[WARNING] OpenAI兼容流式失败: {e},降级到非流式")
|
||
return self._call_api(prompt, **kwargs)
|
||
|
||
def _update_token_usage(self, usage_info: Dict[str, Any]):
|
||
"""更新Token使用统计"""
|
||
self.last_token_usage = {
|
||
'prompt_tokens': usage_info.get('input_tokens', 0),
|
||
'completion_tokens': usage_info.get('output_tokens', 0),
|
||
'total_tokens': usage_info.get('total_tokens', 0)
|
||
}
|
||
|
||
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
||
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
||
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
||
self.total_token_usage['call_count'] += 1
|
||
|
||
def invoke(
|
||
self,
|
||
input: Union[str, List[str]],
|
||
config: Optional[dict] = None,
|
||
**kwargs
|
||
) -> Union[str, LLMResult]:
|
||
"""
|
||
统一的调用接口,支持流式和非流式
|
||
|
||
Args:
|
||
input: 输入文本或文本列表
|
||
config: 配置字典,包含stream_callback等
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
生成的文本或LLMResult对象
|
||
"""
|
||
# 从config中提取流式回调
|
||
stream_callback = None
|
||
if config and config.get('metadata', {}).get('stream_callback'):
|
||
stream_callback = config['metadata']['stream_callback']
|
||
|
||
# 如果是字符串输入
|
||
if isinstance(input, str):
|
||
if stream_callback:
|
||
# 使用流式生成
|
||
return self.stream_generate(input, stream_callback, **kwargs)
|
||
else:
|
||
# 使用普通生成
|
||
return self._call_api(input, **kwargs)
|
||
|
||
# 如果是列表输入(批处理),调用原有的_generate
|
||
return self._generate(input, **kwargs)
|
||
|
||
|
||
def create_dashscope_llm(
|
||
api_key: Optional[str] = None,
|
||
model_name: Optional[str] = None,
|
||
**kwargs
|
||
) -> DashScopeLLM:
|
||
"""创建DashScope LLM实例的便捷函数"""
|
||
return DashScopeLLM(
|
||
api_key=api_key,
|
||
model_name=model_name,
|
||
**kwargs
|
||
) |