169 lines
5.0 KiB
Python
169 lines
5.0 KiB
Python
|
|
"""
|
|||
|
|
ES向量检索器
|
|||
|
|
用于直接与ES向量库进行向量匹配检索
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from langchain_core.documents import Document
|
|||
|
|
|
|||
|
|
# 添加路径
|
|||
|
|
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
|||
|
|
sys.path.append(project_root)
|
|||
|
|
|
|||
|
|
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
|
|||
|
|
from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper
|
|||
|
|
from elasticsearch_vectorization.config import ElasticsearchConfig
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ESVectorRetriever:
|
|||
|
|
"""ES向量检索器,用于直接进行向量相似度匹配"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
keyword: str,
|
|||
|
|
top_k: int = 3,
|
|||
|
|
oneapi_key: Optional[str] = None,
|
|||
|
|
oneapi_base_url: Optional[str] = None,
|
|||
|
|
embed_model_name: Optional[str] = None
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化ES向量检索器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
keyword: ES索引关键词
|
|||
|
|
top_k: 返回的文档数量
|
|||
|
|
oneapi_key: OneAPI密钥
|
|||
|
|
oneapi_base_url: OneAPI基础URL
|
|||
|
|
embed_model_name: 嵌入模型名称
|
|||
|
|
"""
|
|||
|
|
self.keyword = keyword
|
|||
|
|
self.top_k = top_k
|
|||
|
|
|
|||
|
|
# 初始化嵌入模型
|
|||
|
|
self.embedding_model = DashScopeEmbeddingModel(
|
|||
|
|
api_key=oneapi_key,
|
|||
|
|
model_name=embed_model_name
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 初始化ES客户端
|
|||
|
|
self.es_client = ESClientWrapper()
|
|||
|
|
|
|||
|
|
# 设置索引名称
|
|||
|
|
self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
|
|||
|
|
|
|||
|
|
print(f"ES向量检索器初始化完成,目标索引: {self.passages_index}")
|
|||
|
|
|
|||
|
|
def retrieve(self, query: str) -> List[Document]:
|
|||
|
|
"""
|
|||
|
|
检索相关文档
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 查询文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
检索到的文档列表
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 生成查询向量
|
|||
|
|
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0]
|
|||
|
|
|
|||
|
|
# 确保向量是列表格式
|
|||
|
|
if hasattr(query_embedding, 'tolist'):
|
|||
|
|
query_vector = query_embedding.tolist()
|
|||
|
|
else:
|
|||
|
|
query_vector = list(query_embedding)
|
|||
|
|
|
|||
|
|
# 执行向量搜索
|
|||
|
|
search_result = self.es_client.vector_search(
|
|||
|
|
index_name=self.passages_index,
|
|||
|
|
vector=query_vector,
|
|||
|
|
field="embedding",
|
|||
|
|
size=self.top_k
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解析搜索结果
|
|||
|
|
documents = []
|
|||
|
|
hits = search_result.get("hits", {}).get("hits", [])
|
|||
|
|
|
|||
|
|
for hit in hits:
|
|||
|
|
source = hit["_source"]
|
|||
|
|
score = hit["_score"]
|
|||
|
|
|
|||
|
|
# 创建Document对象
|
|||
|
|
doc = Document(
|
|||
|
|
page_content=source.get("content", ""),
|
|||
|
|
metadata={
|
|||
|
|
"passage_id": source.get("passage_id", ""),
|
|||
|
|
"file_id": source.get("file_id", ""),
|
|||
|
|
"evidence": source.get("evidence", ""),
|
|||
|
|
"score": score,
|
|||
|
|
"source": "es_vector_search"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
documents.append(doc)
|
|||
|
|
|
|||
|
|
print(f"ES向量检索完成,找到 {len(documents)} 个相关文档")
|
|||
|
|
return documents
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"ES向量检索失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def test_connection(self) -> bool:
|
|||
|
|
"""测试ES连接"""
|
|||
|
|
try:
|
|||
|
|
return self.es_client.ping()
|
|||
|
|
except:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def get_index_stats(self) -> Dict[str, Any]:
|
|||
|
|
"""获取索引统计信息"""
|
|||
|
|
try:
|
|||
|
|
query = {"match_all": {}}
|
|||
|
|
result = self.es_client.search(self.passages_index, query, size=0)
|
|||
|
|
total = result.get("hits", {}).get("total", 0)
|
|||
|
|
|
|||
|
|
# 兼容不同ES版本的total格式
|
|||
|
|
if isinstance(total, dict):
|
|||
|
|
count = total.get("value", 0)
|
|||
|
|
else:
|
|||
|
|
count = total
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"index_name": self.passages_index,
|
|||
|
|
"document_count": count,
|
|||
|
|
"top_k": self.top_k
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"获取索引统计信息失败: {e}")
|
|||
|
|
return {
|
|||
|
|
"index_name": self.passages_index,
|
|||
|
|
"document_count": 0,
|
|||
|
|
"top_k": self.top_k,
|
|||
|
|
"error": str(e)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_es_vector_retriever(
|
|||
|
|
keyword: str,
|
|||
|
|
top_k: int = 3,
|
|||
|
|
**kwargs
|
|||
|
|
) -> ESVectorRetriever:
|
|||
|
|
"""
|
|||
|
|
创建ES向量检索器的便捷函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
keyword: ES索引关键词
|
|||
|
|
top_k: 返回的文档数量
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ES向量检索器实例
|
|||
|
|
"""
|
|||
|
|
return ESVectorRetriever(
|
|||
|
|
keyword=keyword,
|
|||
|
|
top_k=top_k,
|
|||
|
|
**kwargs
|
|||
|
|
)
|