Files
AIEC-new/AIEC-RAG/retriver/langgraph/dashscope_embedding.py
2025-10-17 09:31:28 +08:00

269 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
阿里云DashScope嵌入模型实现
直接调用阿里云文本嵌入API不通过OneAPI兼容层
"""
import os
import sys
import json
import time
import requests
import numpy as np
from typing import List, Union, Optional
from dotenv import load_dotenv
# 添加项目根目录到路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
# 加载环境变量
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
class DashScopeEmbeddingModel(BaseEmbeddingModel):
"""
阿里云DashScope嵌入模型实现
直接调用DashScope文本嵌入API
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0
):
"""
初始化DashScope嵌入模型
Args:
api_key: 阿里云DashScope API Key
model_name: 嵌入模型名称,如 text-embedding-v3
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
"""
# 初始化父类传入None作为sentence_encoder我们不使用它
super().__init__(sentence_encoder=None)
self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3')
self.max_retries = max_retries
self.retry_delay = retry_delay
if not self.api_key:
raise ValueError("DashScope API Key未设置请在.env文件中设置ONEAPI_KEY")
# DashScope文本嵌入API端点
self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
# 设置请求头
self.headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}")
self._test_connection()
def encode(self, texts, **kwargs):
"""
实现基类的抽象方法
对文本进行编码,返回嵌入向量
Args:
texts: 文本或文本列表
**kwargs: 其他参数batch_size, show_progress_bar, query_type等
Returns:
嵌入向量或嵌入向量列表numpy数组格式
"""
import numpy as np
if isinstance(texts, str):
# 单个文本
result = self.embed_query(texts)
return np.array(result)
else:
# 文本列表
results = self.embed_texts(texts)
return np.array(results)
def _test_connection(self):
"""测试DashScope嵌入API连接"""
try:
# 发送一个简单的测试请求
test_payload = {
"model": self.model_name,
"input": {
"texts": ["test"]
}
}
response = requests.post(
self.api_url,
headers=self.headers,
json=test_payload,
timeout=(5, 15) # 短超时进行连接测试
)
if response.status_code == 200:
result = response.json()
if result.get('output') and result['output'].get('embeddings'):
print("[OK] DashScope嵌入API连接测试成功")
else:
print(f"[WARNING] DashScope嵌入API响应格式异常: {result}")
else:
print(f"[WARNING] DashScope嵌入API连接异常状态码: {response.status_code}")
print(f" 响应: {response.text[:200]}")
except requests.exceptions.Timeout:
print("[WARNING] DashScope嵌入API连接超时但将在实际请求时重试")
except requests.exceptions.ConnectionError:
print("[WARNING] DashScope嵌入API连接失败请检查网络状态")
except Exception as e:
print(f"[WARNING] DashScope嵌入API连接测试出错: {e}")
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
对文本列表进行嵌入
Args:
texts: 文本列表
Returns:
嵌入向量列表
"""
import numpy as np
if not texts:
return []
# DashScope API支持批量处理但需要分批处理大量文本
batch_size = 10 # 每批处理10个文本
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_embeddings = self._embed_batch(batch_texts)
all_embeddings.extend(batch_embeddings)
return all_embeddings
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
"""对一批文本进行嵌入"""
payload = {
"model": self.model_name,
"input": {
"texts": texts
}
}
for attempt in range(self.max_retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=(30, 120) # 连接30秒读取120秒
)
if response.status_code == 200:
result = response.json()
# 检查是否有错误
if result.get('code'):
error_msg = result.get('message', f'API错误代码: {result["code"]}')
if attempt == self.max_retries - 1:
raise RuntimeError(f"DashScope嵌入API错误: {error_msg}")
else:
print(f"API错误正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
continue
# 提取嵌入向量
if result.get('output') and result['output'].get('embeddings'):
embeddings_data = result['output']['embeddings']
# 提取向量数据
embeddings = []
for embedding_item in embeddings_data:
if isinstance(embedding_item, dict) and 'embedding' in embedding_item:
embeddings.append(embedding_item['embedding'])
elif isinstance(embedding_item, list):
embeddings.append(embedding_item)
else:
print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}")
return embeddings
else:
error_msg = f"API响应格式错误: {result}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
time.sleep(self.retry_delay * (attempt + 1))
else:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败状态码: {response.status_code}, 响应: {error_text}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
time.sleep(self.retry_delay * (attempt + 1))
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断请求")
raise KeyboardInterrupt("用户中断请求")
except requests.exceptions.Timeout as e:
error_msg = f"请求超时: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
else:
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.exceptions.ConnectionError as e:
error_msg = f"连接错误: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
else:
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.RequestException as e:
error_msg = f"网络请求异常: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
else:
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
time.sleep(self.retry_delay * (attempt + 1))
raise RuntimeError("所有重试都失败了")
def embed_query(self, text: str) -> List[float]:
"""
对单个查询文本进行嵌入
Args:
text: 查询文本
Returns:
嵌入向量
"""
embeddings = self.embed_texts([text])
return embeddings[0] if embeddings else []
def create_dashscope_embedding_model(
api_key: Optional[str] = None,
model_name: Optional[str] = None,
**kwargs
) -> DashScopeEmbeddingModel:
"""创建DashScope嵌入模型实例的便捷函数"""
return DashScopeEmbeddingModel(
api_key=api_key,
model_name=model_name,
**kwargs
)