first commit

This commit is contained in:
闫旭隆
2025-10-17 09:31:28 +08:00
commit 4698145045
589 changed files with 196795 additions and 0 deletions

View File

@ -0,0 +1,269 @@
"""
阿里云DashScope嵌入模型实现
直接调用阿里云文本嵌入API不通过OneAPI兼容层
"""
import os
import sys
import json
import time
import requests
import numpy as np
from typing import List, Union, Optional
from dotenv import load_dotenv
# 添加项目根目录到路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
# 加载环境变量
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
class DashScopeEmbeddingModel(BaseEmbeddingModel):
"""
阿里云DashScope嵌入模型实现
直接调用DashScope文本嵌入API
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0
):
"""
初始化DashScope嵌入模型
Args:
api_key: 阿里云DashScope API Key
model_name: 嵌入模型名称,如 text-embedding-v3
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
"""
# 初始化父类传入None作为sentence_encoder我们不使用它
super().__init__(sentence_encoder=None)
self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3')
self.max_retries = max_retries
self.retry_delay = retry_delay
if not self.api_key:
raise ValueError("DashScope API Key未设置请在.env文件中设置ONEAPI_KEY")
# DashScope文本嵌入API端点
self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
# 设置请求头
self.headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}")
self._test_connection()
def encode(self, texts, **kwargs):
"""
实现基类的抽象方法
对文本进行编码,返回嵌入向量
Args:
texts: 文本或文本列表
**kwargs: 其他参数batch_size, show_progress_bar, query_type等
Returns:
嵌入向量或嵌入向量列表numpy数组格式
"""
import numpy as np
if isinstance(texts, str):
# 单个文本
result = self.embed_query(texts)
return np.array(result)
else:
# 文本列表
results = self.embed_texts(texts)
return np.array(results)
def _test_connection(self):
"""测试DashScope嵌入API连接"""
try:
# 发送一个简单的测试请求
test_payload = {
"model": self.model_name,
"input": {
"texts": ["test"]
}
}
response = requests.post(
self.api_url,
headers=self.headers,
json=test_payload,
timeout=(5, 15) # 短超时进行连接测试
)
if response.status_code == 200:
result = response.json()
if result.get('output') and result['output'].get('embeddings'):
print("[OK] DashScope嵌入API连接测试成功")
else:
print(f"[WARNING] DashScope嵌入API响应格式异常: {result}")
else:
print(f"[WARNING] DashScope嵌入API连接异常状态码: {response.status_code}")
print(f" 响应: {response.text[:200]}")
except requests.exceptions.Timeout:
print("[WARNING] DashScope嵌入API连接超时但将在实际请求时重试")
except requests.exceptions.ConnectionError:
print("[WARNING] DashScope嵌入API连接失败请检查网络状态")
except Exception as e:
print(f"[WARNING] DashScope嵌入API连接测试出错: {e}")
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
对文本列表进行嵌入
Args:
texts: 文本列表
Returns:
嵌入向量列表
"""
import numpy as np
if not texts:
return []
# DashScope API支持批量处理但需要分批处理大量文本
batch_size = 10 # 每批处理10个文本
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_embeddings = self._embed_batch(batch_texts)
all_embeddings.extend(batch_embeddings)
return all_embeddings
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
"""对一批文本进行嵌入"""
payload = {
"model": self.model_name,
"input": {
"texts": texts
}
}
for attempt in range(self.max_retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=(30, 120) # 连接30秒读取120秒
)
if response.status_code == 200:
result = response.json()
# 检查是否有错误
if result.get('code'):
error_msg = result.get('message', f'API错误代码: {result["code"]}')
if attempt == self.max_retries - 1:
raise RuntimeError(f"DashScope嵌入API错误: {error_msg}")
else:
print(f"API错误正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
continue
# 提取嵌入向量
if result.get('output') and result['output'].get('embeddings'):
embeddings_data = result['output']['embeddings']
# 提取向量数据
embeddings = []
for embedding_item in embeddings_data:
if isinstance(embedding_item, dict) and 'embedding' in embedding_item:
embeddings.append(embedding_item['embedding'])
elif isinstance(embedding_item, list):
embeddings.append(embedding_item)
else:
print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}")
return embeddings
else:
error_msg = f"API响应格式错误: {result}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
time.sleep(self.retry_delay * (attempt + 1))
else:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败状态码: {response.status_code}, 响应: {error_text}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
time.sleep(self.retry_delay * (attempt + 1))
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断请求")
raise KeyboardInterrupt("用户中断请求")
except requests.exceptions.Timeout as e:
error_msg = f"请求超时: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
else:
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.exceptions.ConnectionError as e:
error_msg = f"连接错误: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
else:
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.RequestException as e:
error_msg = f"网络请求异常: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
else:
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
time.sleep(self.retry_delay * (attempt + 1))
raise RuntimeError("所有重试都失败了")
def embed_query(self, text: str) -> List[float]:
"""
对单个查询文本进行嵌入
Args:
text: 查询文本
Returns:
嵌入向量
"""
embeddings = self.embed_texts([text])
return embeddings[0] if embeddings else []
def create_dashscope_embedding_model(
api_key: Optional[str] = None,
model_name: Optional[str] = None,
**kwargs
) -> DashScopeEmbeddingModel:
"""创建DashScope嵌入模型实例的便捷函数"""
return DashScopeEmbeddingModel(
api_key=api_key,
model_name=model_name,
**kwargs
)