Files
AIEC-RAG/elasticsearch_vectorization/http_es_client.py
2025-09-24 09:29:12 +08:00

208 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于HTTP请求的简单ES客户端
用于替代官方ES客户端解决网络连接问题
"""
import requests
import json
import base64
from typing import Dict, List, Any, Optional
from config import ElasticsearchConfig
class HTTPESClient:
def delete_by_query(self, index_name: str, query: Dict) -> Dict:
"""
根据查询条件删除索引中的文档(常用于全量清空)
"""
# POST /{index}/_delete_by_query
return self._make_request("POST", f"/{index_name}/_delete_by_query", query)
"""基于HTTP请求的ES客户端"""
def __init__(self, host: str = None, username: str = None, password: str = None):
self.host = (host or ElasticsearchConfig.ES_HOST).rstrip('/')
self.username = username or ElasticsearchConfig.ES_USERNAME
self.password = password or ElasticsearchConfig.ES_PASSWORD
# 设置认证
self.session = requests.Session()
if self.username and self.password:
auth_string = f"{self.username}:{self.password}"
auth_bytes = auth_string.encode('ascii')
auth_b64 = base64.b64encode(auth_bytes).decode('ascii')
self.session.headers.update({
'Authorization': f'Basic {auth_b64}',
'Content-Type': 'application/json'
})
# 禁用SSL警告
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def _make_request(self, method: str, path: str, data: Optional[Dict] = None,
timeout: int = 30) -> Dict:
"""发送HTTP请求"""
url = f"{self.host}{path}"
try:
if method.upper() == 'GET':
response = self.session.get(url, timeout=timeout, verify=False)
elif method.upper() == 'POST':
response = self.session.post(url, json=data, timeout=timeout, verify=False)
elif method.upper() == 'PUT':
response = self.session.put(url, json=data, timeout=timeout, verify=False)
elif method.upper() == 'DELETE':
response = self.session.delete(url, timeout=timeout, verify=False)
elif method.upper() == 'HEAD':
response = self.session.head(url, timeout=timeout, verify=False)
return {"status_code": response.status_code}
else:
raise ValueError(f"不支持的HTTP方法: {method}")
response.raise_for_status()
if response.content:
return response.json()
else:
return {"status_code": response.status_code}
except requests.exceptions.Timeout:
raise Exception(f"请求超时: {url}")
except requests.exceptions.ConnectionError:
raise Exception(f"连接错误: {url}")
except requests.exceptions.HTTPError as e:
raise Exception(f"HTTP错误 {response.status_code}: {response.text}")
except Exception as e:
raise Exception(f"请求失败: {e}")
def ping(self) -> bool:
"""测试连接"""
try:
result = self._make_request("GET", "/", timeout=10)
return "cluster_name" in result
except:
return False
def info(self) -> Dict:
"""获取ES集群信息"""
return self._make_request("GET", "/")
def cluster_health(self) -> Dict:
"""获取集群健康状态"""
return self._make_request("GET", "/_cluster/health")
def create_index(self, index_name: str, mapping: Dict) -> Dict:
"""创建索引"""
return self._make_request("PUT", f"/{index_name}", mapping)
def delete_index(self, index_name: str) -> Dict:
"""删除索引"""
return self._make_request("DELETE", f"/{index_name}")
def index_exists(self, index_name: str) -> bool:
"""检查索引是否存在"""
try:
result = self._make_request("HEAD", f"/{index_name}")
return result.get("status_code") == 200
except:
return False
def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict:
"""索引文档"""
return self._make_request("PUT", f"/{index_name}/_doc/{doc_id}", document)
def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict:
"""批量索引文档"""
# 构造bulk请求体
bulk_lines = []
for doc in documents:
# 添加索引操作
bulk_lines.append(json.dumps({"index": {"_index": index_name}}))
bulk_lines.append(json.dumps(doc))
bulk_body = "\n".join(bulk_lines) + "\n"
# 发送bulk请求
url = f"{self.host}/_bulk"
headers = self.session.headers.copy()
headers['Content-Type'] = 'application/x-ndjson'
response = self.session.post(url, data=bulk_body, headers=headers,
timeout=60, verify=False)
response.raise_for_status()
return response.json()
def search(self, index_name: str, query: Dict, size: int = 10) -> Dict:
"""搜索文档"""
# 修复如果query已经包含size直接使用query否则添加size
if "size" in query:
search_body = query
else:
search_body = {
"query": query,
"size": size
}
return self._make_request("POST", f"/{index_name}/_search", search_body)
def vector_search(self, index_name: str, vector: List[float],
field: str = "embedding", size: int = 10) -> Dict:
"""向量搜索"""
search_body = {
"knn": {
"field": field,
"query_vector": vector,
"k": size,
"num_candidates": size * 2
},
"size": size
}
return self._make_request("POST", f"/{index_name}/_search", search_body)
def test_http_client():
"""测试HTTP ES客户端"""
print("=== 测试HTTP ES客户端 ===")
# 尝试不同的主机配置
hosts_to_try = [
"http://101.200.154.78:9200",
"http://127.0.0.1:9200",
"http://localhost:9200"
]
for host in hosts_to_try:
print(f"\n尝试连接: {host}")
client = HTTPESClient(host=host)
try:
if client.ping():
print("[OK] 连接成功!")
# 获取集群信息
info = client.info()
print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}")
print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}")
# 获取集群健康状态
health = client.cluster_health()
print(f"[OK] 集群状态: {health.get('status', 'N/A')}")
return client
else:
print("✗ ping失败")
except Exception as e:
print(f"✗ 连接失败: {e}")
return None
if __name__ == "__main__":
working_client = test_http_client()
if working_client:
print("\n[OK] 找到可用的ES连接!")
else:
print("\n✗ 未找到可用的ES连接")