""" 阿里云DashScope原生LLM实现 直接调用阿里云通义千问API,不通过OneAPI兼容层 """ import os import sys import json import time import requests from typing import List, Dict, Any, Optional, Union from langchain_core.language_models import BaseLLM from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.outputs import LLMResult, Generation from dotenv import load_dotenv # 加载环境变量 load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env')) class DashScopeLLM(BaseLLM): """ 阿里云DashScope原生LLM实现 直接调用通义千问API """ # Pydantic字段定义 api_key: str = "" model_name: str = "qwen-turbo" max_retries: int = 3 retry_delay: float = 1.0 api_url: str = "" headers: dict = {} last_token_usage: dict = {} total_token_usage: dict = {} def __init__( self, api_key: Optional[str] = None, model_name: Optional[str] = None, max_retries: int = 3, retry_delay: float = 1.0, **kwargs ): """ 初始化DashScope LLM Args: api_key: 阿里云DashScope API Key model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max max_retries: 最大重试次数 retry_delay: 重试延迟时间(秒) """ # 先设置字段值 api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量 model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo') if not api_key_value: raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY") # DashScope API端点 api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" # 设置请求头 headers_value = { 'Authorization': f'Bearer {api_key_value}', 'Content-Type': 'application/json', 'X-DashScope-SSE': 'disable' # 禁用流式响应 } # Token使用统计 last_token_usage_value = {} total_token_usage_value = { 'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0, 'call_count': 0 } # 初始化父类,传递所有字段 super().__init__( api_key=api_key_value, model_name=model_name_value, max_retries=max_retries, retry_delay=retry_delay, api_url=api_url_value, headers=headers_value, last_token_usage=last_token_usage_value, total_token_usage=total_token_usage_value, **kwargs ) print(f"[OK] DashScope LLM初始化完成: {self.model_name}") self._test_connection() def _test_connection(self): """测试DashScope连接""" try: # 发送一个简单的测试请求 test_payload = { "model": self.model_name, "input": { "messages": [ {"role": "user", "content": "hello"} ] }, "parameters": { "max_tokens": 10, "temperature": 0.1 } } response = requests.post( self.api_url, headers=self.headers, json=test_payload, timeout=(5, 15) # 短超时进行连接测试 ) if response.status_code == 200: result = response.json() if result.get('output') and result['output'].get('text'): print("[OK] DashScope连接测试成功") else: print(f"[WARNING] DashScope响应格式异常: {result}") else: print(f"[WARNING] DashScope连接异常,状态码: {response.status_code}") print(f" 响应: {response.text[:200]}") except requests.exceptions.Timeout: print("[WARNING] DashScope连接超时,但将在实际请求时重试") except requests.exceptions.ConnectionError: print("[WARNING] DashScope连接失败,请检查网络状态") except Exception as e: print(f"[WARNING] DashScope连接测试出错: {e}") @property def _llm_type(self) -> str: return "dashscope" def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """生成响应""" generations = [] for prompt in prompts: try: response_text = self._call_api(prompt, **kwargs) generations.append([Generation(text=response_text)]) except Exception as e: print(f"[ERROR] DashScope API调用失败: {e}") generations.append([Generation(text=f"API调用失败: {str(e)}")]) return LLMResult(generations=generations) def _call_api(self, prompt: str, **kwargs) -> str: """调用DashScope API""" # 构建请求payload payload = { "model": self.model_name, "input": { "messages": [ {"role": "user", "content": prompt} ] }, "parameters": { "max_tokens": kwargs.get("max_tokens", 2048), "temperature": kwargs.get("temperature", 0.7), "top_p": kwargs.get("top_p", 0.8), } } # 如果有stop参数,添加到payload中 stop = kwargs.get("stop") if stop: payload["parameters"]["stop"] = stop for attempt in range(self.max_retries): try: response = requests.post( self.api_url, headers=self.headers, json=payload, timeout=(30, 120) # 连接30秒,读取120秒 ) if response.status_code == 200: result = response.json() # 检查是否有错误 if result.get('code'): error_msg = result.get('message', f'API错误代码: {result["code"]}') if attempt == self.max_retries - 1: raise RuntimeError(f"DashScope API错误: {error_msg}") else: print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) continue # 提取响应内容 if result.get('output') and result['output'].get('text'): content = result['output']['text'] # 提取Token使用信息 usage_info = result.get('usage', {}) self.last_token_usage = { 'prompt_tokens': usage_info.get('input_tokens', 0), 'completion_tokens': usage_info.get('output_tokens', 0), 'total_tokens': usage_info.get('total_tokens', 0) } # 更新累计统计 self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens'] self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens'] self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens'] self.total_token_usage['call_count'] += 1 return content.strip() else: error_msg = f"API响应格式错误: {result}" if attempt == self.max_retries - 1: raise RuntimeError(error_msg) else: print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})") time.sleep(self.retry_delay * (attempt + 1)) else: error_text = response.text[:500] if response.text else "无响应内容" error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}" if attempt == self.max_retries - 1: raise RuntimeError(error_msg) else: print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}") time.sleep(self.retry_delay * (attempt + 1)) except KeyboardInterrupt: print(f"\n[WARNING] 用户中断请求") raise KeyboardInterrupt("用户中断请求") except requests.exceptions.Timeout as e: error_msg = f"请求超时: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}") else: print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) except requests.exceptions.ConnectionError as e: error_msg = f"连接错误: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}") else: print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) except requests.RequestException as e: error_msg = f"网络请求异常: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}") else: print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}") time.sleep(self.retry_delay * (attempt + 1)) raise RuntimeError("所有重试都失败了") def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs): """ 生成响应,兼容原始接口 Args: messages: 消息列表或批量消息列表 max_new_tokens: 最大生成token数 temperature: 温度参数 response_format: 响应格式 **kwargs: 其他参数 Returns: 生成的文本或文本列表 """ # 检查是否为批量处理 is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list) if is_batch: # 批量处理 results = [] for msg_list in messages: try: # 构造prompt prompt = self._format_messages(msg_list) result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature) results.append(result) except Exception as e: print(f"批量处理中的单个请求失败: {str(e)}") results.append("") return results else: # 单个处理 prompt = self._format_messages(messages) return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature) def _format_messages(self, messages): """将消息列表格式化为单一的prompt""" formatted_parts = [] for msg in messages: role = msg.get('role', 'user') content = msg.get('content', '') if role == 'system': formatted_parts.append(f"System: {content}") elif role == 'user': formatted_parts.append(f"User: {content}") elif role == 'assistant': formatted_parts.append(f"Assistant: {content}") return "\n\n".join(formatted_parts) def ner(self, text: str) -> str: """ 命名实体识别 Args: text: 输入文本 Returns: 提取的实体,用逗号分隔 """ messages = [ { "role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..." }, { "role": "user", "content": f"Extract the named entities from: {text}" } ] try: return self.generate_response(messages, max_new_tokens=1024, temperature=0.7) except Exception as e: print(f"NER失败: {str(e)}") return text # 如果失败,返回原文本 def filter_triples_with_entity_event(self, question: str, triples: str) -> str: """ 基于实体和事件过滤三元组 Args: question: 查询问题 triples: 三元组JSON字符串 Returns: 过滤后的三元组JSON字符串 """ import json import json_repair # 改进的过滤提示词 - 中文版本,强调严格子集过滤 filter_messages = [ { "role": "system", "content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。 关键要求: 1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实 2. 你的输出必须是输入事实的严格子集 3. 只包含与回答问题直接相关的事实 4. 如果不确定,宁可选择更少的事实,也不要选择更多 5. 保持每个事实的准确格式:[主语, 关系, 宾语] 过滤规则: - 只选择包含与问题直接相关的实体或关系的事实 - 不能修改、改写或创建输入事实的变体 - 不能添加看起来相关但不在输入中的事实 - 输出事实数量必须 ≤ 输入事实数量 返回格式为包含"fact"键的JSON对象,值为选中的事实数组。 示例: 输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]] 问题:A和B是什么关系? 正确输出:{"fact": [["A", "关系1", "B"]]} 错误做法:添加新事实或修改现有事实""" }, { "role": "user", "content": f"""问题:{question} 待筛选的输入事实: {triples} 从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!""" } ] try: response = self.generate_response( filter_messages, max_new_tokens=4096, temperature=0.0 ) # 尝试解析JSON响应 try: parsed_response = json.loads(response) if 'fact' in parsed_response: return json.dumps(parsed_response, ensure_ascii=False) else: # 如果没有fact字段,尝试修复 return json.dumps({"fact": []}, ensure_ascii=False) except json.JSONDecodeError: # 如果JSON解析失败,尝试使用json_repair try: parsed_response = json_repair.loads(response) if 'fact' in parsed_response: return json.dumps(parsed_response, ensure_ascii=False) else: return json.dumps({"fact": []}, ensure_ascii=False) except: # 如果所有解析都失败,返回空结果 return json.dumps({"fact": []}, ensure_ascii=False) except Exception as e: print(f"三元组过滤失败: {str(e)}") # 如果过滤失败,返回原始三元组 return triples def generate_with_context(self, question: str, context: str, max_new_tokens: int = 1024, temperature: float = 0.7) -> str: """ 基于上下文生成回答 Args: question: 问题 context: 上下文 max_new_tokens: 最大生成token数 temperature: 温度参数 Returns: 生成的回答 """ messages = [ { "role": "system", "content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step." }, { "role": "user", "content": f"{context}\n\n{question}\nThought:" } ] try: return self.generate_response(messages, max_new_tokens, temperature) except Exception as e: print(f"基于上下文生成失败: {str(e)}") return "抱歉,我无法基于提供的上下文回答这个问题。" def get_token_usage(self) -> Dict[str, Any]: """获取Token使用统计""" return { "last_usage": self.last_token_usage.copy(), "total_usage": self.total_token_usage.copy(), "model_name": self.model_name } def create_dashscope_llm( api_key: Optional[str] = None, model_name: Optional[str] = None, **kwargs ) -> DashScopeLLM: """创建DashScope LLM实例的便捷函数""" return DashScopeLLM( api_key=api_key, model_name=model_name, **kwargs )