first commit
This commit is contained in:
214
AIEC-RAG/elasticsearch_vectorization/es_client_wrapper.py
Normal file
214
AIEC-RAG/elasticsearch_vectorization/es_client_wrapper.py
Normal file
@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Elasticsearch客户端包装器
|
||||
自动选择最佳的连接方式(官方客户端或HTTP客户端)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from .config import ElasticsearchConfig
|
||||
import logging
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ESClientWrapper:
|
||||
def delete_by_query(self, index_name: str, query: Dict) -> Dict:
|
||||
"""
|
||||
根据查询条件删除索引中的文档(常用于全量清空)
|
||||
"""
|
||||
if self.client_type == "official":
|
||||
# 官方ES客户端
|
||||
return self.client.delete_by_query(index=index_name, body=query, conflicts="proceed", refresh=True)
|
||||
else:
|
||||
# HTTP客户端需实现delete_by_query方法
|
||||
return self.client.delete_by_query(index_name, query)
|
||||
"""ES客户端包装器,自动选择最佳连接方式"""
|
||||
|
||||
def __init__(self, host: str = None, username: str = None, password: str = None):
|
||||
self.host = host or ElasticsearchConfig.ES_HOST
|
||||
self.username = username or ElasticsearchConfig.ES_USERNAME
|
||||
self.password = password or ElasticsearchConfig.ES_PASSWORD
|
||||
|
||||
self.client = None
|
||||
self.client_type = None
|
||||
|
||||
# 尝试初始化客户端
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""初始化ES客户端"""
|
||||
# 首先尝试官方ES客户端
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
config = {
|
||||
"hosts": [self.host],
|
||||
"basic_auth": (self.username, self.password),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 120, # 增加至2分钟
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 3 # 增加至3次
|
||||
}
|
||||
|
||||
es = Elasticsearch(**config)
|
||||
if es.ping():
|
||||
self.client = es
|
||||
self.client_type = "official"
|
||||
logger.info("使用官方ES客户端连接成功")
|
||||
return
|
||||
else:
|
||||
logger.warning("官方ES客户端ping失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"官方ES客户端初始化失败: {e}")
|
||||
|
||||
# 如果官方客户端失败,尝试HTTP客户端
|
||||
try:
|
||||
from .http_es_client import HTTPESClient
|
||||
|
||||
http_client = HTTPESClient(self.host, self.username, self.password)
|
||||
if http_client.ping():
|
||||
self.client = http_client
|
||||
self.client_type = "http"
|
||||
logger.info("使用HTTP ES客户端连接成功")
|
||||
return
|
||||
else:
|
||||
logger.warning("HTTP ES客户端ping失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"HTTP ES客户端初始化失败: {e}")
|
||||
|
||||
# 如果都失败了
|
||||
logger.error("所有ES客户端初始化都失败了")
|
||||
raise Exception("无法连接到Elasticsearch服务")
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""测试连接"""
|
||||
if not self.client:
|
||||
return False
|
||||
return self.client.ping()
|
||||
|
||||
def info(self) -> Dict:
|
||||
"""获取ES集群信息"""
|
||||
if self.client_type == "official":
|
||||
return self.client.info()
|
||||
else:
|
||||
return self.client.info()
|
||||
|
||||
def cluster_health(self) -> Dict:
|
||||
"""获取集群健康状态"""
|
||||
if self.client_type == "official":
|
||||
return self.client.cluster.health()
|
||||
else:
|
||||
return self.client.cluster_health()
|
||||
|
||||
def create_index(self, index_name: str, mapping: Dict) -> Dict:
|
||||
"""创建索引"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.create(index=index_name, body=mapping)
|
||||
else:
|
||||
return self.client.create_index(index_name, mapping)
|
||||
|
||||
def delete_index(self, index_name: str) -> Dict:
|
||||
"""删除索引"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.delete(index=index_name)
|
||||
else:
|
||||
return self.client.delete_index(index_name)
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.exists(index=index_name)
|
||||
else:
|
||||
return self.client.index_exists(index_name)
|
||||
|
||||
def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict:
|
||||
"""索引文档"""
|
||||
if self.client_type == "official":
|
||||
return self.client.index(index=index_name, id=doc_id, body=document)
|
||||
else:
|
||||
return self.client.index_document(index_name, doc_id, document)
|
||||
|
||||
def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict:
|
||||
"""批量索引文档"""
|
||||
if self.client_type == "official":
|
||||
from elasticsearch.helpers import bulk
|
||||
|
||||
actions = []
|
||||
for doc in documents:
|
||||
action = {
|
||||
"_index": index_name,
|
||||
"_source": doc
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
return bulk(self.client, actions)
|
||||
else:
|
||||
return self.client.bulk_index(index_name, documents)
|
||||
|
||||
def search(self, index_name: str, query: Dict, size: int = 10) -> Dict:
|
||||
"""搜索文档"""
|
||||
if self.client_type == "official":
|
||||
return self.client.search(index=index_name, body={"query": query, "size": size})
|
||||
else:
|
||||
return self.client.search(index_name, query, size)
|
||||
|
||||
def vector_search(self, index_name: str, vector: List[float],
|
||||
field: str = "embedding", size: int = 10) -> Dict:
|
||||
"""向量搜索"""
|
||||
if self.client_type == "official":
|
||||
search_body = {
|
||||
"knn": {
|
||||
"field": field,
|
||||
"query_vector": vector,
|
||||
"k": size,
|
||||
"num_candidates": size * 2
|
||||
},
|
||||
"size": size
|
||||
}
|
||||
return self.client.search(index=index_name, body=search_body)
|
||||
else:
|
||||
return self.client.vector_search(index_name, vector, field, size)
|
||||
|
||||
|
||||
def test_wrapper():
|
||||
"""测试包装器"""
|
||||
print("=== 测试ES客户端包装器 ===")
|
||||
|
||||
try:
|
||||
client = ESClientWrapper()
|
||||
|
||||
if client.ping():
|
||||
print(f"[OK] 连接成功! 使用客户端类型: {client.client_type}")
|
||||
|
||||
# 获取集群信息
|
||||
info = client.info()
|
||||
print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}")
|
||||
print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}")
|
||||
|
||||
# 获取集群健康状态
|
||||
health = client.cluster_health()
|
||||
print(f"[OK] 集群状态: {health.get('status', 'N/A')}")
|
||||
|
||||
return client
|
||||
else:
|
||||
print("✗ 连接失败")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_client = test_wrapper()
|
||||
if working_client:
|
||||
print("\n[OK] ES客户端包装器工作正常!")
|
||||
else:
|
||||
print("\n✗ ES客户端包装器初始化失败")
|
||||
print("\n请按照 setup_local_es.md 中的说明安装本地Elasticsearch")
|
||||
Reference in New Issue
Block a user