Files
AIEC-RAG---/AIEC-RAG/retriver/langgraph/es_vector_retriever.py
2025-09-25 10:33:37 +08:00

169 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ES向量检索器
用于直接与ES向量库进行向量匹配检索
"""
import os
import sys
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper
from elasticsearch_vectorization.config import ElasticsearchConfig
class ESVectorRetriever:
"""ES向量检索器用于直接进行向量相似度匹配"""
def __init__(
self,
keyword: str,
top_k: int = 3,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
embed_model_name: Optional[str] = None
):
"""
初始化ES向量检索器
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
embed_model_name: 嵌入模型名称
"""
self.keyword = keyword
self.top_k = top_k
# 初始化嵌入模型
self.embedding_model = DashScopeEmbeddingModel(
api_key=oneapi_key,
model_name=embed_model_name
)
# 初始化ES客户端
self.es_client = ESClientWrapper()
# 设置索引名称
self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
print(f"ES向量检索器初始化完成目标索引: {self.passages_index}")
def retrieve(self, query: str) -> List[Document]:
"""
检索相关文档
Args:
query: 查询文本
Returns:
检索到的文档列表
"""
try:
# 生成查询向量
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0]
# 确保向量是列表格式
if hasattr(query_embedding, 'tolist'):
query_vector = query_embedding.tolist()
else:
query_vector = list(query_embedding)
# 执行向量搜索
search_result = self.es_client.vector_search(
index_name=self.passages_index,
vector=query_vector,
field="embedding",
size=self.top_k
)
# 解析搜索结果
documents = []
hits = search_result.get("hits", {}).get("hits", [])
for hit in hits:
source = hit["_source"]
score = hit["_score"]
# 创建Document对象
doc = Document(
page_content=source.get("content", ""),
metadata={
"passage_id": source.get("passage_id", ""),
"file_id": source.get("file_id", ""),
"evidence": source.get("evidence", ""),
"score": score,
"source": "es_vector_search"
}
)
documents.append(doc)
print(f"ES向量检索完成找到 {len(documents)} 个相关文档")
return documents
except Exception as e:
print(f"ES向量检索失败: {e}")
return []
def test_connection(self) -> bool:
"""测试ES连接"""
try:
return self.es_client.ping()
except:
return False
def get_index_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
try:
query = {"match_all": {}}
result = self.es_client.search(self.passages_index, query, size=0)
total = result.get("hits", {}).get("total", 0)
# 兼容不同ES版本的total格式
if isinstance(total, dict):
count = total.get("value", 0)
else:
count = total
return {
"index_name": self.passages_index,
"document_count": count,
"top_k": self.top_k
}
except Exception as e:
print(f"获取索引统计信息失败: {e}")
return {
"index_name": self.passages_index,
"document_count": 0,
"top_k": self.top_k,
"error": str(e)
}
def create_es_vector_retriever(
keyword: str,
top_k: int = 3,
**kwargs
) -> ESVectorRetriever:
"""
创建ES向量检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
**kwargs: 其他参数
Returns:
ES向量检索器实例
"""
return ESVectorRetriever(
keyword=keyword,
top_k=top_k,
**kwargs
)