""" ES向量检索器 用于直接与ES向量库进行向量匹配检索 """ import os import sys from typing import List, Dict, Any, Optional from langchain_core.documents import Document # 添加路径 project_root = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.append(project_root) from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper from elasticsearch_vectorization.config import ElasticsearchConfig class ESVectorRetriever: """ES向量检索器,用于直接进行向量相似度匹配""" def __init__( self, keyword: str, top_k: int = 3, oneapi_key: Optional[str] = None, oneapi_base_url: Optional[str] = None, embed_model_name: Optional[str] = None ): """ 初始化ES向量检索器 Args: keyword: ES索引关键词 top_k: 返回的文档数量 oneapi_key: OneAPI密钥 oneapi_base_url: OneAPI基础URL embed_model_name: 嵌入模型名称 """ self.keyword = keyword self.top_k = top_k # 初始化嵌入模型 self.embedding_model = DashScopeEmbeddingModel( api_key=oneapi_key, model_name=embed_model_name ) # 初始化ES客户端 self.es_client = ESClientWrapper() # 设置索引名称 self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword) print(f"ES向量检索器初始化完成,目标索引: {self.passages_index}") def retrieve(self, query: str) -> List[Document]: """ 检索相关文档 Args: query: 查询文本 Returns: 检索到的文档列表 """ try: # 生成查询向量 query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0] # 确保向量是列表格式 if hasattr(query_embedding, 'tolist'): query_vector = query_embedding.tolist() else: query_vector = list(query_embedding) # 执行向量搜索 search_result = self.es_client.vector_search( index_name=self.passages_index, vector=query_vector, field="embedding", size=self.top_k ) # 解析搜索结果 documents = [] hits = search_result.get("hits", {}).get("hits", []) for hit in hits: source = hit["_source"] score = hit["_score"] # 创建Document对象 doc = Document( page_content=source.get("content", ""), metadata={ "passage_id": source.get("passage_id", ""), "file_id": source.get("file_id", ""), "evidence": source.get("evidence", ""), "score": score, "source": "es_vector_search" } ) documents.append(doc) print(f"ES向量检索完成,找到 {len(documents)} 个相关文档") return documents except Exception as e: print(f"ES向量检索失败: {e}") return [] def test_connection(self) -> bool: """测试ES连接""" try: return self.es_client.ping() except: return False def get_index_stats(self) -> Dict[str, Any]: """获取索引统计信息""" try: query = {"match_all": {}} result = self.es_client.search(self.passages_index, query, size=0) total = result.get("hits", {}).get("total", 0) # 兼容不同ES版本的total格式 if isinstance(total, dict): count = total.get("value", 0) else: count = total return { "index_name": self.passages_index, "document_count": count, "top_k": self.top_k } except Exception as e: print(f"获取索引统计信息失败: {e}") return { "index_name": self.passages_index, "document_count": 0, "top_k": self.top_k, "error": str(e) } def create_es_vector_retriever( keyword: str, top_k: int = 3, **kwargs ) -> ESVectorRetriever: """ 创建ES向量检索器的便捷函数 Args: keyword: ES索引关键词 top_k: 返回的文档数量 **kwargs: 其他参数 Returns: ES向量检索器实例 """ return ESVectorRetriever( keyword=keyword, top_k=top_k, **kwargs )