first commit
This commit is contained in:
12
elasticsearch_vectorization/__init__.py
Normal file
12
elasticsearch_vectorization/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Elasticsearch 向量化模块
|
||||
包含ES客户端封装和配置
|
||||
"""
|
||||
|
||||
from .config import ElasticsearchConfig
|
||||
from .es_client_wrapper import ESClientWrapper
|
||||
|
||||
__all__ = [
|
||||
"ElasticsearchConfig",
|
||||
"ESClientWrapper"
|
||||
]
|
||||
BIN
elasticsearch_vectorization/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
elasticsearch_vectorization/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
elasticsearch_vectorization/__pycache__/config.cpython-311.pyc
Normal file
BIN
elasticsearch_vectorization/__pycache__/config.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
208
elasticsearch_vectorization/config.py
Normal file
208
elasticsearch_vectorization/config.py
Normal file
@ -0,0 +1,208 @@
|
||||
"""
|
||||
Elasticsearch配置文件
|
||||
包含云端ES连接配置和相关常量
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
class ElasticsearchConfig:
|
||||
"""Elasticsearch配置类"""
|
||||
|
||||
# ES连接配置(优先从环境变量读取)
|
||||
ES_HOST = os.getenv("ELASTICSEARCH_HOST", "http://101.200.154.78:9200")
|
||||
ES_USERNAME = os.getenv("ELASTICSEARCH_USERNAME", "elastic")
|
||||
ES_PASSWORD = os.getenv("ELASTICSEARCH_PASSWORD", "Abcd123456")
|
||||
|
||||
# 备用SSH隧道方式(如果端口未开放时使用)
|
||||
# ES_HOST = "http://localhost:9200" # 通过SSH隧道映射到本地
|
||||
|
||||
# 索引配置
|
||||
INDEX_PREFIX = "hipporag"
|
||||
EDGES_INDEX_SUFFIX = "edges"
|
||||
PASSAGES_INDEX_SUFFIX = "passages"
|
||||
|
||||
# 向量配置
|
||||
VECTOR_DIMENSION = 1024 # text-embedding-v3的维度
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
# 批处理配置
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
DEFAULT_TEXT_BATCH_SIZE = 40
|
||||
DEFAULT_EDGE_BATCH_SIZE = 256
|
||||
|
||||
# 搜索配置
|
||||
DEFAULT_TOP_K = 10
|
||||
DEFAULT_NUM_CANDIDATES_MULTIPLIER = 2
|
||||
|
||||
@classmethod
|
||||
def get_edges_index_name(cls, keyword: str) -> str:
|
||||
"""获取边索引名称"""
|
||||
return f"{cls.INDEX_PREFIX}_{keyword}_{cls.EDGES_INDEX_SUFFIX}"
|
||||
|
||||
@classmethod
|
||||
def get_passages_index_name(cls, keyword: str) -> str:
|
||||
"""获取文本段落索引名称"""
|
||||
return f"{cls.INDEX_PREFIX}_{keyword}_{cls.PASSAGES_INDEX_SUFFIX}"
|
||||
|
||||
@classmethod
|
||||
def get_es_config(cls) -> dict:
|
||||
"""获取ES连接配置"""
|
||||
return {
|
||||
"hosts": [cls.ES_HOST],
|
||||
"basic_auth": (cls.ES_USERNAME, cls.ES_PASSWORD),
|
||||
"verify_certs": False, # 如果使用自签名证书
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 300, # 请求超时时间,增加至5分钟
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 5, # 重试次数,增加至5次
|
||||
"sniff_on_start": False, # 禁用节点嗅探,避免网络问题
|
||||
"sniff_on_connection_fail": False, # 禁用连接失败时的嗅探
|
||||
"sniff_timeout": 10,
|
||||
"http_compress": True, # 启用HTTP压缩
|
||||
"http_auth": (cls.ES_USERNAME, cls.ES_PASSWORD),
|
||||
"timeout": 300, # 连接超时
|
||||
"connection_class": None, # 使用默认连接类
|
||||
"selector_class": None, # 使用默认选择器
|
||||
"dead_timeout": 60, # 死连接超时时间
|
||||
"retry_on_status": {502, 503, 504, 408, 429}, # 重试的HTTP状态码
|
||||
"maxsize": 25, # 连接池大小
|
||||
}
|
||||
|
||||
|
||||
class SSHTunnelConfig:
|
||||
"""SSH隧道配置类"""
|
||||
|
||||
# SSH连接配置(需要向管理员获取)
|
||||
SSH_HOST = "101.200.154.78" # SSH服务器地址
|
||||
SSH_PORT = 22 # SSH端口
|
||||
SSH_USERNAME = "your_username" # 需要替换为实际用户名
|
||||
SSH_KEY_PATH = "path/to/your/private_key.pem" # 需要替换为实际私钥路径
|
||||
# SSH_PASSWORD = "your_password" # 如果使用密码认证
|
||||
|
||||
# 隧道映射配置
|
||||
LOCAL_PORT = 9200 # 本地端口
|
||||
REMOTE_HOST = "localhost" # 服务器内部ES地址
|
||||
REMOTE_PORT = 9200 # 服务器内部ES端口
|
||||
|
||||
@classmethod
|
||||
def get_ssh_config(cls) -> dict:
|
||||
"""获取SSH隧道配置"""
|
||||
config = {
|
||||
'ssh_host': cls.SSH_HOST,
|
||||
'ssh_port': cls.SSH_PORT,
|
||||
'ssh_username': cls.SSH_USERNAME,
|
||||
'local_port': cls.LOCAL_PORT,
|
||||
'remote_host': cls.REMOTE_HOST,
|
||||
'remote_port': cls.REMOTE_PORT
|
||||
}
|
||||
|
||||
# 添加认证信息
|
||||
if hasattr(cls, 'SSH_KEY_PATH') and cls.SSH_KEY_PATH != "path/to/your/private_key.pem":
|
||||
config['ssh_key_path'] = cls.SSH_KEY_PATH
|
||||
elif hasattr(cls, 'SSH_PASSWORD'):
|
||||
config['ssh_password'] = cls.SSH_PASSWORD
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class HippoRAGConfig:
|
||||
"""HippoRAG检索配置"""
|
||||
|
||||
# 推理配置
|
||||
TOPK_EDGES = 10
|
||||
TOPK_NODES = 30
|
||||
WEIGHT_ADJUST = 0.05
|
||||
PPR_ALPHA = 0.85
|
||||
PPR_MAX_ITER = 100
|
||||
PPR_TOL = 1e-6
|
||||
|
||||
# 检索模式
|
||||
RETRIEVAL_MODE = "query2edge" # query2edge, query2node, ner2node
|
||||
|
||||
# OneAPI配置(从环境变量读取)
|
||||
@classmethod
|
||||
def get_oneapi_config(cls) -> dict:
|
||||
"""获取OneAPI配置"""
|
||||
return {
|
||||
"api_key": os.getenv('ONEAPI_KEY'),
|
||||
"base_url": os.getenv('ONEAPI_BASE_URL'),
|
||||
"model_embed": os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3'),
|
||||
"model_gen": os.getenv('ONEAPI_MODEL_GEN', 'qwen2-7b-instruct')
|
||||
}
|
||||
|
||||
|
||||
# 索引映射模板(兼容版本)
|
||||
EDGES_INDEX_MAPPING = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"edge_index": {"type": "integer"},
|
||||
"head_node_id": {"type": "keyword"},
|
||||
"tail_node_id": {"type": "keyword"},
|
||||
"head_entity": {"type": "keyword"},
|
||||
"relation": {"type": "keyword"},
|
||||
"tail_entity": {"type": "keyword"},
|
||||
"triple_text": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1024,
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"created_at": {"type": "date"},
|
||||
"keyword": {"type": "keyword"}
|
||||
}
|
||||
},
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
}
|
||||
}
|
||||
|
||||
PASSAGES_INDEX_MAPPING = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"passage_id": {"type": "keyword"},
|
||||
"content": {"type": "text"},
|
||||
"file_id": {"type": "keyword"},
|
||||
"evidence": {"type": "keyword"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1024,
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"created_at": {"type": "date"},
|
||||
"keyword": {"type": "keyword"}
|
||||
}
|
||||
},
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 新增:Node节点ES索引mapping
|
||||
NODES_INDEX_MAPPING = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"node_id": {"type": "keyword"},
|
||||
"name": {"type": "text"},
|
||||
"type": {"type": "keyword"},
|
||||
"concepts": {"type": "text"},
|
||||
"synsets": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1024,
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"created_at": {"type": "date"},
|
||||
"keyword": {"type": "keyword"}
|
||||
}
|
||||
},
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0 # 改回0,避免黄色状态
|
||||
}
|
||||
}
|
||||
214
elasticsearch_vectorization/es_client_wrapper.py
Normal file
214
elasticsearch_vectorization/es_client_wrapper.py
Normal file
@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Elasticsearch客户端包装器
|
||||
自动选择最佳的连接方式(官方客户端或HTTP客户端)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from .config import ElasticsearchConfig
|
||||
import logging
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ESClientWrapper:
|
||||
def delete_by_query(self, index_name: str, query: Dict) -> Dict:
|
||||
"""
|
||||
根据查询条件删除索引中的文档(常用于全量清空)
|
||||
"""
|
||||
if self.client_type == "official":
|
||||
# 官方ES客户端
|
||||
return self.client.delete_by_query(index=index_name, body=query, conflicts="proceed", refresh=True)
|
||||
else:
|
||||
# HTTP客户端需实现delete_by_query方法
|
||||
return self.client.delete_by_query(index_name, query)
|
||||
"""ES客户端包装器,自动选择最佳连接方式"""
|
||||
|
||||
def __init__(self, host: str = None, username: str = None, password: str = None):
|
||||
self.host = host or ElasticsearchConfig.ES_HOST
|
||||
self.username = username or ElasticsearchConfig.ES_USERNAME
|
||||
self.password = password or ElasticsearchConfig.ES_PASSWORD
|
||||
|
||||
self.client = None
|
||||
self.client_type = None
|
||||
|
||||
# 尝试初始化客户端
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""初始化ES客户端"""
|
||||
# 首先尝试官方ES客户端
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
config = {
|
||||
"hosts": [self.host],
|
||||
"basic_auth": (self.username, self.password),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 120, # 增加至2分钟
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 3 # 增加至3次
|
||||
}
|
||||
|
||||
es = Elasticsearch(**config)
|
||||
if es.ping():
|
||||
self.client = es
|
||||
self.client_type = "official"
|
||||
logger.info("使用官方ES客户端连接成功")
|
||||
return
|
||||
else:
|
||||
logger.warning("官方ES客户端ping失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"官方ES客户端初始化失败: {e}")
|
||||
|
||||
# 如果官方客户端失败,尝试HTTP客户端
|
||||
try:
|
||||
from .http_es_client import HTTPESClient
|
||||
|
||||
http_client = HTTPESClient(self.host, self.username, self.password)
|
||||
if http_client.ping():
|
||||
self.client = http_client
|
||||
self.client_type = "http"
|
||||
logger.info("使用HTTP ES客户端连接成功")
|
||||
return
|
||||
else:
|
||||
logger.warning("HTTP ES客户端ping失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"HTTP ES客户端初始化失败: {e}")
|
||||
|
||||
# 如果都失败了
|
||||
logger.error("所有ES客户端初始化都失败了")
|
||||
raise Exception("无法连接到Elasticsearch服务")
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""测试连接"""
|
||||
if not self.client:
|
||||
return False
|
||||
return self.client.ping()
|
||||
|
||||
def info(self) -> Dict:
|
||||
"""获取ES集群信息"""
|
||||
if self.client_type == "official":
|
||||
return self.client.info()
|
||||
else:
|
||||
return self.client.info()
|
||||
|
||||
def cluster_health(self) -> Dict:
|
||||
"""获取集群健康状态"""
|
||||
if self.client_type == "official":
|
||||
return self.client.cluster.health()
|
||||
else:
|
||||
return self.client.cluster_health()
|
||||
|
||||
def create_index(self, index_name: str, mapping: Dict) -> Dict:
|
||||
"""创建索引"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.create(index=index_name, body=mapping)
|
||||
else:
|
||||
return self.client.create_index(index_name, mapping)
|
||||
|
||||
def delete_index(self, index_name: str) -> Dict:
|
||||
"""删除索引"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.delete(index=index_name)
|
||||
else:
|
||||
return self.client.delete_index(index_name)
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
if self.client_type == "official":
|
||||
return self.client.indices.exists(index=index_name)
|
||||
else:
|
||||
return self.client.index_exists(index_name)
|
||||
|
||||
def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict:
|
||||
"""索引文档"""
|
||||
if self.client_type == "official":
|
||||
return self.client.index(index=index_name, id=doc_id, body=document)
|
||||
else:
|
||||
return self.client.index_document(index_name, doc_id, document)
|
||||
|
||||
def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict:
|
||||
"""批量索引文档"""
|
||||
if self.client_type == "official":
|
||||
from elasticsearch.helpers import bulk
|
||||
|
||||
actions = []
|
||||
for doc in documents:
|
||||
action = {
|
||||
"_index": index_name,
|
||||
"_source": doc
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
return bulk(self.client, actions)
|
||||
else:
|
||||
return self.client.bulk_index(index_name, documents)
|
||||
|
||||
def search(self, index_name: str, query: Dict, size: int = 10) -> Dict:
|
||||
"""搜索文档"""
|
||||
if self.client_type == "official":
|
||||
return self.client.search(index=index_name, body={"query": query, "size": size})
|
||||
else:
|
||||
return self.client.search(index_name, query, size)
|
||||
|
||||
def vector_search(self, index_name: str, vector: List[float],
|
||||
field: str = "embedding", size: int = 10) -> Dict:
|
||||
"""向量搜索"""
|
||||
if self.client_type == "official":
|
||||
search_body = {
|
||||
"knn": {
|
||||
"field": field,
|
||||
"query_vector": vector,
|
||||
"k": size,
|
||||
"num_candidates": size * 2
|
||||
},
|
||||
"size": size
|
||||
}
|
||||
return self.client.search(index=index_name, body=search_body)
|
||||
else:
|
||||
return self.client.vector_search(index_name, vector, field, size)
|
||||
|
||||
|
||||
def test_wrapper():
|
||||
"""测试包装器"""
|
||||
print("=== 测试ES客户端包装器 ===")
|
||||
|
||||
try:
|
||||
client = ESClientWrapper()
|
||||
|
||||
if client.ping():
|
||||
print(f"[OK] 连接成功! 使用客户端类型: {client.client_type}")
|
||||
|
||||
# 获取集群信息
|
||||
info = client.info()
|
||||
print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}")
|
||||
print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}")
|
||||
|
||||
# 获取集群健康状态
|
||||
health = client.cluster_health()
|
||||
print(f"[OK] 集群状态: {health.get('status', 'N/A')}")
|
||||
|
||||
return client
|
||||
else:
|
||||
print("✗ 连接失败")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_client = test_wrapper()
|
||||
if working_client:
|
||||
print("\n[OK] ES客户端包装器工作正常!")
|
||||
else:
|
||||
print("\n✗ ES客户端包装器初始化失败")
|
||||
print("\n请按照 setup_local_es.md 中的说明安装本地Elasticsearch")
|
||||
207
elasticsearch_vectorization/http_es_client.py
Normal file
207
elasticsearch_vectorization/http_es_client.py
Normal file
@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
基于HTTP请求的简单ES客户端
|
||||
用于替代官方ES客户端,解决网络连接问题
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Any, Optional
|
||||
from config import ElasticsearchConfig
|
||||
|
||||
class HTTPESClient:
|
||||
def delete_by_query(self, index_name: str, query: Dict) -> Dict:
|
||||
"""
|
||||
根据查询条件删除索引中的文档(常用于全量清空)
|
||||
"""
|
||||
# POST /{index}/_delete_by_query
|
||||
return self._make_request("POST", f"/{index_name}/_delete_by_query", query)
|
||||
"""基于HTTP请求的ES客户端"""
|
||||
|
||||
def __init__(self, host: str = None, username: str = None, password: str = None):
|
||||
self.host = (host or ElasticsearchConfig.ES_HOST).rstrip('/')
|
||||
self.username = username or ElasticsearchConfig.ES_USERNAME
|
||||
self.password = password or ElasticsearchConfig.ES_PASSWORD
|
||||
|
||||
# 设置认证
|
||||
self.session = requests.Session()
|
||||
if self.username and self.password:
|
||||
auth_string = f"{self.username}:{self.password}"
|
||||
auth_bytes = auth_string.encode('ascii')
|
||||
auth_b64 = base64.b64encode(auth_bytes).decode('ascii')
|
||||
self.session.headers.update({
|
||||
'Authorization': f'Basic {auth_b64}',
|
||||
'Content-Type': 'application/json'
|
||||
})
|
||||
|
||||
# 禁用SSL警告
|
||||
import urllib3
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
def _make_request(self, method: str, path: str, data: Optional[Dict] = None,
|
||||
timeout: int = 30) -> Dict:
|
||||
"""发送HTTP请求"""
|
||||
url = f"{self.host}{path}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = self.session.get(url, timeout=timeout, verify=False)
|
||||
elif method.upper() == 'POST':
|
||||
response = self.session.post(url, json=data, timeout=timeout, verify=False)
|
||||
elif method.upper() == 'PUT':
|
||||
response = self.session.put(url, json=data, timeout=timeout, verify=False)
|
||||
elif method.upper() == 'DELETE':
|
||||
response = self.session.delete(url, timeout=timeout, verify=False)
|
||||
elif method.upper() == 'HEAD':
|
||||
response = self.session.head(url, timeout=timeout, verify=False)
|
||||
return {"status_code": response.status_code}
|
||||
else:
|
||||
raise ValueError(f"不支持的HTTP方法: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if response.content:
|
||||
return response.json()
|
||||
else:
|
||||
return {"status_code": response.status_code}
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
raise Exception(f"请求超时: {url}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise Exception(f"连接错误: {url}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise Exception(f"HTTP错误 {response.status_code}: {response.text}")
|
||||
except Exception as e:
|
||||
raise Exception(f"请求失败: {e}")
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""测试连接"""
|
||||
try:
|
||||
result = self._make_request("GET", "/", timeout=10)
|
||||
return "cluster_name" in result
|
||||
except:
|
||||
return False
|
||||
|
||||
def info(self) -> Dict:
|
||||
"""获取ES集群信息"""
|
||||
return self._make_request("GET", "/")
|
||||
|
||||
def cluster_health(self) -> Dict:
|
||||
"""获取集群健康状态"""
|
||||
return self._make_request("GET", "/_cluster/health")
|
||||
|
||||
def create_index(self, index_name: str, mapping: Dict) -> Dict:
|
||||
"""创建索引"""
|
||||
return self._make_request("PUT", f"/{index_name}", mapping)
|
||||
|
||||
def delete_index(self, index_name: str) -> Dict:
|
||||
"""删除索引"""
|
||||
return self._make_request("DELETE", f"/{index_name}")
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
try:
|
||||
result = self._make_request("HEAD", f"/{index_name}")
|
||||
return result.get("status_code") == 200
|
||||
except:
|
||||
return False
|
||||
|
||||
def index_document(self, index_name: str, doc_id: str, document: Dict) -> Dict:
|
||||
"""索引文档"""
|
||||
return self._make_request("PUT", f"/{index_name}/_doc/{doc_id}", document)
|
||||
|
||||
def bulk_index(self, index_name: str, documents: List[Dict]) -> Dict:
|
||||
"""批量索引文档"""
|
||||
# 构造bulk请求体
|
||||
bulk_lines = []
|
||||
for doc in documents:
|
||||
# 添加索引操作
|
||||
bulk_lines.append(json.dumps({"index": {"_index": index_name}}))
|
||||
bulk_lines.append(json.dumps(doc))
|
||||
|
||||
bulk_body = "\n".join(bulk_lines) + "\n"
|
||||
|
||||
# 发送bulk请求
|
||||
url = f"{self.host}/_bulk"
|
||||
headers = self.session.headers.copy()
|
||||
headers['Content-Type'] = 'application/x-ndjson'
|
||||
|
||||
response = self.session.post(url, data=bulk_body, headers=headers,
|
||||
timeout=60, verify=False)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def search(self, index_name: str, query: Dict, size: int = 10) -> Dict:
|
||||
"""搜索文档"""
|
||||
# 修复:如果query已经包含size,直接使用query;否则添加size
|
||||
if "size" in query:
|
||||
search_body = query
|
||||
else:
|
||||
search_body = {
|
||||
"query": query,
|
||||
"size": size
|
||||
}
|
||||
return self._make_request("POST", f"/{index_name}/_search", search_body)
|
||||
|
||||
def vector_search(self, index_name: str, vector: List[float],
|
||||
field: str = "embedding", size: int = 10) -> Dict:
|
||||
"""向量搜索"""
|
||||
search_body = {
|
||||
"knn": {
|
||||
"field": field,
|
||||
"query_vector": vector,
|
||||
"k": size,
|
||||
"num_candidates": size * 2
|
||||
},
|
||||
"size": size
|
||||
}
|
||||
return self._make_request("POST", f"/{index_name}/_search", search_body)
|
||||
|
||||
|
||||
def test_http_client():
|
||||
"""测试HTTP ES客户端"""
|
||||
print("=== 测试HTTP ES客户端 ===")
|
||||
|
||||
# 尝试不同的主机配置
|
||||
hosts_to_try = [
|
||||
"http://101.200.154.78:9200",
|
||||
"http://127.0.0.1:9200",
|
||||
"http://localhost:9200"
|
||||
]
|
||||
|
||||
for host in hosts_to_try:
|
||||
print(f"\n尝试连接: {host}")
|
||||
client = HTTPESClient(host=host)
|
||||
|
||||
try:
|
||||
if client.ping():
|
||||
print("[OK] 连接成功!")
|
||||
|
||||
# 获取集群信息
|
||||
info = client.info()
|
||||
print(f"[OK] 集群名称: {info.get('cluster_name', 'N/A')}")
|
||||
print(f"[OK] ES版本: {info.get('version', {}).get('number', 'N/A')}")
|
||||
|
||||
# 获取集群健康状态
|
||||
health = client.cluster_health()
|
||||
print(f"[OK] 集群状态: {health.get('status', 'N/A')}")
|
||||
|
||||
return client
|
||||
else:
|
||||
print("✗ ping失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 连接失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_client = test_http_client()
|
||||
if working_client:
|
||||
print("\n[OK] 找到可用的ES连接!")
|
||||
else:
|
||||
print("\n✗ 未找到可用的ES连接")
|
||||
Reference in New Issue
Block a user