""" 阿里云DashScope嵌入模型实现 直接调用阿里云文本嵌入API,不通过OneAPI兼容层 """ import os import sys import json import time import requests import numpy as np from typing import List, Union, Optional from dotenv import load_dotenv # 添加项目根目录到路径 project_root = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.append(project_root) from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel # 加载环境变量 load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env')) class DashScopeEmbeddingModel(BaseEmbeddingModel): """ 阿里云DashScope嵌入模型实现 直接调用DashScope文本嵌入API """ def __init__( self, api_key: Optional[str] = None, model_name: Optional[str] = None, max_retries: int = 3, retry_delay: float = 1.0 ): """ 初始化DashScope嵌入模型 Args: api_key: 阿里云DashScope API Key model_name: 嵌入模型名称,如 text-embedding-v3 max_retries: 最大重试次数 retry_delay: 重试延迟时间(秒) """ # 初始化父类,传入None作为sentence_encoder(我们不使用它) super().__init__(sentence_encoder=None) self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量 self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3') self.max_retries = max_retries self.retry_delay = retry_delay if not self.api_key: raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY") # DashScope文本嵌入API端点 self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" # 设置请求头 self.headers = { 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json' } print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}") self._test_connection() def encode(self, texts, **kwargs): """ 实现基类的抽象方法 对文本进行编码,返回嵌入向量 Args: texts: 文本或文本列表 **kwargs: 其他参数(batch_size, show_progress_bar, query_type等) Returns: 嵌入向量或嵌入向量列表(numpy数组格式) """ import numpy as np if isinstance(texts, str): # 单个文本 result = self.embed_query(texts) return np.array(result) else: # 文本列表 results = self.embed_texts(texts) return np.array(results) def _test_connection(self): """测试DashScope嵌入API连接""" try: # 发送一个简单的测试请求 test_payload = { "model": self.model_name, "input": { "texts": ["test"] } } response = requests.post( self.api_url, headers=self.headers, json=test_payload, timeout=(5, 15) # 短超时进行连接测试 ) if response.status_code == 200: result = response.json() if result.get('output') and result['output'].get('embeddings'): print("[OK] DashScope嵌入API连接测试成功") else: print(f"[WARNING] DashScope嵌入API响应格式异常: {result}") else: print(f"[WARNING] DashScope嵌入API连接异常,状态码: {response.status_code}") print(f" 响应: {response.text[:200]}") except requests.exceptions.Timeout: print("[WARNING] DashScope嵌入API连接超时,但将在实际请求时重试") except requests.exceptions.ConnectionError: print("[WARNING] DashScope嵌入API连接失败,请检查网络状态") except Exception as e: print(f"[WARNING] DashScope嵌入API连接测试出错: {e}") def embed_texts(self, texts: List[str]) -> List[List[float]]: """ 对文本列表进行嵌入 Args: texts: 文本列表 Returns: 嵌入向量列表 """ import numpy as np if not texts: return [] # DashScope API支持批量处理,但需要分批处理大量文本 batch_size = 10 # 每批处理10个文本 all_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] batch_embeddings = self._embed_batch(batch_texts) all_embeddings.extend(batch_embeddings) return all_embeddings def _embed_batch(self, texts: List[str]) -> List[List[float]]: """对一批文本进行嵌入""" payload = { "model": self.model_name, "input": { "texts": texts } } for attempt in range(self.max_retries): try: response = requests.post( self.api_url, headers=self.headers, json=payload, timeout=(30, 120) # 连接30秒,读取120秒 ) if response.status_code == 200: result = response.json() # 检查是否有错误 if result.get('code'): error_msg = result.get('message', f'API错误代码: {result["code"]}') if attempt == self.max_retries - 1: raise RuntimeError(f"DashScope嵌入API错误: {error_msg}") else: print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) continue # 提取嵌入向量 if result.get('output') and result['output'].get('embeddings'): embeddings_data = result['output']['embeddings'] # 提取向量数据 embeddings = [] for embedding_item in embeddings_data: if isinstance(embedding_item, dict) and 'embedding' in embedding_item: embeddings.append(embedding_item['embedding']) elif isinstance(embedding_item, list): embeddings.append(embedding_item) else: print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}") return embeddings else: error_msg = f"API响应格式错误: {result}" if attempt == self.max_retries - 1: raise RuntimeError(error_msg) else: print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})") time.sleep(self.retry_delay * (attempt + 1)) else: error_text = response.text[:500] if response.text else "无响应内容" error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}" if attempt == self.max_retries - 1: raise RuntimeError(error_msg) else: print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}") time.sleep(self.retry_delay * (attempt + 1)) except KeyboardInterrupt: print(f"\n[WARNING] 用户中断请求") raise KeyboardInterrupt("用户中断请求") except requests.exceptions.Timeout as e: error_msg = f"请求超时: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}") else: print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) except requests.exceptions.ConnectionError as e: error_msg = f"连接错误: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}") else: print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}") time.sleep(self.retry_delay * (attempt + 1)) except requests.RequestException as e: error_msg = f"网络请求异常: {str(e)}" if attempt == self.max_retries - 1: raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}") else: print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}") time.sleep(self.retry_delay * (attempt + 1)) raise RuntimeError("所有重试都失败了") def embed_query(self, text: str) -> List[float]: """ 对单个查询文本进行嵌入 Args: text: 查询文本 Returns: 嵌入向量 """ embeddings = self.embed_texts([text]) return embeddings[0] if embeddings else [] def create_dashscope_embedding_model( api_key: Optional[str] = None, model_name: Optional[str] = None, **kwargs ) -> DashScopeEmbeddingModel: """创建DashScope嵌入模型实例的便捷函数""" return DashScopeEmbeddingModel( api_key=api_key, model_name=model_name, **kwargs )