269 lines
10 KiB
Python
269 lines
10 KiB
Python
|
|
"""
|
|||
|
|
阿里云DashScope嵌入模型实现
|
|||
|
|
直接调用阿里云文本嵌入API,不通过OneAPI兼容层
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import requests
|
|||
|
|
import numpy as np
|
|||
|
|
from typing import List, Union, Optional
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# 添加项目根目录到路径
|
|||
|
|
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
|||
|
|
sys.path.append(project_root)
|
|||
|
|
|
|||
|
|
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DashScopeEmbeddingModel(BaseEmbeddingModel):
|
|||
|
|
"""
|
|||
|
|
阿里云DashScope嵌入模型实现
|
|||
|
|
直接调用DashScope文本嵌入API
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
model_name: Optional[str] = None,
|
|||
|
|
max_retries: int = 3,
|
|||
|
|
retry_delay: float = 1.0
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化DashScope嵌入模型
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: 阿里云DashScope API Key
|
|||
|
|
model_name: 嵌入模型名称,如 text-embedding-v3
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
retry_delay: 重试延迟时间(秒)
|
|||
|
|
"""
|
|||
|
|
# 初始化父类,传入None作为sentence_encoder(我们不使用它)
|
|||
|
|
super().__init__(sentence_encoder=None)
|
|||
|
|
|
|||
|
|
self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
|||
|
|
self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3')
|
|||
|
|
self.max_retries = max_retries
|
|||
|
|
self.retry_delay = retry_delay
|
|||
|
|
|
|||
|
|
if not self.api_key:
|
|||
|
|
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
|||
|
|
|
|||
|
|
# DashScope文本嵌入API端点
|
|||
|
|
self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
|||
|
|
|
|||
|
|
# 设置请求头
|
|||
|
|
self.headers = {
|
|||
|
|
'Authorization': f'Bearer {self.api_key}',
|
|||
|
|
'Content-Type': 'application/json'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}")
|
|||
|
|
self._test_connection()
|
|||
|
|
|
|||
|
|
def encode(self, texts, **kwargs):
|
|||
|
|
"""
|
|||
|
|
实现基类的抽象方法
|
|||
|
|
对文本进行编码,返回嵌入向量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
texts: 文本或文本列表
|
|||
|
|
**kwargs: 其他参数(batch_size, show_progress_bar, query_type等)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
嵌入向量或嵌入向量列表(numpy数组格式)
|
|||
|
|
"""
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
if isinstance(texts, str):
|
|||
|
|
# 单个文本
|
|||
|
|
result = self.embed_query(texts)
|
|||
|
|
return np.array(result)
|
|||
|
|
else:
|
|||
|
|
# 文本列表
|
|||
|
|
results = self.embed_texts(texts)
|
|||
|
|
return np.array(results)
|
|||
|
|
|
|||
|
|
def _test_connection(self):
|
|||
|
|
"""测试DashScope嵌入API连接"""
|
|||
|
|
try:
|
|||
|
|
# 发送一个简单的测试请求
|
|||
|
|
test_payload = {
|
|||
|
|
"model": self.model_name,
|
|||
|
|
"input": {
|
|||
|
|
"texts": ["test"]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response = requests.post(
|
|||
|
|
self.api_url,
|
|||
|
|
headers=self.headers,
|
|||
|
|
json=test_payload,
|
|||
|
|
timeout=(5, 15) # 短超时进行连接测试
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
if result.get('output') and result['output'].get('embeddings'):
|
|||
|
|
print("[OK] DashScope嵌入API连接测试成功")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] DashScope嵌入API响应格式异常: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] DashScope嵌入API连接异常,状态码: {response.status_code}")
|
|||
|
|
print(f" 响应: {response.text[:200]}")
|
|||
|
|
|
|||
|
|
except requests.exceptions.Timeout:
|
|||
|
|
print("[WARNING] DashScope嵌入API连接超时,但将在实际请求时重试")
|
|||
|
|
except requests.exceptions.ConnectionError:
|
|||
|
|
print("[WARNING] DashScope嵌入API连接失败,请检查网络状态")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] DashScope嵌入API连接测试出错: {e}")
|
|||
|
|
|
|||
|
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|||
|
|
"""
|
|||
|
|
对文本列表进行嵌入
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
texts: 文本列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
嵌入向量列表
|
|||
|
|
"""
|
|||
|
|
import numpy as np
|
|||
|
|
if not texts:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# DashScope API支持批量处理,但需要分批处理大量文本
|
|||
|
|
batch_size = 10 # 每批处理10个文本
|
|||
|
|
all_embeddings = []
|
|||
|
|
|
|||
|
|
for i in range(0, len(texts), batch_size):
|
|||
|
|
batch_texts = texts[i:i + batch_size]
|
|||
|
|
batch_embeddings = self._embed_batch(batch_texts)
|
|||
|
|
all_embeddings.extend(batch_embeddings)
|
|||
|
|
|
|||
|
|
return all_embeddings
|
|||
|
|
|
|||
|
|
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|||
|
|
"""对一批文本进行嵌入"""
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model_name,
|
|||
|
|
"input": {
|
|||
|
|
"texts": texts
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for attempt in range(self.max_retries):
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
self.api_url,
|
|||
|
|
headers=self.headers,
|
|||
|
|
json=payload,
|
|||
|
|
timeout=(30, 120) # 连接30秒,读取120秒
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
# 检查是否有错误
|
|||
|
|
if result.get('code'):
|
|||
|
|
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"DashScope嵌入API错误: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取嵌入向量
|
|||
|
|
if result.get('output') and result['output'].get('embeddings'):
|
|||
|
|
embeddings_data = result['output']['embeddings']
|
|||
|
|
|
|||
|
|
# 提取向量数据
|
|||
|
|
embeddings = []
|
|||
|
|
for embedding_item in embeddings_data:
|
|||
|
|
if isinstance(embedding_item, dict) and 'embedding' in embedding_item:
|
|||
|
|
embeddings.append(embedding_item['embedding'])
|
|||
|
|
elif isinstance(embedding_item, list):
|
|||
|
|
embeddings.append(embedding_item)
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}")
|
|||
|
|
|
|||
|
|
return embeddings
|
|||
|
|
else:
|
|||
|
|
error_msg = f"API响应格式错误: {result}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(error_msg)
|
|||
|
|
else:
|
|||
|
|
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
else:
|
|||
|
|
error_text = response.text[:500] if response.text else "无响应内容"
|
|||
|
|
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(error_msg)
|
|||
|
|
else:
|
|||
|
|
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
print(f"\n[WARNING] 用户中断请求")
|
|||
|
|
raise KeyboardInterrupt("用户中断请求")
|
|||
|
|
|
|||
|
|
except requests.exceptions.Timeout as e:
|
|||
|
|
error_msg = f"请求超时: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except requests.exceptions.ConnectionError as e:
|
|||
|
|
error_msg = f"连接错误: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
except requests.RequestException as e:
|
|||
|
|
error_msg = f"网络请求异常: {str(e)}"
|
|||
|
|
if attempt == self.max_retries - 1:
|
|||
|
|
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
|||
|
|
else:
|
|||
|
|
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
|||
|
|
time.sleep(self.retry_delay * (attempt + 1))
|
|||
|
|
|
|||
|
|
raise RuntimeError("所有重试都失败了")
|
|||
|
|
|
|||
|
|
def embed_query(self, text: str) -> List[float]:
|
|||
|
|
"""
|
|||
|
|
对单个查询文本进行嵌入
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 查询文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
嵌入向量
|
|||
|
|
"""
|
|||
|
|
embeddings = self.embed_texts([text])
|
|||
|
|
return embeddings[0] if embeddings else []
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_dashscope_embedding_model(
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
model_name: Optional[str] = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> DashScopeEmbeddingModel:
|
|||
|
|
"""创建DashScope嵌入模型实例的便捷函数"""
|
|||
|
|
return DashScopeEmbeddingModel(
|
|||
|
|
api_key=api_key,
|
|||
|
|
model_name=model_name,
|
|||
|
|
**kwargs
|
|||
|
|
)
|