Files
AIEC-RAG/elasticsearch_vectorization/es_client_wrapper.py
2025-09-24 09:29:12 +08:00

215 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Elasticsearch客户端包装器
自动选择最佳的连接方式官方客户端或HTTP客户端
"""
from typing import Dict, List, Any, Optional, Union
from .config import ElasticsearchConfig
import logging
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ESClientWrapper:
def delete_by_query(self, index_name: str, query: Dict) -> Dict:
"""
根据查询条件删除索引中的文档(常用于全量清空)
"""
if self.client_type == "official":
# 官方ES客户端
return self.client.delete_by_query(index=index_name, body=query, conflicts="proceed", refresh=True)
else:
# HTTP客户端需实现delete_by_query方法
return self.client.delete_by_query(index_name, query)
"""ES客户端包装器自动选择最佳连接方式"""
def __init__(self, host: str = None, username: str = None, password: str = None):
self.host = host or ElasticsearchConfig.ES_HOST
self.username = username or ElasticsearchConfig.ES_USERNAME
self.password = password or ElasticsearchConfig.ES_PASSWORD
self.client = None
self.client_type = None
# 尝试初始化客户端
self._initialize_client()
def _initialize_client(self):
"""初始化ES客户端"""
# 首先尝试官方ES客户端
try:
from elasticsearch import Elasticsearch
config = {
"hosts": [self.host],
"basic_auth": (self.username, self.password),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 120, # 增加至2分钟
"retry_on_timeout": True,
"max_retries": 3 # 增加至3次
}
es = Elasticsearch(**config)
if es.ping():
self.client = es
self.client_type = "official"
logger.info("使用官方ES客户端连接成功")
return
else:
logger.warning("官方ES客户端ping失败")
except Exception as e:
logger.warning(f"官方ES客户端初始化失败: {e}")
# 如果官方客户端失败尝试HTTP客户端
try:
from .http_es_client import HTTPESClient
http_client = HTTPESClient(self.host, self.username, self.password)
if http_client.ping():
self.client = http_client
self.client_type = "http"
logger.info("使用HTTP ES客户端连接成功")
return
else:
logger.warning("HTTP ES客户端ping失败")
except Exception as e:
logger.warning(f"HTTP ES客户端初始化失败: {e}")
# 如果都失败了
logger.error("所有ES客户端初始化都失败了")
raise Exception("无法连接到Elasticsearch服务")
def ping(self) -> bool:
"""测试连接"""
if not self.client:
return False
return self.client.ping()
def info(self) -> Dict:
"""获取ES集群信息"""
if self.client_type == "official":
return self.client.info()
else:
return self.client.info()
def cluster_health(self) -> Dict:
"""获取集群健康状态"""
if self.client_type == "official":
return self.client.cluster.health()
else:
return self.client.cluster_health()
def create_index(self, index_name: str, mapping: Dict) -> Dict:
"""创建索引"""
if self.client_type == "official":
return self.client.indices.create(index=index_name, body=mapping)
else:
return self.client.create_index(index_name, mapping)
def delete_index(self, index_name: str) -> Dict:
"""删除索引"""
if self.client_type == "official":
return self.client.indices.delete(index=index_name)
else:
return self.client.delete_index(index_name)
def index_exists(self, index_name: str) -> bool:
"""检查索引是否存在"""
if self.client_type == "official":
return self.client.indices.exists(index=index_name)
else:
return self.client.index_exists(index_name)
def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict:
"""索引文档"""
if self.client_type == "official":
return self.client.index(index=index_name, id=doc_id, body=document)
else:
return self.client.index_document(index_name, doc_id, document)
def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict:
"""批量索引文档"""
if self.client_type == "official":
from elasticsearch.helpers import bulk
actions = []
for doc in documents:
action = {
"_index": index_name,
"_source": doc
}
actions.append(action)
return bulk(self.client, actions)
else:
return self.client.bulk_index(index_name, documents)
def search(self, index_name: str, query: Dict, size: int = 10) -> Dict:
"""搜索文档"""
if self.client_type == "official":
return self.client.search(index=index_name, body={"query": query, "size": size})
else:
return self.client.search(index_name, query, size)
def vector_search(self, index_name: str, vector: List[float],
field: str = "embedding", size: int = 10) -> Dict:
"""向量搜索"""
if self.client_type == "official":
search_body = {
"knn": {
"field": field,
"query_vector": vector,
"k": size,
"num_candidates": size * 2
},
"size": size
}
return self.client.search(index=index_name, body=search_body)
else:
return self.client.vector_search(index_name, vector, field, size)
def test_wrapper():
"""测试包装器"""
print("=== 测试ES客户端包装器 ===")
try:
client = ESClientWrapper()
if client.ping():
print(f"[OK] 连接成功! 使用客户端类型: {client.client_type}")
# 获取集群信息
info = client.info()
print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}")
print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}")
# 获取集群健康状态
health = client.cluster_health()
print(f"[OK] 集群状态: {health.get('status', 'N/A')}")
return client
else:
print("✗ 连接失败")
return None
except Exception as e:
print(f"✗ 初始化失败: {e}")
return None
if __name__ == "__main__":
working_client = test_wrapper()
if working_client:
print("\n[OK] ES客户端包装器工作正常!")
else:
print("\n✗ ES客户端包装器初始化失败")
print("\n请按照 setup_local_es.md 中的说明安装本地Elasticsearch")