#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Elasticsearch客户端包装器 自动选择最佳的连接方式(官方客户端或HTTP客户端) """ from typing import Dict, List, Any, Optional, Union from .config import ElasticsearchConfig import logging # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ESClientWrapper: def delete_by_query(self, index_name: str, query: Dict) -> Dict: """ 根据查询条件删除索引中的文档(常用于全量清空) """ if self.client_type == "official": # 官方ES客户端 return self.client.delete_by_query(index=index_name, body=query, conflicts="proceed", refresh=True) else: # HTTP客户端需实现delete_by_query方法 return self.client.delete_by_query(index_name, query) """ES客户端包装器,自动选择最佳连接方式""" def __init__(self, host: str = None, username: str = None, password: str = None): self.host = host or ElasticsearchConfig.ES_HOST self.username = username or ElasticsearchConfig.ES_USERNAME self.password = password or ElasticsearchConfig.ES_PASSWORD self.client = None self.client_type = None # 尝试初始化客户端 self._initialize_client() def _initialize_client(self): """初始化ES客户端""" # 首先尝试官方ES客户端 try: from elasticsearch import Elasticsearch config = { "hosts": [self.host], "basic_auth": (self.username, self.password), "verify_certs": False, "ssl_show_warn": False, "request_timeout": 120, # 增加至2分钟 "retry_on_timeout": True, "max_retries": 3 # 增加至3次 } es = Elasticsearch(**config) if es.ping(): self.client = es self.client_type = "official" logger.info("使用官方ES客户端连接成功") return else: logger.warning("官方ES客户端ping失败") except Exception as e: logger.warning(f"官方ES客户端初始化失败: {e}") # 如果官方客户端失败,尝试HTTP客户端 try: from .http_es_client import HTTPESClient http_client = HTTPESClient(self.host, self.username, self.password) if http_client.ping(): self.client = http_client self.client_type = "http" logger.info("使用HTTP ES客户端连接成功") return else: logger.warning("HTTP ES客户端ping失败") except Exception as e: logger.warning(f"HTTP ES客户端初始化失败: {e}") # 如果都失败了 logger.error("所有ES客户端初始化都失败了") raise Exception("无法连接到Elasticsearch服务") def ping(self) -> bool: """测试连接""" if not self.client: return False return self.client.ping() def info(self) -> Dict: """获取ES集群信息""" if self.client_type == "official": return self.client.info() else: return self.client.info() def cluster_health(self) -> Dict: """获取集群健康状态""" if self.client_type == "official": return self.client.cluster.health() else: return self.client.cluster_health() def create_index(self, index_name: str, mapping: Dict) -> Dict: """创建索引""" if self.client_type == "official": return self.client.indices.create(index=index_name, body=mapping) else: return self.client.create_index(index_name, mapping) def delete_index(self, index_name: str) -> Dict: """删除索引""" if self.client_type == "official": return self.client.indices.delete(index=index_name) else: return self.client.delete_index(index_name) def index_exists(self, index_name: str) -> bool: """检查索引是否存在""" if self.client_type == "official": return self.client.indices.exists(index=index_name) else: return self.client.index_exists(index_name) def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict: """索引文档""" if self.client_type == "official": return self.client.index(index=index_name, id=doc_id, body=document) else: return self.client.index_document(index_name, doc_id, document) def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict: """批量索引文档""" if self.client_type == "official": from elasticsearch.helpers import bulk actions = [] for doc in documents: action = { "_index": index_name, "_source": doc } actions.append(action) return bulk(self.client, actions) else: return self.client.bulk_index(index_name, documents) def search(self, index_name: str, query: Dict, size: int = 10) -> Dict: """搜索文档""" if self.client_type == "official": return self.client.search(index=index_name, body={"query": query, "size": size}) else: return self.client.search(index_name, query, size) def vector_search(self, index_name: str, vector: List[float], field: str = "embedding", size: int = 10) -> Dict: """向量搜索""" if self.client_type == "official": search_body = { "knn": { "field": field, "query_vector": vector, "k": size, "num_candidates": size * 2 }, "size": size } return self.client.search(index=index_name, body=search_body) else: return self.client.vector_search(index_name, vector, field, size) def test_wrapper(): """测试包装器""" print("=== 测试ES客户端包装器 ===") try: client = ESClientWrapper() if client.ping(): print(f"[OK] 连接成功! 使用客户端类型: {client.client_type}") # 获取集群信息 info = client.info() print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}") print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}") # 获取集群健康状态 health = client.cluster_health() print(f"[OK] 集群状态: {health.get('status', 'N/A')}") return client else: print("✗ 连接失败") return None except Exception as e: print(f"✗ 初始化失败: {e}") return None if __name__ == "__main__": working_client = test_wrapper() if working_client: print("\n[OK] ES客户端包装器工作正常!") else: print("\n✗ ES客户端包装器初始化失败") print("\n请按照 setup_local_es.md 中的说明安装本地Elasticsearch")