first commit

This commit is contained in:
闫旭隆
2025-09-25 10:33:37 +08:00
commit 34839c2654
387 changed files with 149159 additions and 0 deletions

View File

@ -0,0 +1,169 @@
"""
ES向量检索器
用于直接与ES向量库进行向量匹配检索
"""
import os
import sys
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper
from elasticsearch_vectorization.config import ElasticsearchConfig
class ESVectorRetriever:
"""ES向量检索器用于直接进行向量相似度匹配"""
def __init__(
self,
keyword: str,
top_k: int = 3,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
embed_model_name: Optional[str] = None
):
"""
初始化ES向量检索器
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
embed_model_name: 嵌入模型名称
"""
self.keyword = keyword
self.top_k = top_k
# 初始化嵌入模型
self.embedding_model = DashScopeEmbeddingModel(
api_key=oneapi_key,
model_name=embed_model_name
)
# 初始化ES客户端
self.es_client = ESClientWrapper()
# 设置索引名称
self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
print(f"ES向量检索器初始化完成目标索引: {self.passages_index}")
def retrieve(self, query: str) -> List[Document]:
"""
检索相关文档
Args:
query: 查询文本
Returns:
检索到的文档列表
"""
try:
# 生成查询向量
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0]
# 确保向量是列表格式
if hasattr(query_embedding, 'tolist'):
query_vector = query_embedding.tolist()
else:
query_vector = list(query_embedding)
# 执行向量搜索
search_result = self.es_client.vector_search(
index_name=self.passages_index,
vector=query_vector,
field="embedding",
size=self.top_k
)
# 解析搜索结果
documents = []
hits = search_result.get("hits", {}).get("hits", [])
for hit in hits:
source = hit["_source"]
score = hit["_score"]
# 创建Document对象
doc = Document(
page_content=source.get("content", ""),
metadata={
"passage_id": source.get("passage_id", ""),
"file_id": source.get("file_id", ""),
"evidence": source.get("evidence", ""),
"score": score,
"source": "es_vector_search"
}
)
documents.append(doc)
print(f"ES向量检索完成找到 {len(documents)} 个相关文档")
return documents
except Exception as e:
print(f"ES向量检索失败: {e}")
return []
def test_connection(self) -> bool:
"""测试ES连接"""
try:
return self.es_client.ping()
except:
return False
def get_index_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
try:
query = {"match_all": {}}
result = self.es_client.search(self.passages_index, query, size=0)
total = result.get("hits", {}).get("total", 0)
# 兼容不同ES版本的total格式
if isinstance(total, dict):
count = total.get("value", 0)
else:
count = total
return {
"index_name": self.passages_index,
"document_count": count,
"top_k": self.top_k
}
except Exception as e:
print(f"获取索引统计信息失败: {e}")
return {
"index_name": self.passages_index,
"document_count": 0,
"top_k": self.top_k,
"error": str(e)
}
def create_es_vector_retriever(
keyword: str,
top_k: int = 3,
**kwargs
) -> ESVectorRetriever:
"""
创建ES向量检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
**kwargs: 其他参数
Returns:
ES向量检索器实例
"""
return ESVectorRetriever(
keyword=keyword,
top_k=top_k,
**kwargs
)