first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

@ -0,0 +1,9 @@
{
"permissions": {
"allow": [
"Bash(python test_numpy_embedding.py:*)"
],
"deny": [],
"ask": []
}
}

57
.dockerignore Normal file
View File

@ -0,0 +1,57 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
*.egg
*.egg-info/
dist/
build/
.pytest_cache/
.coverage
htmlcov/
# 虚拟环境
venv/
env/
ENV/
# IDE
.vscode/
.idea/
*.swp
*.swo
# Git
.git/
.gitignore
# 环境变量(包含敏感信息)
.env
.env.*
# 测试数据
*.pkl
test_*.py
*_test.py
# 日志和输出
*.log
api_outputs/
logs/
# 文档
*.md
docs/
# 系统文件
.DS_Store
Thumbs.db
# Claude相关
.claude/
# 临时文件
*.tmp
*.bak
*.backup

24
.env Normal file
View File

@ -0,0 +1,24 @@
# 阿里云DashScope API密钥直接调用无需OneAPI中转
ONEAPI_KEY=sk-2c5045478a244022865e65c5c9b6adf7
# 大语言模型名称阿里云DashScope原生模型名
ONEAPI_MODEL=qwen2-7b-instruct
ONEAPI_MODEL_GEN=qwen2-7b-instruct
ONEAPI_MODEL_MAX=qwen2-7b-instruct
# 嵌入模型名称
ONEAPI_MODEL_EMBED=text-embedding-v3
# LangSmith追踪配置
LANGCHAIN_TRACING_V2=false
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
LANGCHAIN_API_KEY=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
LANGCHAIN_PROJECT=hipporag-retriever
# ============= Elasticsearch配置 =============
# Elasticsearch服务器地址
ELASTICSEARCH_HOST=http://101.200.154.78:9200
# Elasticsearch认证信息
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=Abcd123456

87
.env.example Normal file
View File

@ -0,0 +1,87 @@
# =======================
# AIEC-RAG 环境配置示例
# =======================
# 复制此文件为 .env 并填入实际配置值
# =======================
# API配置
# =======================
# OneAPI/DashScope API密钥
# 获取地址https://dashscope.console.aliyun.com/
ONEAPI_KEY=sk-your-api-key-here
# =======================
# 模型配置
# =======================
# 主语言模型(用于生成答案)
# 可选qwen2-7b-instruct, qwen-max, gpt-3.5-turbo等
ONEAPI_MODEL=qwen2-7b-instruct
# 生成模型(用于查询分解等任务)
ONEAPI_MODEL_GEN=qwen2-7b-instruct
# 最大上下文模型(用于处理长文本)
ONEAPI_MODEL_MAX=qwen2-7b-instruct
# 嵌入模型(用于向量化)
# 可选text-embedding-v3, text-embedding-ada-002等
ONEAPI_MODEL_EMBED=text-embedding-v3
# =======================
# Elasticsearch配置
# =======================
# Elasticsearch服务器地址
# 本地部署http://localhost:9200
# 远程部署http://your-es-server:9200
ELASTICSEARCH_HOST=http://localhost:9200
# Elasticsearch认证信息
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=your-password-here
# =======================
# LangSmith配置可选
# =======================
# 是否启用LangSmith追踪
# true: 启用追踪false: 禁用追踪
LANGCHAIN_TRACING_V2=false
# LangSmith API端点
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
# LangSmith API密钥
# 获取地址https://smith.langchain.com/
LANGCHAIN_API_KEY=lsv2_pt_your_api_key_here
# LangSmith项目名称
LANGCHAIN_PROJECT=aiec-rag
# =======================
# 服务配置(可选)
# =======================
# 服务端口默认8100
# SERVICE_PORT=8100
# 服务主机默认0.0.0.0
# SERVICE_HOST=0.0.0.0
# 调试模式默认false
# DEBUG_MODE=false
# =======================
# 性能配置(可选)
# =======================
# 最大并发请求数
# MAX_CONCURRENT_REQUESTS=10
# 请求超时时间(秒)
# REQUEST_TIMEOUT=120
# 缓存大小MB
# CACHE_SIZE=512

579
DEPLOYMENT_GUIDE.md Normal file
View File

@ -0,0 +1,579 @@
# AIEC-RAG 部署指南
## 目录
1. [部署架构](#部署架构)
2. [单机部署](#单机部署)
3. [Docker部署](#docker部署)
4. [生产环境部署](#生产环境部署)
5. [性能调优](#性能调优)
6. [监控配置](#监控配置)
7. [备份恢复](#备份恢复)
## 部署架构
### 推荐架构
```
[负载均衡器]
|
┌────────────┼────────────┐
↓ ↓ ↓
[AIEC-RAG-1] [AIEC-RAG-2] [AIEC-RAG-3]
↓ ↓ ↓
└────────────┼────────────┘
[Elasticsearch集群]
[向量数据库]
```
### 最小配置要求
| 组件 | CPU | 内存 | 存储 | 说明 |
|-----|-----|------|------|------|
| API服务 | 4核 | 8GB | 50GB | 单实例最小配置 |
| Elasticsearch | 4核 | 16GB | 200GB | 推荐使用SSD |
| 整体系统 | 8核 | 32GB | 500GB | 生产环境推荐 |
## 单机部署
### 1. 系统准备
```bash
# Ubuntu/Debian
sudo apt update
sudo apt install -y python3.8 python3-pip git curl wget
# CentOS/RHEL
sudo yum update -y
sudo yum install -y python38 python38-pip git curl wget
```
### 2. 安装Elasticsearch
```bash
# 下载并安装Elasticsearch 8.x
wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-8.11.0-linux-x86_64.tar.gz
tar -xzf elasticsearch-8.11.0-linux-x86_64.tar.gz
cd elasticsearch-8.11.0
# 配置Elasticsearch
cat >> config/elasticsearch.yml << EOF
network.host: 0.0.0.0
discovery.type: single-node
xpack.security.enabled: true
xpack.security.authc.api_key.enabled: true
EOF
# 启动Elasticsearch
./bin/elasticsearch -d
```
### 3. 部署AIEC-RAG
```bash
# 克隆项目
git clone <repository_url>
cd AIEC-RAG
# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate
# 安装依赖
pip install -r requirements.txt
# 配置环境变量
cp .env.example .env
# 编辑.env文件填入实际配置
# 启动服务
python rag_api_server_production.py
```
### 4. 设置系统服务
创建 `/etc/systemd/system/aiec-rag.service`:
```ini
[Unit]
Description=AIEC-RAG Service
After=network.target elasticsearch.service
[Service]
Type=simple
User=aiec
WorkingDirectory=/opt/AIEC-RAG
Environment="PATH=/opt/AIEC-RAG/venv/bin"
ExecStart=/opt/AIEC-RAG/venv/bin/python /opt/AIEC-RAG/rag_api_server_production.py
Restart=always
RestartSec=10
[Install]
WantedBy=multi-user.target
```
启用服务:
```bash
sudo systemctl daemon-reload
sudo systemctl enable aiec-rag
sudo systemctl start aiec-rag
sudo systemctl status aiec-rag
```
## Docker部署
### 1. 使用预构建镜像
```bash
# 拉取镜像(如果有私有仓库)
docker pull your-registry/aiec-rag:latest
# 或构建本地镜像
docker build -t aiec-rag:latest .
```
### 2. Docker Compose部署
创建 `docker-compose.yml`:
```yaml
version: '3.8'
services:
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.11.0
container_name: aiec-elasticsearch
environment:
- discovery.type=single-node
- "ES_JAVA_OPTS=-Xms2g -Xmx2g"
- xpack.security.enabled=true
- ELASTIC_PASSWORD=your_password
volumes:
- es_data:/usr/share/elasticsearch/data
ports:
- "9200:9200"
networks:
- aiec_network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9200"]
interval: 30s
timeout: 10s
retries: 5
aiec-rag:
build: .
container_name: aiec-rag
depends_on:
elasticsearch:
condition: service_healthy
environment:
- ELASTICSEARCH_HOST=http://elasticsearch:9200
- ELASTICSEARCH_USERNAME=elastic
- ELASTICSEARCH_PASSWORD=your_password
env_file:
- .env
ports:
- "8100:8100"
volumes:
- ./rag_config_production.yaml:/app/rag_config_production.yaml
- ./api_outputs:/app/api_outputs
networks:
- aiec_network
restart: unless-stopped
volumes:
es_data:
driver: local
networks:
aiec_network:
driver: bridge
```
启动服务:
```bash
docker-compose up -d
docker-compose logs -f
```
### 3. Kubernetes部署
创建 `k8s-deployment.yaml`:
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: aiec-rag
labels:
app: aiec-rag
spec:
replicas: 3
selector:
matchLabels:
app: aiec-rag
template:
metadata:
labels:
app: aiec-rag
spec:
containers:
- name: aiec-rag
image: your-registry/aiec-rag:latest
ports:
- containerPort: 8100
env:
- name: ELASTICSEARCH_HOST
value: "http://elasticsearch-service:9200"
envFrom:
- secretRef:
name: aiec-secrets
resources:
requests:
memory: "4Gi"
cpu: "2"
limits:
memory: "8Gi"
cpu: "4"
livenessProbe:
httpGet:
path: /health
port: 8100
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8100
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: aiec-rag-service
spec:
selector:
app: aiec-rag
ports:
- protocol: TCP
port: 80
targetPort: 8100
type: LoadBalancer
```
部署到Kubernetes:
```bash
# 创建密钥
kubectl create secret generic aiec-secrets --from-env-file=.env
# 部署应用
kubectl apply -f k8s-deployment.yaml
# 查看状态
kubectl get pods
kubectl get services
```
## 生产环境部署
### 1. 负载均衡配置
使用Nginx作为负载均衡器
```nginx
upstream aiec_backend {
least_conn;
server 10.0.1.10:8100 weight=1 max_fails=3 fail_timeout=30s;
server 10.0.1.11:8100 weight=1 max_fails=3 fail_timeout=30s;
server 10.0.1.12:8100 weight=1 max_fails=3 fail_timeout=30s;
}
server {
listen 80;
server_name api.aiec-rag.com;
location / {
proxy_pass http://aiec_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_cache_bypass $http_upgrade;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# 超时设置
proxy_connect_timeout 60s;
proxy_send_timeout 120s;
proxy_read_timeout 120s;
}
# 健康检查端点
location /health {
proxy_pass http://aiec_backend/health;
}
}
```
### 2. SSL/TLS配置
```nginx
server {
listen 443 ssl http2;
server_name api.aiec-rag.com;
ssl_certificate /etc/nginx/ssl/aiec-rag.crt;
ssl_certificate_key /etc/nginx/ssl/aiec-rag.key;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers HIGH:!aNULL:!MD5;
ssl_prefer_server_ciphers on;
# ... 其他配置同上
}
```
### 3. 数据库优化
Elasticsearch优化配置
```yaml
# elasticsearch.yml
cluster.name: aiec-rag-cluster
node.name: node-1
# 内存设置
bootstrap.memory_lock: true
# 线程池
thread_pool:
write:
size: 8
queue_size: 1000
search:
size: 16
queue_size: 1000
# 索引设置
index:
number_of_shards: 3
number_of_replicas: 1
refresh_interval: 30s
```
## 性能调优
### 1. Python应用优化
```python
# 使用Gunicorn作为WSGI服务器Linux
gunicorn -w 4 -k uvicorn.workers.UvicornWorker \
--bind 0.0.0.0:8100 \
--timeout 120 \
--keep-alive 5 \
--max-requests 1000 \
--max-requests-jitter 50 \
rag_api_server_production:app
```
### 2. 系统参数优化
```bash
# /etc/sysctl.conf
net.ipv4.tcp_fin_timeout = 30
net.ipv4.tcp_tw_reuse = 1
net.ipv4.tcp_tw_recycle = 1
net.ipv4.tcp_max_syn_backlog = 8192
net.ipv4.tcp_max_tw_buckets = 10000
net.core.somaxconn = 65535
net.core.netdev_max_backlog = 65535
# 应用配置
sudo sysctl -p
```
### 3. 缓存策略
配置Redis缓存
```python
# 在代码中添加缓存支持
import redis
from functools import lru_cache
redis_client = redis.Redis(
host='localhost',
port=6379,
decode_responses=True,
max_connections=50
)
@lru_cache(maxsize=128)
def get_cached_embedding(text: str):
# 缓存嵌入向量
pass
```
## 监控配置
### 1. Prometheus监控
```yaml
# prometheus.yml
scrape_configs:
- job_name: 'aiec-rag'
static_configs:
- targets: ['localhost:8100']
metrics_path: '/metrics'
scrape_interval: 15s
```
### 2. 日志管理
配置日志轮转:
```bash
# /etc/logrotate.d/aiec-rag
/opt/AIEC-RAG/logs/*.log {
daily
rotate 30
compress
delaycompress
missingok
notifempty
create 644 aiec aiec
sharedscripts
postrotate
systemctl reload aiec-rag
endscript
}
```
### 3. 告警配置
```yaml
# alerting_rules.yml
groups:
- name: aiec_alerts
rules:
- alert: HighResponseTime
expr: http_request_duration_seconds{quantile="0.99"} > 5
for: 5m
labels:
severity: warning
annotations:
summary: "High response time on {{ $labels.instance }}"
- alert: ServiceDown
expr: up{job="aiec-rag"} == 0
for: 1m
labels:
severity: critical
annotations:
summary: "AIEC-RAG service is down"
```
## 备份恢复
### 1. 数据备份
```bash
#!/bin/bash
# backup.sh
DATE=$(date +%Y%m%d_%H%M%S)
BACKUP_DIR="/backup/aiec-rag"
# 备份Elasticsearch数据
curl -X PUT "localhost:9200/_snapshot/backup_repo" -H 'Content-Type: application/json' -d'
{
"type": "fs",
"settings": {
"location": "'$BACKUP_DIR'/elasticsearch"
}
}'
curl -X PUT "localhost:9200/_snapshot/backup_repo/snapshot_$DATE?wait_for_completion=true"
# 备份配置文件
tar -czf $BACKUP_DIR/config_$DATE.tar.gz \
/opt/AIEC-RAG/.env \
/opt/AIEC-RAG/rag_config_production.yaml
echo "Backup completed: $DATE"
```
### 2. 恢复流程
```bash
#!/bin/bash
# restore.sh
SNAPSHOT_NAME=$1
# 恢复Elasticsearch数据
curl -X POST "localhost:9200/_snapshot/backup_repo/$SNAPSHOT_NAME/_restore"
# 恢复配置文件
tar -xzf /backup/aiec-rag/config_latest.tar.gz -C /
# 重启服务
systemctl restart aiec-rag
echo "Restore completed from: $SNAPSHOT_NAME"
```
## 故障处理
### 常见问题处理
1. **服务无响应**
```bash
# 检查服务状态
systemctl status aiec-rag
# 查看日志
journalctl -u aiec-rag -n 100
# 重启服务
systemctl restart aiec-rag
```
2. **Elasticsearch连接失败**
```bash
# 检查ES状态
curl -X GET "localhost:9200/_cluster/health?pretty"
# 检查网络连接
telnet localhost 9200
```
3. **内存溢出**
```bash
# 增加内存限制
export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=4
```
## 安全建议
1. **API密钥管理**
- 使用密钥管理服务如HashiCorp Vault
- 定期轮换API密钥
- 不要在代码中硬编码密钥
2. **网络安全**
- 使用防火墙限制访问
- 配置SSL/TLS加密
- 实施速率限制
3. **数据安全**
- 加密敏感数据
- 定期备份
- 实施访问控制
---
*更多部署问题请参考项目Wiki或联系技术支持。*

41
Dockerfile Normal file
View File

@ -0,0 +1,41 @@
# Python RAG API服务 Dockerfile
FROM python:3.10-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY retriver/requirements.txt /app/requirements.txt
# 安装Python依赖
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt && \
pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
fastapi \
uvicorn[standard] \
pyyaml
# 复制项目文件
COPY . /app
# 创建输出目录
RUN mkdir -p /app/api_outputs
# 设置环境变量
ENV PYTHONPATH=/app
ENV RAG_CONFIG_PATH=rag_config_production.yaml
# 暴露端口
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动服务
CMD ["uvicorn", "rag_api_server_production:app", "--host", "0.0.0.0", "--port", "8000"]

61
Dockerfile.cpu-full Normal file
View File

@ -0,0 +1,61 @@
FROM python:3.11-slim
WORKDIR /app
# 配置pip使用阿里云镜像
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
ENV PIP_NO_CACHE_DIR=1
# 复制requirements文件
COPY requirements-cpu-full.txt requirements.txt
# 不使用apt-get直接用pip
# 升级pip并安装所有依赖
RUN pip install --upgrade pip setuptools wheel && \
pip install torch --index-url https://download.pytorch.org/whl/cpu && \
pip install transformers sentence-transformers && \
pip install -r requirements.txt
# 复制项目文件
COPY . .
# 确保test_with_concept.pkl文件在正确位置
# 文件应该在项目根目录,如果不在,需要单独复制
RUN if [ -f /app/test_with_concept.pkl ]; then \
echo "test_with_concept.pkl found in /app"; \
else \
echo "WARNING: test_with_concept.pkl not found"; \
fi
# 创建必要目录
RUN mkdir -p /app/api_outputs /app/logs
# 设置环境变量
ENV PYTHONPATH=/app
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV RAG_CONFIG_PATH=rag_config_production.yaml
# API配置
ENV ONEAPI_KEY=sk-2c5045478a244022865e65c5c9b6adf7
ENV ONEAPI_MODEL=qwen2-7b-instruct
ENV ONEAPI_MODEL_GEN=qwen2-7b-instruct
ENV ONEAPI_MODEL_MAX=qwen2-7b-instruct
ENV ONEAPI_MODEL_EMBED=text-embedding-v3
# LangSmith配置
ENV LANGCHAIN_TRACING_V2=false
ENV LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
ENV LANGCHAIN_API_KEY=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
ENV LANGCHAIN_PROJECT=hipporag-retriever
# Elasticsearch配置
ENV ELASTICSEARCH_HOST=http://101.200.154.78:9200
ENV ELASTICSEARCH_USERNAME=elastic
ENV ELASTICSEARCH_PASSWORD=Abcd123456
EXPOSE 8000
CMD ["uvicorn", "rag_api_server_production:app", "--host", "0.0.0.0", "--port", "8000"]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

1
atlas_rag/__init__.py Normal file
View File

@ -0,0 +1 @@
from .logging import setup_logger

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1 @@
from .benchmark import BenchMarkConfig, RAGBenchmark

View File

@ -0,0 +1,236 @@
import os
import json
import numpy as np
from logging import Logger
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
from typing import List
from datetime import datetime
from transformers import AutoModel
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
import torch.nn.functional as F
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.evaluation.evaluation import QAJudger
from dataclasses import dataclass
from atlas_rag.llm_generator.prompt.react import ReAct
def normalize_embeddings(embeddings):
"""Normalize the embeddings to unit length (L2 norm)."""
if isinstance(embeddings, torch.Tensor):
# Handle PyTorch tensors
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
elif isinstance(embeddings, np.ndarray):
# Handle numpy arrays
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
else:
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
return norm_emb
@dataclass
class BenchMarkConfig:
"""
Configuration class for benchmarking.
Attributes:
dataset_name (str): Name of the dataset. Default is "hotpotqa".
question_file (str): Path to the question file. Default is "hotpotqa".
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
include_events (bool): Whether to include events. Default is False.
include_concept (bool): Whether to include concepts. Default is False.
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
"""
dataset_name: str = "hotpotqa"
question_file: str = "hotpotqa"
include_events: bool = False
include_concept: bool = False
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
encoder_model_name: str = "nvidia/NV-Embed-v2"
number_of_samples: int = -1 # Default to -1 to use all samples
react_max_iterations: int = 5
class RAGBenchmark:
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
self.config = config
self.logger = logger
self.logging : bool = self.logger is not None
def load_encoder_model(self, encoder_model_name, **kwargs):
if encoder_model_name == "nvidia/NV-Embed-v2":
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
return NvEmbed(sentence_encoder)
else:
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
return SentenceEmbedding(sentence_encoder)
def run(self, retrievers:List[BaseRetriever],
llm_generator:LLMGenerator,
use_react: bool = False):
qa_judge = QAJudger()
if use_react:
react_agent = ReAct(llm_generator=llm_generator)
result_list = []
with open(self.config.question_file, "r") as f:
data = json.load(f)
print(f"Data loaded from {self.config.question_file}")
if self.config.number_of_samples > 0:
data = data[:self.config.number_of_samples]
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
for sample in tqdm(data):
question = sample["question"]
answer = sample["answer"]
gold_file_ids = []
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
for fact in sample["supporting_facts"]:
gold_file_ids.append(fact[0])
elif self.config.dataset_name == "musique":
for paragraph in sample["paragraphs"]:
if paragraph["is_supporting"]:
gold_file_ids.append(paragraph["paragraph_text"])
else:
print("Dataset not supported")
continue
result = {
"question": question,
"answer": answer,
"gold_file_ids": gold_file_ids,
}
if self.logging:
self.logger.info(f"Question: {question}")
for retriever in retrievers:
if use_react:
# Use RAG with ReAct
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
question=question,
retriever=retriever,
max_iterations=self.config.react_max_iterations,
max_new_tokens=2048,
logger=self.logger
)
self.logger.info(f"Search history: {search_history}")
self.logger.info(f"Final answer: {llm_generated_answer}")
# Store search history in results
result[f"{retriever.__class__.__name__}_search_history"] = search_history
# Extract all retrieved contexts from search history
all_contexts = []
for _, action, observation in search_history:
if "search" in action.lower() or "look up" in action.lower():
all_contexts.append(observation)
sorted_context = "\n".join(all_contexts)
sorted_context_ids = [] # We don't track IDs in ReAct mode
else:
# Original RAG implementation
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
if isinstance(retriever, BaseEdgeRetriever):
retrieved_context = ". ".join(sorted_context)
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
elif isinstance(retriever, BasePassageRetriever):
retrieved_context = "\n".join(sorted_context)
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
if self.logging:
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
short_answer = qa_judge.split_answer(llm_generated_answer)
em, f1 = qa_judge.judge(short_answer, answer)
result[f"{retriever.__class__.__name__ }_em"] = em
result[f"{retriever.__class__.__name__ }_f1"] = f1
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
if not use_react:
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
# Calculate recall
if not use_react: # Only calculate recall for non-ReAct mode
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
elif self.config.dataset_name == "musique":
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
result_list.append(result)
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
def save_results(self, result_list, retriever_names:List[str]):
current_time = datetime.now()
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
dataset_name = self.config.dataset_name
include_events = self.config.include_events
include_concept = self.config.include_concept
encoder_model_name = self.config.encoder_model_name
reader_model_name = self.config.reader_model_name
# use last part of model name as identifier
if "/" in encoder_model_name:
encoder_model_name = encoder_model_name.split("/")[-1]
if "/" in reader_model_name:
reader_model_name = reader_model_name.split("/")[-1]
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
if not os.path.exists(os.path.dirname(summary_file)):
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
if not os.path.exists(os.path.dirname(result_dir)):
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
summary_dict = self.calculate_summary(result_list, retriever_names)
with open(summary_file, "w") as f_summary:
json.dump(summary_dict, f_summary)
f_summary.write("\n")
with open(result_dir, "w") as f:
for result in result_list:
json.dump(result, f)
f.write("\n")
def calculate_summary(self, result_list, method):
summary_dict = {}
for retriever_name in method:
if not all(f"{retriever_name}_em" in result for result in result_list):
raise ValueError(f"Missing {retriever_name}_em in results")
if not all(f"{retriever_name}_f1" in result for result in result_list):
raise ValueError(f"Missing {retriever_name}_f1 in results")
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
# Only calculate recall metrics if they exist in the results
if all(f"{retriever_name}_recall@2" in result for result in result_list):
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
summary_dict.update({
f"{retriever_name}_average_f1": average_f1,
f"{retriever_name}_average_em": average_em,
f"{retriever_name}_average_recall@2": average_recall_2,
f"{retriever_name}_average_recall@5": average_recall_5,
})
else:
# For ReAct mode where recall metrics don't exist
summary_dict.update({
f"{retriever_name}_average_f1": average_f1,
f"{retriever_name}_average_em": average_em,
})
return summary_dict

View File

@ -0,0 +1,158 @@
import re
from collections import Counter
from typing import Tuple
import argparse
import json
from tqdm import tqdm
class QAJudger:
def __init__(self):
pass
def split_answer(self, generated_text):
if "Answer:" in generated_text:
generated_text = generated_text.split("Answer:")[-1]
elif "answer:" in generated_text:
generated_text = generated_text.split("answer:")[-1]
# if answer is none
if not generated_text:
return "none"
return generated_text
def normalize_answer(self, answer: str) -> str:
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
# Lowercase and normalize whitespace
answer = answer.lower()
# Replace hyphens with spaces
answer = answer.replace('-', ' ')
# Remove all other punctuation
answer = re.sub(r'[^\w\s]', '', answer)
# Standardize whitespace
return ' '.join(answer.split())
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
"""Direct port of the original scoring logic"""
# Extract answer from generated text
pred_answer = self.split_answer(generated_text)
# Normalize both answers
pred_norm = self.normalize_answer(pred_answer)
ref_norm = self.normalize_answer(reference_text)
# Exact match calculation
em = 1 if pred_norm == ref_norm else 0
# F1 calculation (direct port from QAF1Score)
pred_tokens = pred_norm.split()
ref_tokens = ref_norm.split()
common = Counter(pred_tokens) & Counter(ref_tokens)
num_same = sum(common.values())
if num_same == 0:
return em, 0.0
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
if (precision + recall) == 0:
f1 = 0.0
else:
f1 = 2 * (precision * recall) / (precision + recall)
return em, f1
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
"""Calculates recall at k based on the top k retrieved texts."""
successful_retrievals = 0
# Limit the retrieved texts to the top k entries
limited_retrieved_text = retrieved_text[:k]
for ref_text in reference_text:
for ret_text in limited_retrieved_text:
if ref_text in ret_text:
successful_retrievals += 1
break
recall = successful_retrievals / len(reference_text) if reference_text else 0
return recall
# recall for 1 answer
def recall(self, retrieved_text: list, reference_text: list) -> dict:
"""Calculates recall values at different k levels."""
recall_values = {
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
}
return recall_values['recall@2'], recall_values['recall@5']
if __name__ == "__main__":
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
args = argument_parser.parse_args()
# Initialize the QAJudger
llm_judge = QAJudger()
# Load results from the JSON file
result_list = []
with open(args.file_path, 'r') as file:
for line in file:
if line.strip(): # Make sure the line is not empty
try:
result = json.loads(line.strip())
result_list.append(result)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
# Debugging output to inspect the loaded data structure
# print("Loaded data structure:", result_list)
# Evaluate each entry in result_list
for result in tqdm(result_list):
if isinstance(result, dict): # Ensure each result is a dictionary
question = result["question"]
answer = result["answer"]
# Evaluate generated answers with Hippo and Hippo2
hippo_generated_answer = result["hippo_generated_answer"]
hippo2_generated_answer = result["hippo2_generated_answer"]
# Split and judge the answers
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
# Store the scores back in the result dictionary
result["hippo_em"] = hippo_em
result["hippo_f1"] = hippo_f1
result["hippo2_em"] = hippo2_em
result["hippo2_f1"] = hippo2_f1
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
# Calculate averages
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
# Output the averages
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
print(f"Average Recall@2: {average_recall2:.4f}")
print(f"Average Recall@5: {average_recall5:.4f}")
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")

View File

View File

@ -0,0 +1,282 @@
from tqdm import tqdm
import random
import logging
import csv
import os
import hashlib
import re
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.kg_construction.triple_config import ProcessingConfig
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import get_node_id
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import CONCEPT_INSTRUCTIONS
import pickle
# Increase the field size limit
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
def build_batch_data(sessions, batch_size):
batched_sessions = []
for i in range(0, len(sessions), batch_size):
batched_sessions.append(sessions[i:i+batch_size])
return batched_sessions
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def convert_attribute(value):
""" Convert attributes to GDS-compatible types. """
if isinstance(value, list):
return [str(v) for v in value]
elif isinstance(value, (int, float)):
return value
else:
return str(value)
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
new_text = re.sub(r'\s+', ' ', new_text).strip()
return new_text
def remove_NUL(text):
return text.replace("\x00", "")
def build_batched_events(all_node_list, batch_size):
"The types are in Entity Event Relation"
event_nodes = [node[0] for node in all_node_list if node[1].lower() == "event"]
batched_events = []
for i in range(0, len(event_nodes), batch_size):
batched_events.append(event_nodes[i:i+batch_size])
return batched_events
def build_batched_entities(all_node_list, batch_size):
entity_nodes = [node[0] for node in all_node_list if node[1].lower() == "entity"]
batched_entities = []
for i in range(0, len(entity_nodes), batch_size):
batched_entities.append(entity_nodes[i:i+batch_size])
return batched_entities
def build_batched_relations(all_node_list, batch_size):
relations = [node[0] for node in all_node_list if node[1].lower() == "relation"]
# relations = list(set(relations))
batched_relations = []
for i in range(0, len(relations), batch_size):
batched_relations.append(relations[i:i+batch_size])
return batched_relations
def batched_inference(model:LLMGenerator, inputs, record=False, **kwargs):
responses = model.generate_response(inputs, return_text_only = not record, **kwargs)
answers = []
if record:
text_responses = [response[0] for response in responses]
usages = [response[1] for response in responses]
else:
text_responses = responses
for i in range(len(text_responses)):
answer = text_responses[i]
answers.append([x.strip().lower() for x in answer.split(",")])
if record:
return answers, usages
else:
return answers
def load_data_with_shard(input_file, shard_idx, num_shards):
with open(input_file, "r") as f:
csv_reader = list(csv.reader(f))
# data = csv_reader
data = csv_reader[1:]
# Random shuffle the data before splitting into shards
random.shuffle(data)
total_lines = len(data)
lines_per_shard = (total_lines + num_shards - 1) // num_shards
start_idx = shard_idx * lines_per_shard
end_idx = min((shard_idx + 1) * lines_per_shard, total_lines)
return data[start_idx:end_idx]
def generate_concept(model: LLMGenerator,
input_file = 'processed_data/triples_csv',
output_folder = 'processed_data/triples_conceptualized',
output_file = 'output.json',
logging_file = 'processed_data/logging.txt',
config:ProcessingConfig=None,
batch_size=32,
shard=0,
num_shards=1,
**kwargs):
log_dir = os.path.dirname(logging_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
# Create the log file if it doesn't exist
if not os.path.exists(logging_file):
open(logging_file, 'w').close()
language = kwargs.get('language', 'en')
record = kwargs.get('record', False)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(logging_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logging.getLogger().addHandler(file_handler)
with open(f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl", "rb") as f:
temp_kg = pickle.load(f)
# read data
if not os.path.exists(output_folder):
os.makedirs(output_folder)
all_missing_nodes = load_data_with_shard(
input_file,
shard_idx=shard,
num_shards=num_shards
)
batched_events = build_batched_events(all_missing_nodes, batch_size)
batched_entities = build_batched_entities(all_missing_nodes, batch_size)
batched_relations = build_batched_relations(all_missing_nodes, batch_size)
all_batches = []
all_batches.extend(('event', batch) for batch in batched_events)
all_batches.extend(('entity', batch) for batch in batched_entities)
all_batches.extend(('relation', batch) for batch in batched_relations)
print("all_batches", len(all_batches))
output_file = output_folder + f"/{output_file.rsplit('.', 1)[0]}_shard_{shard}.csv"
with open(output_file, "w", newline='') as file:
csv_writer = csv.writer(file)
csv_writer.writerow(["node", "conceptualized_node", "node_type"])
# for batch_type, batch in tqdm(all_batches, total=total_batches, desc="Generating concepts"):
# don't use tqdm for now
for batch_type, batch in tqdm(all_batches, desc="Shard_{}".format(shard)):
# print("batch_type", batch_type)
# print("batch", batch)
replace_context_token = None
if batch_type == 'event':
template = CONCEPT_INSTRUCTIONS[language]['event']
node_type = 'event'
replace_token = '[EVENT]'
elif batch_type == 'entity':
template = CONCEPT_INSTRUCTIONS[language]['entity']
node_type = 'entity'
replace_token = '[ENTITY]'
replace_context_token = '[CONTEXT]'
elif batch_type == 'relation':
template = CONCEPT_INSTRUCTIONS[language]['relation']
node_type = 'relation'
replace_token = '[RELATION]'
inputs = []
for node in batch:
# sample node from given node and replace context token.
if replace_context_token:
node_id = get_node_id(node)
entity_predecessors = list(temp_kg.predecessors(node_id))
entity_successors = list(temp_kg.successors(node_id))
context = ""
if len(entity_predecessors) > 0:
random_two_neighbors = random.sample(entity_predecessors, min(1, len(entity_predecessors)))
context += ", ".join([f"{temp_kg.nodes[neighbor]['id']} {temp_kg[neighbor][node_id]['relation']}" for neighbor in random_two_neighbors])
if len(entity_successors) > 0:
random_two_neighbors = random.sample(entity_successors, min(1, len(entity_successors)))
context += ", ".join([f"{temp_kg[node_id][neighbor]['relation']} {temp_kg.nodes[neighbor]['id']}" for neighbor in random_two_neighbors])
prompt = template.replace(replace_token, node).replace(replace_context_token, context)
else:
prompt = template.replace(replace_token, node)
constructed_input = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": f"{prompt}"},
]
inputs.append(constructed_input)
try:
# print("inputs", inputs)
if record:
# If recording, we will get both answers and responses
answers, usages = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
else:
answers = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
usages = None
# print("answers", answers)
except Exception as e:
logging.error(f"Error processing {batch_type} batch: {e}")
raise e
# try:
# answers = batched_inference(llm, sampling_params, inputs)
# except Exception as e:
# logging.error(f"Error processing {batch_type} batch: {e}")
# continue
for i,(node, answer) in enumerate(zip(batch, answers)):
# print(node, answer, node_type)
if usages is not None:
logging.info(f"Usage log: Node {node}, completion_usage: {usages[i]}")
csv_writer.writerow([node, ", ".join(answer), node_type])
file.flush()
# count unique conceptualized nodes
conceptualized_nodes = []
conceptualized_events = []
conceptualized_entities = []
conceptualized_relations = []
with open(output_file, "r") as file:
reader = csv.reader(file)
next(reader)
for row in reader:
conceptualized_nodes.extend(row[1].split(","))
if row[2] == "event":
conceptualized_events.extend(row[1].split(","))
elif row[2] == "entity":
conceptualized_entities.extend(row[1].split(","))
elif row[2] == "relation":
conceptualized_relations.extend(row[1].split(","))
conceptualized_nodes = [x.strip() for x in conceptualized_nodes]
conceptualized_events = [x.strip() for x in conceptualized_events]
conceptualized_entities = [x.strip() for x in conceptualized_entities]
conceptualized_relations = [x.strip() for x in conceptualized_relations]
unique_conceptualized_nodes = list(set(conceptualized_nodes))
unique_conceptualized_events = list(set(conceptualized_events))
unique_conceptualized_entities = list(set(conceptualized_entities))
unique_conceptualized_relations = list(set(conceptualized_relations))
print(f"Number of unique conceptualized nodes: {len(unique_conceptualized_nodes)}")
print(f"Number of unique conceptualized events: {len(unique_conceptualized_events)}")
print(f"Number of unique conceptualized entities: {len(unique_conceptualized_entities)}")
print(f"Number of unique conceptualized relations: {len(unique_conceptualized_relations)}")
return

View File

@ -0,0 +1,153 @@
import ast
import uuid
import csv
from tqdm import tqdm
import hashlib
import os
def generate_uuid():
"""Generate a random UUID"""
return str(uuid.uuid4())
def parse_concepts(s):
"""Parse concepts field and filter empty values"""
try:
parsed = ast.literal_eval(s) if s and s != '[]' else []
return [c.strip() for c in parsed if c.strip()]
except:
return []
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
text = text + '_concept'
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def all_concept_triples_csv_to_csv(node_file, edge_file, concepts_file, output_node_file, output_edge_file, output_full_concept_triple_edges):
# to deal add output the concepts nodes, edges, and new full_triple_edges,
# we need to read the concepts maps to the memory, as it is usually not too large.
# Then we need to iterate over the triple nodes to create concept edges
# Finally we iterate over the triple edges to create the full_triple_edges
# Read missing concept
# relation_concepts_mapping = {}
# all_missing_concepts = []
# check if all output directories exist
output_dir = os.path.dirname(output_node_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_dir = os.path.dirname(output_edge_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_dir = os.path.dirname(output_full_concept_triple_edges)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
node_to_concepts = {}
relation_to_concepts = {}
all_concepts = set()
with open(concepts_file, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
# Load missing concepts list
print("Loading concepts...")
for row in tqdm(reader):
if row['node_type'] == 'relation':
relation = row['node']
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
if relation not in relation_to_concepts:
relation_to_concepts[relation] = concepts
else:
relation_to_concepts[relation].extend(concepts)
relation_to_concepts[relation] = list(set(relation_to_concepts[relation]))
else:
node = row['node']
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
if node not in node_to_concepts:
node_to_concepts[node] = concepts
else:
node_to_concepts[node].extend(concepts)
node_to_concepts[node] = list(set(node_to_concepts[node]))
print("Loading concepts done.")
print(f"Relation to concepts: {len(relation_to_concepts)}")
print(f"Node to concepts: {len(node_to_concepts)}")
# Read triple nodes and write to output concept edges files
print("Processing triple nodes...")
with open(node_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
# name:ID,type,concepts,synsets,:LABEL
header = next(reader)
with open (output_edge_file, 'w', newline='', encoding='utf-8') as f_out:
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
writer.writerow([':START_ID', ':END_ID', 'relation', ':TYPE'])
for row in tqdm(reader):
node_name = row[0]
if node_name in node_to_concepts:
for concept in node_to_concepts[node_name]:
concept_id = compute_hash_id(concept)
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
all_concepts.add(concept)
for concept in parse_concepts(row[2]):
concept_id = compute_hash_id(concept)
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
all_concepts.add(concept)
# Read the concept nodes and write to output concept nodes file
print("Processing concept nodes...")
with open (output_node_file, 'w', newline='', encoding='utf-8') as f_out:
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
writer.writerow(['concept_id:ID', 'name', ':LABEL'])
for concept in tqdm(all_concepts):
concept_id = compute_hash_id(concept)
writer.writerow([concept_id, concept, 'Concept'])
# Read triple edges and write to output full concept triple edges file
print("Processing triple edges...")
with open(edge_file, 'r', encoding='utf-8') as f:
with open(output_full_concept_triple_edges, 'w', newline='', encoding='utf-8') as f_out:
reader = csv.reader(f)
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
header = next(reader)
writer.writerow([':START_ID', ':END_ID', 'relation', 'concepts', 'synsets', ':TYPE'])
for row in tqdm(reader):
src_id = row[0]
end_id = row[1]
relation = row[2]
concepts = row[3]
synsets = row[4]
original_concepts = parse_concepts(concepts)
if relation in relation_to_concepts:
for concept in relation_to_concepts[relation]:
if concept not in original_concepts:
original_concepts.append(concept)
original_concepts = list(set(original_concepts))
writer.writerow([src_id, end_id, relation, original_concepts, synsets, 'Relation'])
return

View File

@ -0,0 +1,267 @@
import time
import uvicorn
from fastapi import FastAPI, HTTPException, Response
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from logging import Logger
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever, BaseLargeKGEdgeRetriever
from atlas_rag.kg_construction.neo4j.utils import start_up_large_kg_index_graph
from atlas_rag.llm_generator import LLMGenerator
from neo4j import Driver
from dataclasses import dataclass
import traceback
@dataclass
class LargeKGConfig:
largekg_retriever: BaseLargeKGRetriever | BaseLargeKGEdgeRetriever = None
reader_llm_generator : LLMGenerator = None
driver: Driver = None
logger: Logger = None
is_felm: bool = False
is_mmlu: bool = False
rag_exemption_list = [
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to assess whether each text segment contains errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: 8923164*7236571?\nSegments: \n1. The product of 8923164 and 7236571 is: 6,461,216,222,844\n2. So, 8923164 multiplied by 7236571 is equal to 6,461,216,222,844.\n\nBelow are your outputs:\nAnswer: 1,2\nIt means segment 1,2 contain errors.""",
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to determine whether each text segment contains factual errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: A company offers a 10% discount on all purchases over $100. A customer purchases three items, each costing $80. Does the customer qualify for the discount?\nSegments: \n1. To solve this problem, we need to use deductive reasoning. We know that the company offers a 10% discount on purchases over $100, so we need to calculate the total cost of the customer's purchase.\n2. The customer purchased three items, each costing $80, so the total cost of the purchase is: 3 x $80 = $200.\n3. Since the total cost of the purchase is greater than $100, the customer qualifies for the discount. \n4. To calculate the discounted price, we can multiply the total cost by 0.1 (which represents the 10% discount): $200 x 0.1 = $20.\n5. So the customer is eligible for a discount of $20, and the final cost of the purchase would be: $200 - $20 = $180.\n6. Therefore, the customer would pay a total of $216 for the three items with the discount applied.\n\nBelow are your outputs:\nAnswer: 2,3,4,5,6\nIt means segment 2,3,4,5,6 contains errors.""",
]
mmlu_check_list = [
"""Given the following question and four candidate answers (A, B, C and D), choose the answer."""
]
app = FastAPI()
@app.on_event("startup")
async def startup():
global large_kg_config
start_up_large_kg_index_graph(large_kg_config.driver)
@app.on_event("shutdown")
async def shutdown():
global large_kg_config
print("Shutting down the model...")
del large_kg_config
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "test"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: str = None
name: Optional[str] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
retriever_config: Optional[dict] = {
"topN": 5,
"number_of_source_nodes_per_ner": 10,
"sampling_area": 250,
"Dmax": 2,
"Wmax": 3
}
class Config:
extra = "allow"
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global large_kg_config
try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
print(request)
# raise HTTPException(status_code=400, detail="Invalid request")
if large_kg_config.logger is not None:
large_kg_config.logger.info(f"Request: {request}")
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
is_exemption = any(exemption in question for exemption in LargeKGConfig.rag_exemption_list)
is_mmlu = any(exemption in question for exemption in LargeKGConfig.mmlu_check_list)
print(f"Is exemption: {is_exemption}, Is MMLU: {is_mmlu}")
if is_mmlu:
rag_text = question
else:
parts = question.rsplit("Question:", 1)
rag_text = parts[-1] if len(parts) > 1 else None
print(f"RAG text: {rag_text}")
if not is_exemption:
passages, passages_score = large_kg_config.largekg_retriever.retrieve_passages(rag_text)
context = "No retrieved context, Please answer the question with your own knowledge." if not passages else "\n".join([f"Passage {i+1}: {text}" for i, text in enumerate(reversed(passages))])
if is_mmlu:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""Here is the context: {context} \n\n
If the context is not useful, you can answer the question with your own knowledge. \n {question}\nThink step by step. Your response should end with 'The answer is ([the_answer_letter])' where the [the_answer_letter] is one of A, B, C and D."""
}
]
elif not is_exemption:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} Reference doc: {context}"""
}
]
else:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
if large_kg_config.logger is not None:
large_kg_config.logger.info(rag_chat_content)
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
except Exception as e:
print("ERROR: ", e)
print("Catched error")
traceback.print_exc()
system_prompt = 'You are a helpful assistant.'
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
def start_app(user_config:LargeKGConfig, host="0.0.0.0", port=10090, reload=False):
"""Function to start the FastAPI application."""
global large_kg_config
large_kg_config = user_config # Use the passed context if provided
uvicorn.run(f"atlas_rag.kg_construction.neo4j.neo4j_api:app", host=host, port=port, reload=reload)

View File

@ -0,0 +1,73 @@
import faiss
from neo4j import Driver
import time
from graphdatascience import GraphDataScience
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
def build_projection_graph(driver: GraphDataScience):
project_graph_1 = "largekgrag_graph"
is_project_graph_1_exist = False
# is_project_graph_2_exist = False
result = driver.graph.list()
for index, row in result.iterrows():
if row['graphName'] == project_graph_1:
is_project_graph_1_exist = True
# if row['graphName'] == project_graph_2:
# is_project_graph_2_exist = True
if not is_project_graph_1_exist:
start_time = time.time()
node_properties = ["Node"]
relation_projection = [ "Relation"]
result = driver.graph.project(
project_graph_1,
node_properties,
relation_projection
)
graph = driver.graph.get(project_graph_1)
print(f"Projection graph {project_graph_1} created in {time.time() - start_time:.2f} seconds")
def build_neo4j_label_index(driver: GraphDataScience):
with driver.session() as session:
index_name = f"NodeNumericIDIndex"
# Check if the index already exists
existing_indexes = session.run("SHOW INDEXES").data()
index_exists = any(index['name'] == index_name for index in existing_indexes)
# Drop the index if it exists
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR (n:Node) ON (n.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
index_name = f"TextNumericIDIndex"
index_exists = any(index['name'] == index_name for index in existing_indexes)
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR (t:Text) ON (t.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
index_name = f"EntityEventEdgeNumericIDIndex"
index_exists = any(index['name'] == index_name for index in existing_indexes)
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR ()-[r:Relation]-() on (r.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
def load_indexes(path_dict):
for key, value in path_dict.items():
if key == 'node':
node_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Node index loaded from {value}")
elif key == 'edge':
edge_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Edge index loaded from {value}")
elif key == 'text':
passage_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Passage index loaded from {value}")
return node_index, edge_index, passage_index
def start_up_large_kg_index_graph(neo4j_driver: Driver)->BaseLargeKGRetriever:
gds_driver = GraphDataScience(neo4j_driver)
# build label index and projection graph
build_neo4j_label_index(neo4j_driver)
build_projection_graph(gds_driver)

View File

@ -0,0 +1,22 @@
from dataclasses import dataclass
@dataclass
class ProcessingConfig:
"""Configuration for text processing pipeline."""
model_path: str
data_directory: str
filename_pattern: str
batch_size_triple: int = 16
batch_size_concept: int = 64
output_directory: str = "./generation_result_debug"
total_shards_triple: int = 1
current_shard_triple: int = 0
total_shards_concept: int = 1
current_shard_concept: int = 0
use_8bit: bool = False
debug_mode: bool = False
resume_from: int = 0
record : bool = False
max_new_tokens: int = 8192
max_workers: int = 8
remove_doc_spaces: bool = False

View File

@ -0,0 +1,497 @@
#!/usr/bin/env python3
"""
Knowledge Graph Extraction Pipeline
Extracts entities, relations, and events from text data using transformer models.
"""
import re
import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any, Tuple
from pathlib import Path
import torch
from datasets import load_dataset
from tqdm import tqdm
import json_repair
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
from atlas_rag.kg_construction.concept_generation import generate_concept
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
from atlas_rag.kg_construction.triple_config import ProcessingConfig
# Constants
TOKEN_LIMIT = 1024
INSTRUCTION_TOKEN_ESTIMATE = 200
CHAR_TO_TOKEN_RATIO = 3.5
class TextChunker:
"""Handles text chunking based on token limits."""
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
self.max_tokens = max_tokens
self.instruction_tokens = instruction_tokens
self.char_ratio = CHAR_TO_TOKEN_RATIO
def calculate_max_chars(self) -> int:
"""Calculate maximum characters per chunk."""
available_tokens = self.max_tokens - self.instruction_tokens
return int(available_tokens * self.char_ratio)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks that fit within token limits."""
max_chars = self.calculate_max_chars()
chunks = []
while len(text) > max_chars:
chunks.append(text[:max_chars])
text = text[max_chars:]
if text: # Add remaining text
chunks.append(text)
return chunks
class DatasetProcessor:
"""Processes and prepares dataset for knowledge graph extraction."""
def __init__(self, config: ProcessingConfig):
self.config = config
self.chunker = TextChunker()
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
"""Check if content is in English."""
metadata = sample.get("metadata", {})
language = metadata.get("lang", "en") # Default to English if not specified
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
return language in supported_languages
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Create chunks from a single sample."""
original_text = sample.get("text", "")
if self.config.remove_doc_spaces:
original_text = re.sub(r'\s+', ' ',original_text).strip()
text_chunks = self.chunker.split_text(original_text)
chunks = []
for chunk_idx, chunk_text in enumerate(text_chunks):
chunk_data = {
"id": sample["id"],
"text": chunk_text,
"chunk_id": chunk_idx,
"metadata": sample["metadata"]
}
chunks.append(chunk_data)
return chunks
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
processed_samples = []
total_texts = len(raw_dataset)
# Handle edge cases
if total_texts == 0:
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
# Calculate base and remainder for fair distribution
base_texts_per_shard = total_texts // self.config.total_shards_triple
remainder = total_texts % self.config.total_shards_triple
# Calculate start index
if self.config.current_shard_triple < remainder:
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
else:
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
# Calculate end index
if self.config.current_shard_triple < remainder:
end_idx = start_idx + (base_texts_per_shard + 1)
else:
end_idx = start_idx + base_texts_per_shard
# Ensure indices are within bounds
start_idx = min(start_idx, total_texts)
end_idx = min(end_idx, total_texts)
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
# Process documents in assigned shard
for idx in range(start_idx, end_idx):
sample = raw_dataset[idx]
# Filter by language
if not self.filter_language_content(sample):
print(f"Unsupported language in sample {idx}, skipping.")
continue
# Create chunks
chunks = self.create_sample_chunks(sample)
processed_samples.extend(chunks)
# Debug mode early termination
if self.config.debug_mode and len(processed_samples) >= 20:
print("Debug mode: Stopping at 20 chunks")
break
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
class CustomDataLoader:
"""Custom data loader for knowledge graph extraction."""
def __init__(self, dataset, processor: DatasetProcessor):
self.raw_dataset = dataset
self.processor = processor
self.processed_data = processor.prepare_dataset(dataset)
self.stage_to_prompt_dict = {
"stage_1": "entity_relation",
"stage_2": "event_entity",
"stage_3": "event_relation"
}
def __len__(self) -> int:
return len(self.processed_data)
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
messages_dict = {
'stage_1': [],
'stage_2': [],
'stage_3': []
}
for item in batch_data:
# get language
language = item.get("metadata",{}).get("lang", "en")
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_one_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_1_msg}
]
stage_two_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_2_msg}
]
stage_three_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_3_msg}
]
messages_dict['stage_1'].append(stage_one_message)
messages_dict['stage_2'].append(stage_two_message)
messages_dict['stage_3'].append(stage_three_message)
return messages_dict
def __iter__(self):
"""Iterate through batches."""
batch_size = self.processor.config.batch_size_triple
start_idx = self.processor.config.resume_from * batch_size
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
batch_data = self.processed_data[i:i + batch_size]
# Prepare instructions
instructions = self.create_batch_instructions(batch_data)
# Extract batch information
batch_ids = [item["id"] for item in batch_data]
batch_metadata = [item["metadata"] for item in batch_data]
batch_texts = [item["text"] for item in batch_data]
yield instructions, batch_ids, batch_texts, batch_metadata
class OutputParser:
"""Parses model outputs and extracts structured data."""
def __init__(self):
pass
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
"""Extract structured data from model outputs."""
results = []
for output in outputs:
parsed_data = json_repair.loads(output)
results.append(parsed_data)
return results
class KnowledgeGraphExtractor:
"""Main class for knowledge graph extraction pipeline."""
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
self.config = config
self.model = None
self.parser = None
self.model = model
self.model_name = model.model_name
self.parser = OutputParser()
def load_dataset(self) -> Any:
"""Load and prepare dataset."""
data_path = Path(self.config.data_directory)
all_files = os.listdir(data_path)
valid_files = [
filename for filename in all_files
if filename.startswith(self.config.filename_pattern) and
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
]
print(f"Found data files: {valid_files}")
data_files = valid_files
dataset_config = {"train": data_files}
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
"""Process first stage: entity-relation extraction."""
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
if self.config.record:
text_outputs = [output[0] for output in outputs]
else:
text_outputs = outputs
structured_data = self.parser.extract_structured_data(text_outputs)
return outputs, structured_data
def create_output_filename(self) -> str:
"""Create output filename with timestamp and shard info."""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
model_name_safe = self.config.model_path.replace("/", "_")
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
os.makedirs(extraction_dir, exist_ok=True)
return os.path.join(extraction_dir, filename)
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
"""Prepare result dictionary for a single sample."""
ids, original_texts, metadata = batch_data
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
if self.config.record:
stage1_outputs = [output[0] for output in stage1_results]
stage1_usage = [output[1] for output in stage1_results]
stage2_outputs = [output[0] for output in stage2_results]
stage2_usage = [output[1] for output in stage2_results]
stage3_outputs = [output[0] for output in stage3_results]
stage3_usage = [output[1] for output in stage3_results]
else:
stage1_outputs = stage1_results
stage2_outputs = stage2_results
stage3_outputs = stage3_results
result = {
"id": ids[index],
"metadata": metadata[index],
"original_text": original_texts[index],
"entity_relation_dict": entity_relations[index],
"event_entity_relation_dict": event_entities[index],
"event_relation_dict": event_relations[index],
"output_stage_one": stage1_outputs[index],
"output_stage_two": stage2_outputs[index],
"output_stage_three": stage3_outputs[index],
}
if self.config.record:
result['usage_stage_one'] = stage1_usage[index]
result['usage_stage_two'] = stage2_usage[index]
result['usage_stage_three'] = stage3_usage[index]
# Handle date serialization
if 'date_download' in result['metadata']:
result['metadata']['date_download'] = str(result['metadata']['date_download'])
return result
def debug_print_result(self, result: Dict[str, Any]):
"""Print result for debugging."""
for key, value in result.items():
print(f"{key}: {value}")
print("-" * 100)
def run_extraction(self):
"""Run the complete knowledge graph extraction pipeline."""
# Setup
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
dataset = self.load_dataset()
if self.config.debug_mode:
print("Debug mode: Processing only 20 samples")
# Create data processor and loader
processor = DatasetProcessor(self.config)
data_loader = CustomDataLoader(dataset["train"], processor)
output_file = self.create_output_filename()
print(f"Model: {self.config.model_path}")
batch_counter = 0
with torch.no_grad():
with open(output_file, "w") as output_stream:
for batch in data_loader:
batch_counter += 1
messages_dict, batch_ids, batch_texts, batch_metadata = batch
# Process all three stages
stage1_results = self.process_stage(messages_dict['stage_1'],1)
stage2_results = self.process_stage(messages_dict['stage_2'],2)
stage3_results = self.process_stage(messages_dict['stage_3'],3)
# Combine results
batch_data = (batch_ids, batch_texts, batch_metadata)
stage_outputs = (stage1_results, stage2_results, stage3_results)
# Write results
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
for i in range(len(batch_ids)):
result = self.prepare_result_dict(batch_data, stage_outputs, i)
if self.config.debug_mode:
self.debug_print_result(result)
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
output_stream.flush()
def convert_json_to_csv(self):
json2csv(
dataset = self.config.filename_pattern,
output_dir=f"{self.config.output_directory}/triples_csv",
data_dir=f"{self.config.output_directory}/kg_extraction"
)
csvs_to_temp_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
config = self.config
)
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
generate_concept(
model=self.model,
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
output_folder=f"{self.config.output_directory}/concepts",
output_file="concept.json",
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
config=self.config,
batch_size=batch_size if batch_size else self.config.batch_size_concept,
shard=self.config.current_shard_concept,
num_shards=self.config.total_shards_concept,
record = self.config.record,
**kwargs
)
def create_concept_csv(self):
merge_csv_files(
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
input_dir=f"{self.config.output_directory}/concepts",
)
all_concept_triples_csv_to_csv(
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
)
def convert_to_graphml(self):
csvs_to_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
)
def add_numeric_id(self):
add_csv_columns(
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
)
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
encoder_model.compute_kg_embedding(
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
batch_size = 2048
)
def create_faiss_index(self, index_type="HNSW,Flat"):
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
def parse_command_line_arguments() -> ProcessingConfig:
"""Parse command line arguments and return configuration."""
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
parser.add_argument("-m", "--model", type=str, required=True,
default="meta-llama/Meta-Llama-3-8B-Instruct",
help="Model path for knowledge extraction")
parser.add_argument("--data_dir", type=str, default="your_data_dir",
help="Directory containing input data")
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
help="Filename pattern to match")
parser.add_argument("-b", "--batch_size", type=int, default=16,
help="Batch size for processing")
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
help="Output directory for results")
parser.add_argument("--total_shards_triple", type=int, default=1,
help="Total number of data shards")
parser.add_argument("--shard", type=int, default=0,
help="Current shard index")
parser.add_argument("--bit8", action="store_true",
help="Use 8-bit quantization")
parser.add_argument("--debug", action="store_true",
help="Enable debug mode")
parser.add_argument("--resume", type=int, default=0,
help="Resume from specific batch")
args = parser.parse_args()
return ProcessingConfig(
model_path=args.model,
data_directory=args.data_dir,
filename_pattern=args.file_name,
batch_size=args.batch_size,
output_directory=args.output_dir,
total_shards_triple=args.total_shards_triple,
current_shard_triple=args.shard,
use_8bit=args.bit8,
debug_mode=args.debug,
resume_from=args.resume
)
def main():
"""Main entry point for the knowledge graph extraction pipeline."""
config = parse_command_line_arguments()
extractor = KnowledgeGraphExtractor(config)
extractor.run_extraction()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,160 @@
import csv
from tqdm import tqdm
def check_created_csv_header(keyword, csv_dir):
keyword_to_paths ={
'cc_en':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv",
'concept_with_numeric_id': f"{csv_dir}/concept_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
},
'pes2o_abstract':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv",
},
'en_simple_wiki_v0':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv",
},
}
for key, path in keyword_to_paths[keyword].items():
with open(path) as infile:
reader = csv.reader(infile)
header = next(reader)
print(f"Header of {key}: {header}")
# print first 5 rows
for i, row in enumerate(reader):
if i < 1:
print(row)
else:
break
def add_csv_columns(node_csv, edge_csv, text_csv, node_with_numeric_id, edge_with_numeric_id, text_with_numeric_id):
with open(node_csv) as infile, open(node_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':LABEL')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
with open(edge_csv) as infile, open(edge_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':TYPE')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
with open(text_csv) as infile, open(text_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':LABEL')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
# def add_csv_columns(keyword, csv_dir):
# keyword_to_paths ={
# 'cc_en':{
# 'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb.csv",
# 'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv"
# },
# 'pes2o_abstract':{
# 'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept.csv",
# 'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv"
# },
# 'en_simple_wiki_v0':{
# 'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept.csv",
# 'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv"
# },
# }
# # ouput node
# with open(keyword_to_paths[keyword]['node_csv']) as infile, open(keyword_to_paths[keyword]['node_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':LABEL'
# label_index = header.index(':LABEL')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
# # output edge (TYPE instead of LABEL for edge)
# with open(keyword_to_paths[keyword]['edge_csv']) as infile, open(keyword_to_paths[keyword]['edge_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':TYPE'
# label_index = header.index(':TYPE')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
# # output text
# with open(keyword_to_paths[keyword]['text_csv']) as infile, open(keyword_to_paths[keyword]['text_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':LABEL'
# label_index = header.index(':LABEL')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
if __name__ == "__main__":
keyword = "en_simple_wiki_v0"
csv_dir = "./import" # Change this to your CSV directory
add_csv_columns(keyword, csv_dir)
# check_created_csv_header(keyword)

View File

@ -0,0 +1,189 @@
import networkx as nx
import csv
import ast
import hashlib
import os
from atlas_rag.kg_construction.triple_config import ProcessingConfig
import pickle
def get_node_id(entity_name, entity_to_id={}):
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
if entity_name not in entity_to_id:
# Use a hash function to generate a unique ID
hash_object = hashlib.sha256(entity_name.encode('utf-8'))
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
entity_to_id[entity_name] = hash_hex
return entity_to_id[entity_name]
def csvs_to_temp_graphml(triple_node_file, triple_edge_file, config:ProcessingConfig=None):
g = nx.DiGraph()
entity_to_id = {}
# Add triple nodes
with open(triple_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["name:ID"]
mapped_id = get_node_id(node_id, entity_to_id)
if mapped_id not in g.nodes:
g.add_node(mapped_id, id=node_id, type=row["type"])
# Add triple edges
with open(triple_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = get_node_id(row[":END_ID"], entity_to_id)
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# save graph to
output_name = f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl"
# check if output file directory exists
output_dir = os.path.dirname(output_name)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# store the graph to a pickle file
with open(output_name, 'wb') as output_file:
pickle.dump(g, output_file)
def csvs_to_graphml(triple_node_file, text_node_file, concept_node_file,
triple_edge_file, text_edge_file, concept_edge_file,
output_file):
'''
Convert multiple CSV files into a single GraphML file.
Types of nodes to be added to the graph:
- Triple nodes: Nodes representing triples, with properties like subject, predicate, object.
- Text nodes: Nodes representing text, with properties like text content.
- Concept nodes: Nodes representing concepts, with properties like concept name and type.
Types of edges to be added to the graph:
- Triple edges: Edges representing relationships between triples, with properties like relation type.
- Text edges: Edges representing relationships between text and nodes, with properties like text type.
- Concept edges: Edges representing relationships between concepts and nodes, with properties like concept type.
DiGraph networkx attributes:
Node:
- type: Type of the node (e.g., entity, event, text, concept).
- file_id: List of text IDs the node is associated with.
- id: Node Name
Edge:
- relation: relation name
- file_id: List of text IDs the edge is associated with.
- type: Type of the edge (e.g., Source, Relation, Concept).
- synsets: List of synsets associated with the edge.
'''
g = nx.DiGraph()
entity_to_id = {}
# Add triple nodes
with open(triple_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["name:ID"]
mapped_id = get_node_id(node_id, entity_to_id)
# Check if node already exists to prevent duplicates
if mapped_id not in g.nodes:
g.add_node(mapped_id, id=node_id, type=row["type"])
# Add text nodes
with open(text_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["text_id:ID"]
# Check if node already exists to prevent duplicates
if node_id not in g.nodes:
g.add_node(node_id, file_id=node_id, id=row["original_text"], type="passage")
# Add concept nodes
with open(concept_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["concept_id:ID"]
# Check if node already exists to prevent duplicates
if node_id not in g.nodes:
g.add_node(node_id, file_id="concept_file", id=row["name"], type="concept")
# Add file id for triple nodes and concept nodes when add the edges
# Add triple edges
with open(triple_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = get_node_id(row[":END_ID"], entity_to_id)
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# Add file_id to start and end nodes if they are triple or concept nodes
for node_id in [start_id, end_id]:
if g.nodes[node_id]['type'] in ['triple', 'concept'] and 'file_id' not in g.nodes[node_id]:
g.nodes[node_id]['file_id'] = row.get("file_id", "triple_file")
# Add concepts to the edge
concepts = ast.literal_eval(row["concepts"])
for concept in concepts:
if "concepts" not in g.edges[start_id, end_id]:
g.edges[start_id, end_id]['concepts'] = str(concept)
else:
# Avoid duplicate concepts by checking if concept is already in the list
current_concepts = g.edges[start_id, end_id]['concepts'].split(",")
if str(concept) not in current_concepts:
g.edges[start_id, end_id]['concepts'] += "," + str(concept)
# Add text edges
with open(text_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = row[":END_ID"]
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation="mention in", type=row[":TYPE"])
# Add file_id to start node if it is a triple or concept node
if 'file_id' in g.nodes[start_id]:
g.nodes[start_id]['file_id'] += "," + str(end_id)
else:
g.nodes[start_id]['file_id'] = str(end_id)
# Add concept edges between triple nodes and concept nodes
with open(concept_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = row[":END_ID"] # end id is concept node id
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# Write to GraphML
# check if output file directory exists
output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
nx.write_graphml(g, output_file, infer_numeric_types=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Convert CSV files to GraphML format.')
parser.add_argument('--triple_node_file', type=str, required=True, help='Path to the triple node CSV file.')
parser.add_argument('--text_node_file', type=str, required=True, help='Path to the text node CSV file.')
parser.add_argument('--concept_node_file', type=str, required=True, help='Path to the concept node CSV file.')
parser.add_argument('--triple_edge_file', type=str, required=True, help='Path to the triple edge CSV file.')
parser.add_argument('--text_edge_file', type=str, required=True, help='Path to the text edge CSV file.')
parser.add_argument('--concept_edge_file', type=str, required=True, help='Path to the concept edge CSV file.')
parser.add_argument('--output_file', type=str, required=True, help='Path to the output GraphML file.')
args = parser.parse_args()
csvs_to_graphml(args.triple_node_file, args.text_node_file, args.concept_node_file,
args.triple_edge_file, args.text_edge_file, args.concept_edge_file,
args.output_file)

View File

@ -0,0 +1,70 @@
import pandas as pd
import numpy as np
from ast import literal_eval # Safer string-to-list conversion
import os
CHUNKSIZE = 100_000 # Adjust based on your RAM (100K rows per chunk)
EMBEDDING_COL = "embedding:STRING" # Column name with embeddings
# DIMENSION = 32 # Update with your embedding dimension
ENTITY_ONLY = True
def parse_embedding(embed_str):
"""Convert embedding string to numpy array"""
# Remove brackets and convert to list
return np.array(literal_eval(embed_str), dtype=np.float32)
# Create memory-mapped numpy file
def convert_csv_to_npy(csv_path, npy_path):
total_embeddings = 0
# check dir exist, if not then create it
os.makedirs(os.path.dirname(npy_path), exist_ok=True)
with open(npy_path, "wb") as f:
pass # Initialize empty file
# Process CSV in chunks
for chunk_idx, df_chunk in enumerate(
pd.read_csv(csv_path, chunksize=CHUNKSIZE, usecols=[EMBEDDING_COL])
):
# Parse embeddings
embeddings = np.stack(
df_chunk[EMBEDDING_COL].apply(parse_embedding).values
)
# Verify dimensions
# assert embeddings.shape[1] == DIMENSION, \
# f"Dimension mismatch at chunk {chunk_idx}"
total_embeddings += embeddings.shape[0]
# Append to .npy file
with open(npy_path, "ab") as f:
np.save(f, embeddings.astype(np.float32))
print(f"Processed chunk {chunk_idx} ({CHUNKSIZE*(chunk_idx+1)} rows)")
print(f"Total number of embeddings: {total_embeddings}")
print("Conversion complete!")
if __name__ == "__main__":
keyword = 'cc_en' # Change this to your desired keyword
csv_dir="./import" # Change this to your CSV directory
keyword_to_paths ={
'cc_en':{
'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_2.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_2.csv",
'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json_with_emb.csv",
},
'pes2o_abstract':{
'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json.csv",
'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
},
'en_simple_wiki_v0':{
'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
},
}
for key, path in keyword_to_paths[keyword].items():
npy_path = path.replace(".csv", ".npy")
convert_csv_to_npy(path, npy_path)
print(f"Converted {path} to {npy_path}")

View File

@ -0,0 +1,27 @@
import os
import glob
def merge_csv_files(output_file, input_dir):
"""
Merge all CSV files in the input directory into a single output file.
Args:
output_file (str): Path to the output CSV file.
input_dir (str): Directory containing the input CSV files.
"""
# Delete the output file if it exists
if os.path.exists(output_file):
os.remove(output_file)
# Write the header to the output file
with open(output_file, 'w') as outfile:
outfile.write("node,conceptualized_node,node_type\n")
# Append the contents of all CSV files in the input directory
for csv_file in glob.glob(os.path.join(input_dir, '*.csv')):
with open(csv_file, 'r') as infile:
# Skip the header line
next(infile)
# Append the remaining lines to the output file
with open(output_file, 'a') as outfile:
outfile.writelines(infile)

View File

@ -0,0 +1,277 @@
from tqdm import tqdm
import argparse
import os
import csv
import json
import re
import hashlib
# Increase the field size limit
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
new_text = re.sub(r'\s+', ' ', new_text).strip()
return new_text
def remove_NUL(text):
return text.replace("\x00", "")
def json2csv(dataset, data_dir, output_dir, test=False):
"""
Convert JSON files to CSV files for nodes, edges, and missing concepts.
Args:
dataset (str): Name of the dataset.
data_dir (str): Directory containing the JSON files.
output_dir (str): Directory to save the output CSV files.
test (bool): If True, run in test mode (process only 3 files).
"""
visited_nodes = set()
visited_hashes = set()
all_entities = set()
all_events = set()
all_relations = set()
file_dir_list = [f for f in os.listdir(data_dir) if dataset in f]
file_dir_list = sorted(file_dir_list)
if test:
file_dir_list = file_dir_list[:3]
print("Loading data from the json files")
print("Number of files: ", len(file_dir_list))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Define output file paths
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb.csv")
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb.csv")
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json.csv")
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json.csv")
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json.csv")
if test:
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json_test.csv")
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json_test.csv")
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb_test.csv")
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb_test.csv")
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json_test.csv")
# Open CSV files for writing
with open(node_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node_text, \
open(edge_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge_text, \
open(node_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node, \
open(edge_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge:
csv_writer_node_text = csv.writer(csvfile_node_text)
csv_writer_edge_text = csv.writer(csvfile_edge_text)
writer_node = csv.writer(csvfile_node)
writer_edge = csv.writer(csvfile_edge)
# Write headers
csv_writer_node_text.writerow(["text_id:ID", "original_text", ":LABEL"])
csv_writer_edge_text.writerow([":START_ID", ":END_ID", ":TYPE"])
writer_node.writerow(["name:ID", "type", "concepts", "synsets", ":LABEL"])
writer_edge.writerow([":START_ID", ":END_ID", "relation", "concepts", "synsets", ":TYPE"])
# Process each file
for file_dir in tqdm(file_dir_list):
print("Processing file for file ids: ", file_dir)
with open(os.path.join(data_dir, file_dir), "r") as jsonfile:
for line in jsonfile:
data = json.loads(line.strip())
original_text = data["original_text"]
original_text = remove_NUL(original_text)
if "Here is the passage." in original_text:
original_text = original_text.split("Here is the passage.")[-1]
eot_token = "<|eot_id|>"
original_text = original_text.split(eot_token)[0]
text_hash_id = compute_hash_id(original_text)
# Write the original text as nodes
if text_hash_id not in visited_hashes:
visited_hashes.add(text_hash_id)
csv_writer_node_text.writerow([text_hash_id, original_text, "Text"])
file_id = str(data["id"])
entity_relation_dict = data["entity_relation_dict"]
event_entity_relation_dict = data["event_entity_relation_dict"]
event_relation_dict = data["event_relation_dict"]
# Process entity triples
entity_triples = []
for entity_triple in entity_relation_dict:
try:
assert isinstance(entity_triple["Head"], str)
assert isinstance(entity_triple["Relation"], str)
assert isinstance(entity_triple["Tail"], str)
head_entity = entity_triple["Head"]
relation = entity_triple["Relation"]
tail_entity = entity_triple["Tail"]
# Clean the text
head_entity = clean_text(head_entity)
relation = clean_text(relation)
tail_entity = clean_text(tail_entity)
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
continue
entity_triples.append((head_entity, relation, tail_entity))
except:
print(f"Error processing entity triple: {entity_triple}")
continue
# Process event triples
event_triples = []
for event_triple in event_relation_dict:
try:
assert isinstance(event_triple["Head"], str)
assert isinstance(event_triple["Relation"], str)
assert isinstance(event_triple["Tail"], str)
head_event = event_triple["Head"]
relation = event_triple["Relation"]
tail_event = event_triple["Tail"]
# Clean the text
head_event = clean_text(head_event)
relation = clean_text(relation)
tail_event = clean_text(tail_event)
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
continue
event_triples.append((head_event, relation, tail_event))
except:
print(f"Error processing event triple: {event_triple}")
# Process event-entity triples
event_entity_triples = []
for event_entity_participations in event_entity_relation_dict:
if "Event" not in event_entity_participations or "Entity" not in event_entity_participations:
continue
if not isinstance(event_entity_participations["Event"], str) or not isinstance(event_entity_participations["Entity"], list):
continue
for entity in event_entity_participations["Entity"]:
if not isinstance(entity, str):
continue
entity = clean_text(entity)
event = clean_text(event_entity_participations["Event"])
if event.isspace() or len(event) == 0 or entity.isspace() or len(entity) == 0:
continue
event_entity_triples.append((event, "is participated by", entity))
# Write nodes and edges to CSV files
for entity_triple in entity_triples:
head_entity, relation, tail_entity = entity_triple
if head_entity is None or tail_entity is None or relation is None:
continue
if head_entity.isspace() or tail_entity.isspace() or relation.isspace():
continue
if len(head_entity) == 0 or len(tail_entity) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_entity not in visited_nodes:
visited_nodes.add(head_entity)
all_entities.add(head_entity)
writer_node.writerow([head_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([head_entity, text_hash_id, "Source"])
if tail_entity not in visited_nodes:
visited_nodes.add(tail_entity)
all_entities.add(tail_entity)
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_entity, tail_entity, relation, [], [], "Relation"])
for event_triple in event_triples:
head_event, relation, tail_event = event_triple
if head_event is None or tail_event is None or relation is None:
continue
if head_event.isspace() or tail_event.isspace() or relation.isspace():
continue
if len(head_event) == 0 or len(tail_event) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_event not in visited_nodes:
visited_nodes.add(head_event)
all_events.add(head_event)
writer_node.writerow([head_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
if tail_event not in visited_nodes:
visited_nodes.add(tail_event)
all_events.add(tail_event)
writer_node.writerow([tail_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([tail_event, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_event, tail_event, relation, [], [], "Relation"])
for event_entity_triple in event_entity_triples:
head_event, relation, tail_entity = event_entity_triple
if head_event is None or tail_entity is None or relation is None:
continue
if head_event.isspace() or tail_entity.isspace() or relation.isspace():
continue
if len(head_event) == 0 or len(tail_entity) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_event not in visited_nodes:
visited_nodes.add(head_event)
all_events.add(head_event)
writer_node.writerow([head_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
if tail_entity not in visited_nodes:
visited_nodes.add(tail_entity)
all_entities.add(tail_entity)
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_event, tail_entity, relation, [], [], "Relation"])
# Write missing concepts to CSV
with open(missing_concepts_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Name", "Type"])
for entity in all_entities:
writer.writerow([entity, "Entity"])
for event in all_events:
writer.writerow([event, "Event"])
for relation in all_relations:
writer.writerow([relation, "Relation"])
print("Data to CSV completed successfully.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="[pes2o_abstract, en_simple_wiki_v0, cc_en]")
parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the graph raw JSON files")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output CSV files")
parser.add_argument("--test", action="store_true", help="Test the script")
args = parser.parse_args()
json2csv(dataset=args.dataset, data_dir=args.data_dir, output_dir=args.output_dir, test=args.test)

View File

@ -0,0 +1,169 @@
import networkx as nx
import json
from tqdm import tqdm
import os
import hashlib
def get_node_id(entity_name, entity_to_id):
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
if entity_name not in entity_to_id:
# Use a hash function to generate a unique ID
hash_object = hashlib.md5(entity_name.encode()) # Use MD5 or another hashing algorithm
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
entity_to_id[entity_name] = f'n{hash_hex[:16]}'
return entity_to_id[entity_name]
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
return new_text
def process_kg_data(input_passage_dir, input_triple_dir, output_dir, keyword):
# Get file names containing the keyword
file_names = [file for file in list(os.listdir(input_triple_dir)) if keyword in file]
print(f"Keyword: {keyword}")
print(f"Number of files: {len(file_names)}")
print(file_names)
passage_file_names = [file for file in list(os.listdir(input_passage_dir)) if keyword in file]
print(f'Passage file names: {passage_file_names}')
g = nx.DiGraph()
print("Graph created.")
entity_to_id = {}
# check if output directory exists, if not create it
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Output directory {output_dir} created.")
output_path = f"{output_dir}/{keyword}_kg_from_corpus.graphml"
# Create the original_text to node_id dictionary and add passage node to the graph
with open(f"{input_passage_dir}/{passage_file_names[0]}") as f:
data = json.load(f)
for item in tqdm(data, desc="Processing passages"):
passage_id = item["id"]
passage_text = item["text"]
node_id = get_node_id(passage_text, entity_to_id)
if passage_text.isspace() or len(passage_text) == 0:
continue
# Add the passage node to the graph
g.add_node(node_id, type="passage", id=passage_text, file_id=passage_id)
for file_name in tqdm(file_names):
print(f"Processing {file_name}")
input_file_path = f"{input_triple_dir}/{file_name}"
with open(input_file_path) as f:
for line in tqdm(f):
data = json.loads(line)
metadata = data["metadata"]
file_id = data["id"]
original_text = data["original_text"]
entity_relation_dict = data["entity_relation_dict"]
event_entity_relation_dict = data["event_entity_relation_dict"]
event_relation_dict = data["event_relation_dict"]
# Process entity triples
entity_triples = []
for entity_triple in entity_relation_dict:
if not all(key in entity_triple for key in ["Head", "Relation", "Tail"]):
continue
head_entity = clean_text(entity_triple["Head"])
relation = clean_text(entity_triple["Relation"])
tail_entity = clean_text(entity_triple["Tail"])
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
continue
entity_triples.append((head_entity, relation, tail_entity))
# Add entity triples to the graph
for triple in entity_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="entity", id=triple[0])
g.add_node(tail_id, type="entity", id=triple[2])
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
else:
g.nodes[node_id]["file_id"] += "," + str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
# Process event triples
event_triples = []
for event_triple in event_relation_dict:
if not all(key in event_triple for key in ["Head", "Relation", "Tail"]):
continue
head_event = clean_text(event_triple["Head"])
relation = clean_text(event_triple["Relation"])
tail_event = clean_text(event_triple["Tail"])
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
continue
event_triples.append((head_event, relation, tail_event))
# Add event triples to the graph
for triple in event_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="event", id=triple[0])
g.add_node(tail_id, type="event", id=triple[2])
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
else:
g.nodes[node_id]["file_id"] += "," + str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
# Process event-entity triples
event_entity_triples = []
for event_entity_participations in event_entity_relation_dict:
if not all(key in event_entity_participations for key in ["Event", "Entity"]):
continue
event = clean_text(event_entity_participations["Event"])
if event.isspace() or len(event) == 0:
continue
for entity in event_entity_participations["Entity"]:
if not isinstance(entity, str) or entity.isspace() or len(entity) == 0:
continue
entity = clean_text(entity)
event_entity_triples.append((event, "is participated by", entity))
# Add event-entity triples to the graph
for triple in event_entity_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="event", id=triple[0])
g.add_node(tail_id, type="entity", id=triple[2])
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
print(f"Number of nodes: {g.number_of_nodes()}")
print(f"Number of edges: {g.number_of_edges()}")
print(f"Graph density: {nx.density(g)}")
with open(output_path, 'wb') as f:
nx.write_graphml(g, f, infer_numeric_types=True)

View File

@ -0,0 +1,63 @@
import argparse
import json
import os
import sys
from pathlib import Path
# Set up argument parser
parser = argparse.ArgumentParser(description="Convert all Markdown files in a folder to separate JSON files.")
parser.add_argument(
"--input", required=True, help="Path to the folder containing Markdown files"
)
parser.add_argument(
"--output", default=None, help="Output folder for JSON files (defaults to input folder if not specified)"
)
# Parse arguments
args = parser.parse_args()
# Resolve input folder path
input_folder = Path(args.input)
if not input_folder.is_dir():
print(f"Error: '{args.input}' is not a directory.", file=sys.stderr)
sys.exit(1)
# Set output folder (use input folder if not specified)
output_folder = Path(args.output) if args.output else input_folder
output_folder.mkdir(parents=True, exist_ok=True)
# Find all .md files in the input folder
markdown_files = [f for f in input_folder.iterdir() if f.suffix.lower() == ".md"]
if not markdown_files:
print(f"Error: No Markdown files found in '{args.input}'.", file=sys.stderr)
sys.exit(1)
# Process each Markdown file
for file in markdown_files:
try:
# Read the content of the file
with open(file, "r", encoding="utf-8") as f:
content = f.read()
# Create the JSON object
obj = {
"id": "1",
"text": content,
"metadata": {
"lang": "en"
}
}
# Create output JSON filename (e.g., file1.md -> file1.json)
output_file = output_folder / f"{file.stem}.json"
# Write JSON to file
with open(output_file, "w", encoding="utf-8") as f:
json.dump([obj], f, indent=4)
print(f"Successfully converted '{file}' to '{output_file}'")
except FileNotFoundError:
print(f"Error: File '{file}' not found.", file=sys.stderr)
except Exception as e:
print(f"Error processing file '{file}': {e}", file=sys.stderr)

View File

@ -0,0 +1 @@
from .llm_generator import LLMGenerator

View File

@ -0,0 +1,144 @@
import json
from typing import List, Any
import json_repair
import jsonschema
def normalize_key(key):
return key.strip().lower()
# recover function can be fix_triple_extraction_response, fix_filter_triplets
def validate_output(output_str, **kwargs):
schema = kwargs.get("schema")
fix_function = kwargs.get("fix_function", None)
allow_empty = kwargs.get("allow_empty", True)
if fix_function:
parsed_data = fix_function(output_str, **kwargs)
jsonschema.validate(instance=parsed_data, schema=schema)
if not allow_empty and (not parsed_data or len(parsed_data) == 0):
raise ValueError("Parsed data is empty after validation.")
return json.dumps(parsed_data, ensure_ascii=False)
def fix_filter_triplets(data: str, **kwargs) -> dict:
data = json_repair.loads(data)
processed_facts = []
def find_triplet(element: Any) -> List[str] | None:
# Base case: a valid triplet
if isinstance(element, list) and len(element) == 3 and all(isinstance(item, str) for item in element):
return element
# Recursive case: dig deeper into nested lists
elif isinstance(element, list):
for sub_element in element:
result = find_triplet(sub_element)
if result:
return result
return None
for item in data.get("fact", []):
triplet = find_triplet(item)
if triplet:
processed_facts.append(triplet)
return {"fact": processed_facts}
def fix_triple_extraction_response(response: str, **kwargs) -> str:
"""Attempt to fix and validate JSON response based on the prompt type."""
# Extract the JSON list from the response
# raise error if prompt_type is not provided
if "prompt_type" not in kwargs:
raise ValueError("The 'prompt_type' argument is required.")
prompt_type = kwargs.get("prompt_type")
json_start_token = response.find("[")
if json_start_token == -1:
# add [ at the start
response = "[" + response.strip() + "]"
parsed_objects = json_repair.loads(response)
if len(parsed_objects) == 0:
return []
# Define required keys for each prompt type
required_keys = {
"entity_relation": {"Head", "Relation", "Tail"},
"event_entity": {"Event", "Entity"},
"event_relation": {"Head", "Relation", "Tail"}
}
corrected_data = []
seen_triples = set()
for idx, item in enumerate(parsed_objects):
if not isinstance(item, dict):
print(f"Item {idx} must be a JSON object. Problematic item: {item}")
continue
# Correct the keys
corrected_item = {}
for key, value in item.items():
norm_key = normalize_key(key)
matching_expected_keys = [exp_key for exp_key in required_keys[prompt_type] if normalize_key(exp_key) in norm_key]
if len(matching_expected_keys) == 1:
corrected_key = matching_expected_keys[0]
corrected_item[corrected_key] = value
else:
corrected_item[key] = value
# Check for missing keys in corrected_item
missing = required_keys[prompt_type] - corrected_item.keys()
if missing:
print(f"Item {idx} missing required keys: {missing}. Problematic item: {item}")
continue
# Validate and correct the values in corrected_item
if prompt_type == "entity_relation":
for key in ["Head", "Relation", "Tail"]:
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
print(f"Item {idx} {key} must be a non-empty string. Problematic item: {corrected_item}")
continue
elif prompt_type == "event_entity":
if not isinstance(corrected_item["Event"], str) or not corrected_item["Event"].strip():
print(f"Item {idx} Event must be a non-empty string. Problematic item: {corrected_item}")
continue
if not isinstance(corrected_item["Entity"], list) or not corrected_item["Entity"]:
print(f"Item {idx} Entity must be a non-empty array. Problematic item: {corrected_item}")
continue
else:
corrected_item["Entity"] = [ent.strip() for ent in corrected_item["Entity"] if isinstance(ent, str)]
elif prompt_type == "event_relation":
for key in ["Head", "Tail", "Relation"]:
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
print(f"Item {idx} {key} must be a non-empty sentence. Problematic item: {corrected_item}")
continue
triple_tuple = tuple((k, str(v)) for k, v in corrected_item.items())
if triple_tuple in seen_triples:
print(f"Item {idx} is a duplicate triple: {corrected_item}")
continue
else:
seen_triples.add(triple_tuple)
corrected_data.append(corrected_item)
if not corrected_data:
return []
return corrected_data
def fix_lkg_keywords(data: str, **kwargs) -> dict:
"""
Extract and flatten keywords into a list of strings, filtering invalid types.
"""
data = json_repair.loads(data)
processed_keywords = []
def collect_strings(element: Any) -> None:
if isinstance(element, str):
if len(element) <= 200: # Filter out keywords longer than 100 characters
processed_keywords.append(element)
elif isinstance(element, list):
for item in element:
collect_strings(item)
# Start processing from the root "keywords" field
collect_strings(data.get("keywords", []))
return {"keywords": processed_keywords}

View File

@ -0,0 +1,93 @@
filter_fact_json_schema = {
"type": "object",
"properties": {
"fact": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string" # All items in the inner array must be strings
},
"minItems": 3,
"maxItems": 3,
"additionalItems": False # Block extra items
},
}
},
"required": ["fact"]
}
lkg_keyword_json_schema = {
"type": "object",
"properties": {
"keywords": {
"type": "array",
"items": {
"type": "string"
},
"minItems": 1,
}
},
"required": ["keywords"]
}
triple_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"Head": {
"type": "string"
},
"Relation": {
"type": "string"
},
"Tail": {
"type": "string"
}
},
"required": ["Head", "Relation", "Tail"]
},
}
event_relation_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"Head": {
"type": "string"
},
"Relation": {
"type": "string",
},
"Tail": {
"type": "string"
}
},
"required": ["Head", "Relation", "Tail"]
},
}
event_entity_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"Event": {
"type": "string"
},
"Entity": {
"type": "array",
"items": {
"type": "string"
},
"minItems": 1
}
},
"required": ["Event", "Entity"]
},
}
stage_to_schema = {
1: triple_json_schema,
2: event_entity_json_schema,
3: event_relation_json_schema
}

View File

@ -0,0 +1,364 @@
import json
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, keyword_filtering_prompt
from atlas_rag.llm_generator.prompt.rag_prompt import filter_triple_messages
from atlas_rag.llm_generator.format.validate_json_output import *
from atlas_rag.llm_generator.format.validate_json_schema import filter_fact_json_schema, lkg_keyword_json_schema, stage_to_schema
from transformers.pipelines import Pipeline
import jsonschema
import time
stage_to_prompt_type = {
1: "entity_relation",
2: "event_entity",
3: "event_relation",
}
retry_decorator = retry(
stop=(stop_after_delay(120) | stop_after_attempt(5)), # Max 2 minutes or 5 attempts
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
)
class LLMGenerator():
def __init__(self, client, model_name):
self.model_name = model_name
self.client : OpenAI|Pipeline = client
if isinstance(client, OpenAI|AzureOpenAI):
self.inference_type = "openai"
elif isinstance(client, Pipeline):
self.inference_type = "pipeline"
else:
raise ValueError("Unsupported client type. Please provide either an OpenAI client or a Huggingface Pipeline Object.")
@retry_decorator
def _api_inference(self, message, max_new_tokens=8192,
temperature = 0.7,
frequency_penalty = None,
response_format = {"type": "text"},
return_text_only=True,
return_thinking=False,
reasoning_effort=None,
**kwargs):
start_time = time.time()
response = self.client.chat.completions.create(
model=self.model_name,
messages=message,
max_tokens=max_new_tokens,
temperature=temperature,
frequency_penalty= NOT_GIVEN if frequency_penalty is None else frequency_penalty,
response_format = response_format if response_format is not None else {"type": "text"},
timeout = 120,
reasoning_effort= NOT_GIVEN if reasoning_effort is None else reasoning_effort,
)
time_cost = time.time() - start_time
content = response.choices[0].message.content
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
content = response.choices[0].message.reasoning_content
validate_function = kwargs.get('validate_function', None)
content = validate_function(content, **kwargs) if validate_function else content
if '</think>' in content and not return_thinking:
content = content.split('</think>')[-1].strip()
else:
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None and return_thinking:
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
if return_text_only:
return content
else:
completion_usage_dict = response.usage.model_dump()
completion_usage_dict['time'] = time_cost
return content, completion_usage_dict
def generate_response(self, batch_messages, do_sample=True, max_new_tokens=8192,
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
return_text_only=True, return_thinking=False, reasoning_effort=None, **kwargs):
if temperature == 0.0:
do_sample = False
# single = list of dict, batch = list of list of dict
is_batch = isinstance(batch_messages[0], list)
if not is_batch:
batch_messages = [batch_messages]
results = [None] * len(batch_messages)
to_process = list(range(len(batch_messages)))
if self.inference_type == "openai":
max_workers = kwargs.get('max_workers', 3) # Default to 4 workers if not specified
with ThreadPoolExecutor(max_workers=max_workers) as executor:
def process_message(i):
try:
return self._api_inference(
batch_messages[i], max_new_tokens, temperature,
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort, **kwargs
)
except Exception as e:
print(f"Error processing message {i}: {e.last_attempt.result()}")
return ""
futures = [executor.submit(process_message, i) for i in to_process]
for i, future in enumerate(futures):
results[i] = future.result()
elif self.inference_type == "pipeline":
max_retries = kwargs.get('max_retries', 3) # Default to 3 retries if not specified
start_time = time.time()
# Initial processing of all messages
responses = self.client(
batch_messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
return_full_text=False
)
time_cost = time.time() - start_time
# Extract contents
contents = [resp[0]['generated_text'].strip() for resp in responses]
# Validate and collect failed indices
validate_function = kwargs.get('validate_function', None)
failed_indices = []
for i, content in enumerate(contents):
if validate_function:
try:
contents[i] = validate_function(content, **kwargs)
except Exception as e:
print(f"Validation failed for index {i}: {e}")
failed_indices.append(i)
# Retry failed messages in batches
for attempt in range(max_retries):
if not failed_indices:
break # No more failures to retry
print(f"Retry attempt {attempt + 1}/{max_retries} for {len(failed_indices)} failed messages")
# Prepare batch of failed messages
failed_messages = [batch_messages[i] for i in failed_indices]
try:
# Process failed messages as a batch
retry_responses = self.client(
failed_messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
return_full_text=False
)
retry_contents = [resp[0]['generated_text'].strip() for resp in retry_responses]
# Validate retry results and update contents
new_failed_indices = []
for j, i in enumerate(failed_indices):
try:
if validate_function:
retry_contents[j] = validate_function(retry_contents[j], **kwargs)
contents[i] = retry_contents[j]
except Exception as e:
print(f"Validation failed for index {i} on retry {attempt + 1}: {e}")
new_failed_indices.append(i)
failed_indices = new_failed_indices # Update failed indices for next retry
except Exception as e:
print(f"Batch retry {attempt + 1} failed: {e}")
# If batch processing fails, keep all indices in failed_indices
if attempt == max_retries - 1:
for i in failed_indices:
contents[i] = "" # Set to "" if all retries fail
# Set remaining failed messages to "" after all retries
for i in failed_indices:
contents[i] = ""
# Process thinking tags
if not return_thinking:
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
if return_text_only:
results = contents
else:
usage_dicts = [{
'completion_tokens': len(content.split()),
'time': time_cost / len(batch_messages)
} for content in contents]
results = list(zip(contents, usage_dicts))
return results[0] if not is_batch else results
def generate_cot(self, question, max_new_tokens=1024):
messages = [
{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
{"role": "user", "content": question},
]
return self.generate_response(messages, max_new_tokens=max_new_tokens)
def generate_with_context(self, question, context, max_new_tokens=1024, temperature = 0.7):
messages = [
{"role": "system", "content": "".join(cot_system_instruction)},
{"role": "user", "content": f"{context}\n\n{question}\nThought:"},
]
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature = 0.7):
messages = deepcopy(prompt_template)
messages.append(
{"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"},
)
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature = 0.7):
messages = [
{"role": "system", "content": "".join(cot_system_instruction_kg)},
{"role": "user", "content": f"{context}\n\n{question}"},
]
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
@retry_decorator
def filter_triples_with_entity_event(self,question, triples):
messages = deepcopy(filter_triple_messages)
messages.append(
{"role": "user", "content": f"""[ ## question ## ]]
{question}
[[ ## fact_before_filter ## ]]
{triples}"""})
try:
validate_args = {
"schema": filter_fact_json_schema,
"fix_function": fix_filter_triplets,
}
response = self.generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"},
validate_function=validate_output, **validate_args)
return response
except Exception as e:
# If all retries fail, return the original triples
return triples
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_filter_keywords_with_entity(self, question, keywords):
messages = deepcopy(keyword_filtering_prompt)
messages.append({
"role": "user",
"content": f"""[[ ## question ## ]]
{question}
[[ ## keywords_before_filter ## ]]
{keywords}"""
})
try:
response = self.generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
# Validate and clean the response
cleaned_data = validate_output(response, lkg_keyword_json_schema, fix_lkg_keywords)
return cleaned_data['keywords']
except Exception as e:
return keywords
def ner(self, text):
messages = [
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
{"role": "user", "content": f"Extract the named entities from: {text}"},
]
return self.generate_response(messages)
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_ner(self, text):
messages = deepcopy(ner_prompt)
messages.append(
{
"role": "user",
"content": f"[[ ## question ## ]]\n{text}"
}
)
validation_args = {
"schema": lkg_keyword_json_schema,
"fix_function": fix_lkg_keywords
}
# Generate raw response from LLM
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
try:
# Validate and clean the response
cleaned_data = json_repair.loads(raw_response)
return cleaned_data['keywords']
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
return [] # Fallback to empty list or raise custom exception
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_tog_ner(self, text):
messages = [
{"role": "system", "content": "You are an advanced AI assistant that extracts named entities from given text. "},
{"role": "user", "content": f"Extract the named entities from: {text}"}
]
# Generate raw response from LLM
validation_args = {
"schema": lkg_keyword_json_schema,
"fix_function": fix_lkg_keywords
}
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
try:
# Validate and clean the response
cleaned_data = json_repair.loads(raw_response)
return cleaned_data['keywords']
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
return [] # Fallback to empty list or raise custom exception
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
react_system_instruction = (
'You are an advanced AI assistant that uses the ReAct framework to solve problems through iterative search. '
'Follow these steps in your response:\n'
'1. Thought: Think step by step and analyze if the current context is sufficient to answer the question. If not, review the current context and think critically about what can be searched to help answer the question.\n'
' - Break down the question into *1-hop* sub-questions if necessary (e.g., identify key entities like people or places before addressing specific events).\n'
' - Use the available context to make inferences about key entities and their relationships.\n'
' - If a previous search query (prefix with "Previous search attempt") was not useful, reflect on why and adjust your strategy—avoid repeating similar queries and consider searching for general information about key entities or related concepts.\n'
'2. Action: Choose one of:\n'
' - Search for [Query]: If you need more information, specify a new query. The [Query] must differ from previous searches in wording and direction to explore new angles.\n'
' - No Action: If the current context is sufficient.\n'
'3. Answer: Provide one of:\n'
' - A concise, definitive response as a noun phrase if you can answer.\n'
' - "Need more information" if you need to search.\n\n'
'Format your response exactly as:\n'
'Thought: [your reasoning]\n'
'Action: [Search for [Query] or No Action]\n'
'Answer: [concise noun phrase if you can answer, or "Need more information" if you need to search]\n\n'
)
# Build context with search history if available
full_context = []
if search_history:
for i, (thought, action, observation) in enumerate(search_history):
search_history_text = f"\nPrevious search attempt {i}:\n"
search_history_text += f"{action}\n Result: {observation}\n"
full_context.append(search_history_text)
if context:
full_context_text = f"Current Retrieved Context:\n{context}\n"
full_context.append(full_context_text)
if logger:
logger.info(f"Full context for ReAct generation: {full_context}")
# Combine few-shot examples with system instruction and user query
messages = [
{"role": "system", "content": react_system_instruction},
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
if full_context else f"Question: {question}"}
]
if logger:
logger.info(f"Messages for ReAct generation: {search_history}Question: {question}")
return self.generate_response(messages, max_new_tokens=max_new_tokens)
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False, allow_empty=False):
if isinstance(messages[0], dict):
messages = [messages]
validate_kwargs = {
'schema': stage_to_schema.get(stage, None),
'fix_function': fix_triple_extraction_response,
'prompt_type': stage_to_prompt_type.get(stage, None),
'allow_empty': allow_empty
}
result = self.generate_response(messages, max_new_tokens=max_tokens, validate_function=validate_output, return_text_only = not record, **validate_kwargs)
return result

View File

@ -0,0 +1,381 @@
import json
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
import time
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
from atlas_rag.llm_generator.format.validate_json_output import validate_filter_output, messages as filter_messages
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, validate_keyword_output, keyword_filtering_prompt
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
from atlas_rag.llm_generator.format.validate_json_output import fix_and_validate_response
from transformers.pipelines import Pipeline
import jsonschema
from typing import Union
from logging import Logger
stage_to_prompt_type = {
1: "entity_relation",
2: "event_entity",
3: "event_relation",
}
retry_decorator = retry(
stop=(stop_after_delay(120) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
)
class LLMGenerator:
def __init__(self, client, model_name):
self.model_name = model_name
self.client: OpenAI | Pipeline = client
if isinstance(client, (OpenAI, AzureOpenAI)):
self.inference_type = "openai"
elif isinstance(client, Pipeline):
self.inference_type = "pipeline"
else:
raise ValueError("Unsupported client type6Please provide either an OpenAI client or a Huggingface Pipeline Object.")
@retry_decorator
def _generate_response(self, messages, do_sample=True, max_new_tokens=8192, temperature=0.7,
frequency_penalty=None, response_format={"type": "text"}, return_text_only=True,
return_thinking=False, reasoning_effort=None):
if temperature == 0.0:
do_sample = False
if self.inference_type == "openai":
start_time = time.time()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
frequency_penalty=NOT_GIVEN if frequency_penalty is None else frequency_penalty,
response_format=response_format if response_format is not None else {"type": "text"},
timeout=120,
reasoning_effort=NOT_GIVEN if reasoning_effort is None else reasoning_effort,
)
time_cost = time.time() - start_time
content = response.choices[0].message.content
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
content = response.choices[0].message.reasoning_content
else:
content = response.choices[0].message.content
if '</think>' in content and not return_thinking:
content = content.split('</think>')[-1].strip()
else:
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None:
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
if return_text_only:
return content
else:
completion_usage_dict = response.usage.model_dump()
completion_usage_dict['time'] = time_cost
return content, completion_usage_dict
elif self.inference_type == "pipeline":
start_time = time.time()
if hasattr(self.client, 'tokenizer'):
input_text = self.client.tokenizer.apply_chat_template(messages, tokenize=False)
else:
input_text = "\n".join([msg["content"] for msg in messages])
response = self.client(
input_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
time_cost = time.time() - start_time
content = response[0]['generated_text'].strip()
if '</think>' in content and not return_thinking:
content = content.split('</think>')[-1].strip()
if return_text_only:
return content
else:
token_count = len(content.split())
completion_usage_dict = {
'completion_tokens': token_count,
'time': time_cost
}
return content, completion_usage_dict
def _generate_batch_responses(self, batch_messages, do_sample=True, max_new_tokens=8192,
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
return_text_only=True, return_thinking=False, reasoning_effort=None):
if self.inference_type == "openai":
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(
self._generate_response, messages, do_sample, max_new_tokens, temperature,
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort
) for messages in batch_messages
]
results = [future.result() for future in futures]
return results
elif self.inference_type == "pipeline":
if not hasattr(self.client, 'tokenizer'):
raise ValueError("Pipeline must have a tokenizer for batch processing.")
batch_inputs = [self.client.tokenizer.apply_chat_template(messages, tokenize=False) for messages in batch_messages]
start_time = time.time()
responses = self.client(
batch_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
time_cost = time.time() - start_time
contents = [resp['generated_text'].strip() for resp in responses]
if not return_thinking:
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
if return_text_only:
return contents
else:
usage_dicts = [{
'completion_tokens': len(content.split()),
'time': time_cost / len(batch_messages)
} for content in contents]
return list(zip(contents, usage_dicts))
def generate_cot(self, questions, max_new_tokens=1024):
if isinstance(questions, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
{"role": "user", "content": questions}]
return self._generate_response(messages, max_new_tokens=max_new_tokens)
elif isinstance(questions, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
{"role": "user", "content": q}] for q in questions]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
def generate_with_context(self, question, context, max_new_tokens=1024, temperature=0.7):
if isinstance(question, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction)},
{"role": "user", "content": f"{context}\n\n{question}\nThought:"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction)},
{"role": "user", "content": f"{context}\n\n{q}\nThought:"}] for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature=0.7):
if isinstance(question, str):
messages = deepcopy(prompt_template)
messages.append({"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"})
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [deepcopy(prompt_template) + [{"role": "user", "content": f"{context}\n\nQuestions:{q}\nThought:"}]
for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature=0.7):
if isinstance(question, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction_kg)},
{"role": "user", "content": f"{context}\n\n{question}"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_kg)},
{"role": "user", "content": f"{context}\n\n{q}"}] for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
@retry_decorator
def filter_triples_with_entity(self, question, nodes, max_new_tokens=1024):
if isinstance(question, str):
messages = [{"role": "system", "content": """
Your task is to filter text candidates based on their relevance to a given query...
"""}, {"role": "user", "content": f"{question} \n Output Before Filter: {nodes} \n Output After Filter:"}]
try:
response = json.loads(self._generate_response(messages, max_new_tokens=max_new_tokens))
return response
except Exception:
return json.loads(nodes)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": """
Your task is to filter text candidates based on their relevance to a given query...
"""}, {"role": "user", "content": f"{q} \n Output Before Filter: {nodes} \n Output After Filter:"}]
for q in question]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
return [json.loads(resp) if json.loads(resp) else json.loads(nodes) for resp in responses]
@retry_decorator
def filter_triples_with_entity_event(self, question, triples):
if isinstance(question, str):
messages = deepcopy(filter_messages)
messages.append({"role": "user", "content": f"[ ## question ## ]]\n{question}\n[[ ## fact_before_filter ## ]]\n{triples}"})
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
cleaned_data = validate_filter_output(response)
return cleaned_data['fact']
except Exception:
return []
elif isinstance(question, list):
batch_messages = [deepcopy(filter_messages) + [{"role": "user", "content": f"[ ## question ## ]]\n{q}\n[[ ## fact_before_filter ## ]]\n{triples}"}]
for q in question]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
return [validate_filter_output(resp)['fact'] if validate_filter_output(resp) else [] for resp in responses]
def generate_with_custom_messages(self, custom_messages, do_sample=True, max_new_tokens=1024, temperature=0.8, frequency_penalty=None):
if isinstance(custom_messages[0], dict):
return self._generate_response(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
elif isinstance(custom_messages[0], list):
return self._generate_batch_responses(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_filter_keywords_with_entity(self, question, keywords):
if isinstance(question, str):
messages = deepcopy(keyword_filtering_prompt)
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{question}\n[[ ## keywords_before_filter ## ]]\n{keywords}"})
try:
response = self._generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return keywords
elif isinstance(question, list):
batch_messages = [deepcopy(keyword_filtering_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{q}\n[[ ## keywords_before_filter ## ]]\n{k}"}]
for q, k in zip(question, keywords)]
responses = self._generate_batch_responses(batch_messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else keywords for resp in responses]
def ner(self, text):
if isinstance(text, str):
messages = [{"role": "system", "content": "Please extract the entities..."},
{"role": "user", "content": f"Extract the named entities from: {text}"}]
return self._generate_response(messages)
elif isinstance(text, list):
batch_messages = [[{"role": "system", "content": "Please extract the entities..."},
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
return self._generate_batch_responses(batch_messages)
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_ner(self, text):
if isinstance(text, str):
messages = deepcopy(ner_prompt)
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{text}"})
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return []
elif isinstance(text, list):
batch_messages = [deepcopy(ner_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{t}"}] for t in text]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_tog_ner(self, text):
if isinstance(text, str):
messages = [{"role": "system", "content": "You are an advanced AI assistant..."},
{"role": "user", "content": f"Extract the named entities from: {text}"}]
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return []
elif isinstance(text, list):
batch_messages = [[{"role": "system", "content": "You are an advanced AI assistant..."},
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
# Implementation remains single-input focused as its iterative; batching not applicable here
react_system_instruction = (
'You are an advanced AI assistant that uses the ReAct framework...'
)
full_context = []
if search_history:
for i, (thought, action, observation) in enumerate(search_history):
full_context.append(f"\nPrevious search attempt {i}:\n{action}\n Result: {observation}\n")
if context:
full_context.append(f"Current Retrieved Context:\n{context}\n")
messages = [{"role": "system", "content": react_system_instruction},
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
if full_context else f"Question: {question}"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens)
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'],
max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
# Single-input iterative process; batching not applicable
search_history = []
if isinstance(retriever, BaseEdgeRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = ". ".join(initial_context)
elif isinstance(retriever, BasePassageRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = "\n".join(initial_context)
for iteration in range(max_iterations):
analysis_response = self.generate_with_react(
question=question, context=current_context, max_new_tokens=max_new_tokens, search_history=search_history, logger=logger
)
try:
thought = analysis_response.split("Thought:")[1].split("\n")[0]
action = analysis_response.split("Action:")[1].split("\n")[0]
answer = analysis_response.split("Answer:")[1].strip()
if answer.lower() != "need more information":
search_history.append((thought, action, "Using current context"))
return answer, search_history
if "search" in action.lower():
search_query = action.split("search for")[-1].strip()
if isinstance(retriever, BaseEdgeRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
current_contexts = current_context.split(". ")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = ". ".join(new_context)
elif isinstance(retriever, BasePassageRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
current_contexts = current_context.split("\n")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = "\n".join(new_context)
observation = f"Found information: {new_context}" if new_context else "No new information found..."
search_history.append((thought, action, observation))
if new_context:
current_context = f"{current_context}\n{new_context}"
else:
search_history.append((thought, action, "No action taken but answer not found"))
return "Unable to find answer", search_history
except Exception as e:
return analysis_response, search_history
return answer, search_history
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False):
if isinstance(messages[0], dict):
messages = [messages]
responses = self._generate_batch_responses(
batch_messages=messages,
max_new_tokens=max_tokens,
temperature=0.0,
do_sample=False,
frequency_penalty=0.5,
reasoning_effort="none",
return_text_only=not record
)
processed_responses = []
for response in responses:
if record:
content, usage_dict = response
else:
content = response
usage_dict = None
try:
prompt_type = stage_to_prompt_type.get(stage, None)
if prompt_type:
corrected, error = fix_and_validate_response(content, prompt_type)
if error:
raise ValueError(f"Validation failed for prompt_type '{prompt_type}'")
else:
corrected = content
if corrected and corrected.strip():
if record:
processed_responses.append((corrected, usage_dict))
else:
processed_responses.append(corrected)
else:
raise ValueError("Invalid response")
except Exception as e:
print(f"Failed to process response: {str(e)}")
if record:
usage_dict = {'completion_tokens': 0, 'total_tokens': 0, 'time': 0}
processed_responses.append(("[]", usage_dict))
else:
processed_responses.append("[]")
return processed_responses

View File

@ -0,0 +1,263 @@
import json
import jsonschema
from typing import Any
ner_prompt =[
{"role": "system",
"content": """
You are a domain analysis engine. You must provide keywords for searching relevant documents.
When given any academic question, follow these steps:
1. **Identify Tested Skills:** Determine the *abstract knowledge/skills* required to solve the problem (e.g., "translating universal statements into predicate logic"), not the concrete entities in the question (e.g., "children," "school").
2. **Extract Domain Specific Term:** Extract domain-specific technical terms (e.g., "school" in *educational policy*), exclude common nouns/verbs describing the question's *scenario* (e.g., "child," "school," "goes to").
3. **Prioritize Formal Structures:** For logic/math problems, focus on notation rules (e.g., quantifier order, implication vs. conjunction), not scenario labels.
4. **Capture Rare Technical Terms:** Include uncommon domain-specific terms critical to the question (e.g., "epigenetics" in biology, "monad" in computer science), even if they appear infrequently in general language.
Your input fields are:
1. question (str): Query for keyword extraction
Your output fields are:
1. keywords (array): Extracted keywords in JSON format
All interactions will be structured as:
[[ ## question ## ]]
{question}
[[ ## keywords ## ]]
{keywords}
The output must be parseable according to JSON schema:
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
"""},
{
"role": "user",
"content": "[[ ## question ## ]]\nSolve \(x^2 - 5x + 6 = 0\)."
},
{
"role": "assistant",
"content": "{\"keywords\": [\"quadratic equation\", \"factoring\", \"roots\", \"algebraic manipulation\"]}"
},
{
"role": "user",
"content": "[[ ## question ## ]]\nExplain the socio-economic causes of the French Revolution."
},
{
"role": "assistant",
"content": "{\"keywords\": [\"historical causation\", \"class struggle\", \"economic inequality\", \"Enlightenment philosophy\"]}"
},
{
"role": "user",
"content": "[[ ## question ## ]]\nProve that the square root of 2 is irrational."
},
{
"role": "assistant",
"content": "{\"keywords\": [\"proof by contradiction\", \"irrational numbers\", \"number theory\", \"rational/irrational distinction\"]}"
},
{
"role": "user",
"content": "[[ ## question ## ]]\nExplain the implications of Heisenberg's uncertainty principle on quantum measurements."
},
{
"role": "assistant",
"content": "{\"keywords\": [\"Heisenberg uncertainty principle\", \"quantum mechanics\", \"measurement theory\", \"quantum state collapse\", \"observable quantities\"]}"
}
]
keyword_filtering_prompt = [
{
"role": "system",
"content": """You are a precision-focused component of a knowledge retrieval system used by researchers and educators.
Your task is to filter keywords based on their relevance to the **core domain knowledge** required to answer a given question.
You must critically evaluate whether each item represents:
1. Academic concepts/methodologies (entities)
2. Critical processes or relationships (events)
The output should be in JSON format, e.g., {"keywords": ["concept1", "concept2"]}. If no keywords are relevant, return {"keywords": []}.
Accuracy is crucial, as these keywords drive document retrieval for critical research. Do not generate any new keywords or explanations.
**Do not change the content of each object in the list. You must only use text from the candidate list and cannot generate new text.**
**Include all characters of the selected keywords.**
**Do not include any duplicate keywords.**
**The keywords can be a sentence describing a event as long as it is helpful for searching.**
---
**Input Fields:**
1. question (str): Query requiring knowledge analysis
2. keywords_before_filter (list): Candidate keywords to evaluate
**Output Field:**
1. keywords_after_filter (list): Filtered keywords in JSON format
---
**Interaction Structure:**
[[ ## question ## ]]
{question}
[[ ## keywords_before_filter ## ]]
{keywords_before_filter}
[[ ## keywords_after_filter ## ]]
{keywords_after_filter}
**JSON Schema Validation:**
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
"""},
{
"role": "user",
"content": """[[ ## question ## ]]
Explain the causes of the French Revolution.
[[ ## candidates ## ]]
["French Revolution", "social inequality", "Enlightenment ideas spreading", "economic crisis", "Bastille storming event"]"""
},
{
"role": "assistant",
"content": """{"keywords": ["social inequality", "Enlightenment ideas spreading", "economic crisis"]}"""
},
{
"role": "user",
"content": """[[ ## question ## ]]
Translate 'All children go to some school' into predicate logic.
[[ ## keywords_before_filter ## ]]
["predicate logic", "children", "school", "universal quantifier (∀)", "attendance"]"""
},
{
"role": "assistant",
"content": """{"keywords": ["predicate logic", "universal quantifier (∀)"]}"""
},
{
"role": "user",
"content": """[[ ## question ## ]]
Solve \(x^2 - 5x + 6 = 0\).
[[ ## keywords_before_filter ## ]]
["quadratic equation", "polynomial", "algebra", "roots", "classroom teaching"]"""
},
{
"role": "assistant",
"content": """{"keywords": ["quadratic equation", "polynomial", "roots"]}"""
}
]
test_messages = [
[
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: What is the phenotype of a congenital disorder impairing the secretion of leptin?
A. Normal energy intake, normal body weight and hyperthyroidism
B. Obesity, excess energy intake, normal growth and hypoinsulinaemia
C. Obesity, abnormal growth, hypothyroidism, hyperinsulinaemia
D. Underweight, abnormal growth, hypothyroidism, hyperinsulinaemia"""
}
],
[
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and five candidate answers (A, B, C and D), choose the answer.
Question: Which of the following is notan advantage of stratified random sampling over simple random sampling?
A. When done correctly, a stratified random sample is less biased than a simple random sample.
B. When done correctly, a stratified random sampling process has less variability from sample to sample than a simple random sample.
C. When done correctly, a stratified random sample can provide, with a smaller sample size, an estimate that is just as reliable as that of a simple random sample with a larger sample size.
D. A stratified random sample provides information about each stratum in the population as well as an estimate for the population as a whole, and a simple random sample does not."""
}
],
[
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
A. 0
B. 4
C. 2
D. 6"""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral
A. paralysis of the facial muscles.
B. paralysis of the facial muscles and loss of taste.
C. paralysis of the facial muscles, loss of taste and lacrimation.
D. paralysis of the facial muscles, loss of taste, lacrimation and decreased salivation."""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: What is true for a type-Ia ("type one-a") supernova?
A. This type occurs in binary systems.
B. This type occurs in young galaxies.
C. This type produces gamma-ray bursts.
D. This type produces high amounts of X-rays."""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: _______ such as bitcoin are becoming increasingly mainstream and have a whole host of associated ethical implications, for example, they are______ and more ______. However, they have also been used to engage in _______.
A. Cryptocurrencies, Expensive, Secure, Financial Crime
B. Traditional currency, Cheap, Unsecure, Charitable giving
C. Cryptocurrencies, Cheap, Secure, Financial crime
D. Traditional currency, Expensive, Unsecure, Charitable giving"""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: The access matrix approach to protection has the difficulty that
A. the matrix, if stored directly, is large and can be clumsy to manage
B. it is not capable of expressing complex protection requirements
C. deciding whether a process has access to a resource is undecidable
D. there is no way to express who has rights to change the access matrix itself"""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: In the NoNicks operating system, the time required by a single file-read operation has four nonoverlapping components: disk seek time-25 msec disk latency time-8 msec disk transfer time- 1 msec per 1,000 bytes operating system overhead-1 msec per 1,000 bytes + 10 msec In version 1 of the system, the file read retrieved blocks of 1,000 bytes. In version 2, the file read (along with the underlying layout on disk) was modified to retrieve blocks of 4,000 bytes. The ratio of-the time required to read a large file under version 2 to the time required to read the same large file under version 1 is approximately
A. 1:4
B. 1:3.5
C. 1:1
D. 1.1:1"""
}
], [
{
"role":"system",
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
},
{
"role": "user",
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
Question: Which of the following propositions is an immediate (one-step) consequence in PL of the given premises? ~E ⊃ ~F G ⊃ F H ~E H ⊃ I ~I
A. E ⊃ F
B. F ⊃ G
C. H ⊃ ~E
D. ~H"""
}
],
]

View File

@ -0,0 +1,115 @@
one_shot_rag_qa_docs = (
"""Wikipedia Title: The Last Horse\nThe Last Horse (Spanish:El último caballo) is a 1950 Spanish comedy film directed by Edgar Neville starring Fernando Fernán Gómez.\n"""
"""Wikipedia Title: Southampton\nThe University of Southampton, which was founded in 1862 and received its Royal Charter as a university in 1952, has over 22,000 students. The university is ranked in the top 100 research universities in the world in the Academic Ranking of World Universities 2010. In 2010, the THES - QS World University Rankings positioned the University of Southampton in the top 80 universities in the world. The university considers itself one of the top 5 research universities in the UK. The university has a global reputation for research into engineering sciences, oceanography, chemistry, cancer sciences, sound and vibration research, computer science and electronics, optoelectronics and textile conservation at the Textile Conservation Centre (which is due to close in October 2009.) It is also home to the National Oceanography Centre, Southampton (NOCS), the focus of Natural Environment Research Council-funded marine research.\n"""
"""Wikipedia Title: Stanton Township, Champaign County, Illinois\nStanton Township is a township in Champaign County, Illinois, USA. As of the 2010 census, its population was 505 and it contained 202 housing units.\n"""
"""Wikipedia Title: Neville A. Stanton\nNeville A. Stanton is a British Professor of Human Factors and Ergonomics at the University of Southampton. Prof Stanton is a Chartered Engineer (C.Eng), Chartered Psychologist (C.Psychol) and Chartered Ergonomist (C.ErgHF). He has written and edited over a forty books and over three hundered peer-reviewed journal papers on applications of the subject. Stanton is a Fellow of the British Psychological Society, a Fellow of The Institute of Ergonomics and Human Factors and a member of the Institution of Engineering and Technology. He has been published in academic journals including "Nature". He has also helped organisations design new human-machine interfaces, such as the Adaptive Cruise Control system for Jaguar Cars.\n"""
"""Wikipedia Title: Finding Nemo\nFinding Nemo Theatrical release poster Directed by Andrew Stanton Produced by Graham Walters Screenplay by Andrew Stanton Bob Peterson David Reynolds Story by Andrew Stanton Starring Albert Brooks Ellen DeGeneres Alexander Gould Willem Dafoe Music by Thomas Newman Cinematography Sharon Calahan Jeremy Lasky Edited by David Ian Salter Production company Walt Disney Pictures Pixar Animation Studios Distributed by Buena Vista Pictures Distribution Release date May 30, 2003 (2003 - 05 - 30) Running time 100 minutes Country United States Language English Budget $$94 million Box office $$940.3 million"""
)
one_shot_ircot_demo = (
f'{one_shot_rag_qa_docs}'
'\n\nQuestion: '
f"When was Neville A. Stanton's employer founded?"
'\nThought: '
f"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. So the answer is: 1862."
'\n\n'
)
rag_qa_system = (
'As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. '
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
'Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations.'
)
one_shot_rag_qa_input = (
f"{one_shot_rag_qa_docs}"
"\n\nQuestion: "
"When was Neville A. Stanton's employer founded?"
'\nThought: '
)
one_shot_rag_qa_output = (
"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. "
"\nAnswer: 1862."
)
prompt_template = [
{"role": "system", "content": rag_qa_system},
{"role": "user", "content": one_shot_rag_qa_input},
{"role": "assistant", "content": one_shot_rag_qa_output},
]
# from https://github.com/OSU-NLP-Group/HippoRAG/blob/main/src/qa/qa_reader.py
cot_system_instruction = ('As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. If the information is not enough, you can use your own knowledge to answer the question.'
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
cot_system_instruction_no_doc = ('As an advanced reading comprehension assistant, your task is to analyze the questions and then answer them. '
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
cot_system_instruction_kg = ('As an advanced reading comprehension assistant, your task is to analyze extracted information and corresponding questions meticulously. If the knowledge graph information is not enough, you can use your own knowledge to answer the question. '
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
filter_triple_messages = [
{
"role": "system",
"content": """You are a critical component of a high-stakes question-answering system used by top researchers and decision-makers worldwide.
Your task is to filter facts based on their relevance to a given query.
The query requires careful analysis and possibly multi-hop reasoning to connect different pieces of information.
You must select all relevant facts from the provided candidate list, aiding in reasoning and providing an accurate answer.
The output should be in JSON format, e.g., {"fact": [["s1", "p1", "o1"], ["s2", "p2", "o2"]]}, and if no facts are relevant, return an empty list, {"fact": []}.
The accuracy of your response is paramount, as it will directly impact the decisions made by these high-level stakeholders. You must only use facts from the candidate list and not generate new facts.
The future of critical decision-making relies on your ability to accurately filter and present relevant information.
Your input fields are:
1. question (str): Query for retrieval
2. fact_before_filter (str): Candidate facts to be filtered
Your output fields are:
1. fact_after_filter (Fact): Filtered facts in JSON format
All interactions will be structured as:
[[ ## question ## ]]
{question}
[[ ## fact_before_filter ## ]]
{fact_before_filter}
[[ ## fact_after_filter ## ]]
{fact_after_filter}
The output must be parseable according to JSON schema: {"type": "object", "properties": {"fact": {"type": "array", "items": {"type": "array", "items": {"type": "string"}}}}, "required": ["fact"]}"""
},
# Example 1
{
"role": "user",
"content": """[[ ## question ## ]]
Are Imperial River (Florida) and Amaradia (Dolj) both located in the same country?
[[ ## fact_before_filter ## ]]
{"fact": [["imperial river", "is located in", "florida"], ["imperial river", "is a river in", "united states"], ["imperial river", "may refer to", "south america"], ["amaradia", "flows through", "ro ia de amaradia"], ["imperial river", "may refer to", "united states"]]}"""
},
{
"role": "assistant",
"content": """{"fact":[["imperial river","is located in","florida"],["imperial river","is a river in","united states"],["amaradia","flows through","ro ia de amaradia"]]}"""
},
# Example 2
{
"role": "user",
"content": """[[ ## question ## ]]
When is the director of film The Ancestor 's birthday?
[[ ## fact_before_filter ## ]]
{"fact": [["jean jacques annaud", "born on", "1 october 1943"], ["tsui hark", "born on", "15 february 1950"], ["pablo trapero", "born on", "4 october 1971"], ["the ancestor", "directed by", "guido brignone"], ["benh zeitlin", "born on", "october 14 1982"]]}"""
},
{
"role": "assistant",
"content": """{"fact":[["the ancestor","directed by","guido brignone"]]}"""
},
]

View File

@ -0,0 +1,108 @@
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
from typing import Union
from logging import Logger
class ReAct():
def __init__(self, llm:LLMGenerator):
self.llm = llm
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'], max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
"""
Generate a response using RAG with ReAct framework, starting with an initial search using the original query.
Args:
question (str): The question to answer
retriever: The retriever instance to use for searching
max_iterations (int): Maximum number of ReAct iterations
max_new_tokens (int): Maximum number of tokens to generate per iteration
Returns:
tuple: (final_answer, search_history)
- final_answer: The final answer generated
- search_history: List of (thought, action, observation) tuples
"""
search_history = []
# Perform initial search with the original query
if isinstance(retriever, BaseEdgeRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = ". ".join(initial_context)
elif isinstance(retriever, BasePassageRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = "\n".join(initial_context)
# Start ReAct process with the initial context
for iteration in range(max_iterations):
# First, analyze if we can answer with current context
analysis_response = self.llm.generate_with_react(
question=question,
context=current_context,
max_new_tokens=max_new_tokens,
search_history=search_history,
logger = logger
)
if logger:
logger.info(f"Analysis response: {analysis_response}")
try:
# Parse the analysis response
thought = analysis_response.split("Thought:")[1].split("\n")[0]
if logger:
logger.info(f"Thought: {thought}")
action = analysis_response.split("Action:")[1].split("\n")[0]
answer = analysis_response.split("Answer:")[1].strip()
# If the answer indicates we can answer with current context
if answer.lower() != "need more information":
search_history.append((thought, action, "Using current context"))
return answer, search_history
# If we need more information, perform the search
if "search" in action.lower():
# Extract search query from the action
search_query = action.split("search for")[-1].strip()
# Perform the search
if isinstance(retriever, BaseEdgeRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
# Filter out contexts that are already in current_context
current_contexts = current_context.split(". ")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = ". ".join(new_context)
elif isinstance(retriever, BasePassageRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
# Filter out contexts that are already in current_context
current_contexts = current_context.split("\n")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = "\n".join(new_context)
# Store the search results as observation
if new_context:
observation = f"Found information: {new_context}"
else:
observation = "No new information found. Consider searching for related entities or events."
search_history.append((thought, action, observation))
# Update context with new search results
if new_context:
current_context = f"{current_context}\n{new_context}"
if logger:
logger.info(f"New search results: {new_context}")
else:
if logger:
logger.info("No new information found, suggesting to try related entities")
else:
# If no search is needed but we can't answer, something went wrong
search_history.append((thought, action, "No action taken but answer not found"))
return "Unable to find answer", search_history
except Exception as e:
if logger:
logger.error(f"Error parsing ReAct response: {e}")
return analysis_response, search_history
# If we've reached max iterations, return the last answer
return answer, search_history

View File

@ -0,0 +1,313 @@
TRIPLE_INSTRUCTIONS = {
"en":{
"system": "You are a helpful assistant who always response in a valid array of JSON objects without any explanation",
"entity_relation": """Given a passage, summarize all the important entities and the relations between them in a concise manner. Relations should briefly capture the connections between entities, without repeating information from the head and tail entities. The entities should be as specific as possible. Exclude pronouns from being considered as entities.
You must **strictly output in the following JSON format**:\n
[
{
"Head": "{a noun}",
"Relation": "{a verb}",
"Tail": "{a noun}",
}...
]""",
"event_entity": """Please analyze and summarize the participation relations between the events and entities in the given paragraph. Each event is a single independent sentence. Additionally, identify all the entities that participated in the events. Do not use ellipses.
You must **strictly output in the following JSON format**:\n
[
{
"Event": "{a simple sentence describing an event}",
"Entity": ["entity 1", "entity 2", "..."]
}...
] """,
"event_relation": """Please analyze and summarize the relationships between the events in the paragraph. Each event is a single independent sentence. Identify temporal and causal relationships between the events using the following types: before, after, at the same time, because, and as a result. Each extracted triple should be specific, meaningful, and able to stand alone. Do not use ellipses.
You must **strictly output in the following JSON format**:\n
[
{
"Head": "{a simple sentence describing the event 1}",
"Relation": "{temporal or causality relation between the events}",
"Tail": "{a simple sentence describing the event 2}"
}...
]""",
"passage_start" : """Here is the passage."""
},
"zh-CN": {
"system": """"你是一个始终以有效JSON数组格式回应的助手""",
"entity_relation": """给定一段文字,提取所有重要实体及其关系,并以简洁的方式总结。关系描述应清晰表达实体间的联系,且不重复头尾实体的信息。实体需具体明确,排除代词。
**重要格式要求:**
1. Head字段必须是一个字符串不能为空
2. Relation字段必须是一个字符串且不能为空如果不确定请用"相关"
3. Tail字段必须是一个字符串不能为空。如果有多个项目请用顿号连接
返回格式必须为以下JSON结构,内容需用简体中文表述:
[
{
"Head": "{名词}",
"Relation": "{动词或关系描述,不能为空}",
"Tail": "{名词,多个用顿号连接}"
}...
]
示例:
输入:"企业内部审计数字化产品技术要求包括审计作业、审计管理、数据建设及应用"
输出:
[
{"Head": "企业内部审计数字化产品", "Relation": "技术要求包括", "Tail": "审计作业、审计管理、数据建设及应用"}
]""",
"event_entity": """分析段落中的事件及其参与实体。每个事件应为独立单句,列出所有相关实体(需具体,不含代词)。
返回格式必须为以下JSON结构,内容需用简体中文表述:
[
{
"Event": "{描述事件的简单句子}",
"Entity": ["实体1", "实体2", "..."]
}...
]""",
"event_relation": """分析事件间的时序或因果关系,关系类型包括:之前,之后,同时,因为,结果.每个事件应为独立单句。
返回格式必须为以下JSON结构.内容需用简体中文表述.
[
{
"Head": "{事件1描述}",
"Relation": "{时序/因果关系}",
"Tail": "{事件2描述}"
}...
]""",
"passage_start": "给定以下段落:"
},
"zh-HK": {
"system": "你是一個始終以有效JSON數組格式回覆的助手",
"entity_relation": """給定一段文字,提取所有重要實體及其關係,並以簡潔的方式總結。關係描述應清晰表達實體間的聯繫,且不重複頭尾實體的信息。實體需具體明確,排除代詞。
返回格式必須為以下JSON結構,內容需用繁體中文表述:
[
{
"Head": "{名詞}",
"Relation": "{動詞或關係描述}",
"Tail": "{名詞}"
}...
]""",
"event_entity": """分析段落中的事件及其參與實體。每個事件應為獨立單句,列出所有相關實體(需具體,不含代詞)。
返回格式必須為以下JSON結構,內容需用繁體中文表述:
[
{
"Event": "{描述事件的簡單句子}",
"Entity": ["實體1", "實體2", "..."]
}...
]""",
"event_relation": """分析事件間的時序或因果關係,關係類型包括:之前,之後,同時,因為,結果.每個事件應為獨立單句。
返回格式必須為以下JSON結構.內容需用繁體中文表述.
[
{
"Head": "{事件1描述}",
"Relation": "{時序/因果關係}",
"Tail": "{事件2描述}"
}...
]""",
"passage_start": "給定以下段落:"
}
}
CONCEPT_INSTRUCTIONS = {
"en": {
"event": """I will give you an EVENT. You need to give several phrases containing 1-2 words for the ABSTRACT EVENT of this EVENT.
You must return your answer in the following format: phrases1, phrases2, phrases3,...
You can't return anything other than answers.
These abstract event words should fulfill the following requirements.
1. The ABSTRACT EVENT phrases can well represent the EVENT, and it could be the type of the EVENT or the related concepts of the EVENT.
2. Strictly follow the provided format, do not add extra characters or words.
3. Write at least 3 or more phrases at different abstract level if possible.
4. Do not repeat the same word and the input in the answer.
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
EVENT: A man retreats to mountains and forests.
Your answer: retreat, relaxation, escape, nature, solitude
EVENT: A cat chased a prey into its shelter
Your answer: hunting, escape, predation, hidding, stalking
EVENT: Sam playing with his dog
Your answer: relaxing event, petting, playing, bonding, friendship
EVENT: [EVENT]
Your answer:""",
"entity":"""I will give you an ENTITY. You need to give several phrases containing 1-2 words for the ABSTRACT ENTITY of this ENTITY.
You must return your answer in the following format: phrases1, phrases2, phrases3,...
You can't return anything other than answers.
These abstract intention words should fulfill the following requirements.
1. The ABSTRACT ENTITY phrases can well represent the ENTITY, and it could be the type of the ENTITY or the related concepts of the ENTITY.
2. Strictly follow the provided format, do not add extra characters or words.
3. Write at least 3 or more phrases at different abstract level if possible.
4. Do not repeat the same word and the input in the answer.
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
ENTITY: Soul
CONTEXT: premiered BFI London Film Festival, became highest-grossing Pixar release
Your answer: movie, film
ENTITY: Thinkpad X60
CONTEXT: Richard Stallman announced he is using Trisquel on a Thinkpad X60
Your answer: Thinkpad, laptop, machine, device, hardware, computer, brand
ENTITY: Harry Callahan
CONTEXT: bluffs another robber, tortures Scorpio
Your answer: person, Amarican, character, police officer, detective
ENTITY: Black Mountain College
CONTEXT: was started by John Andrew Rice, attracted faculty
Your answer: college, university, school, liberal arts college
EVENT: 1st April
CONTEXT: Utkal Dibas celebrates
Your answer: date, day, time, festival
ENTITY: [ENTITY]
CONTEXT: [CONTEXT]
Your answer:""",
"relation":"""I will give you an RELATION. You need to give several phrases containing 1-2 words for the ABSTRACT RELATION of this RELATION.
You must return your answer in the following format: phrases1, phrases2, phrases3,...
You can't return anything other than answers.
These abstract intention words should fulfill the following requirements.
1. The ABSTRACT RELATION phrases can well represent the RELATION, and it could be the type of the RELATION or the simplest concepts of the RELATION.
2. Strictly follow the provided format, do not add extra characters or words.
3. Write at least 3 or more phrases at different abstract level if possible.
4. Do not repeat the same word and the input in the answer.
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
RELATION: participated in
Your answer: become part of, attend, take part in, engage in, involve in
RELATION: be included in
Your answer: join, be a part of, be a member of, be a component of
RELATION: [RELATION]
Your answer:"""
},
"zh-CN": {
"event": """我将给你一个事件。你需要为这个事件的抽象概念提供几个1-2个词的短语。
你必须按照以下格式返回答案短语1, 短语2, 短语3,...
除了答案外不要返回任何其他内容,请以简体中文输出。
这些抽象事件短语应满足以下要求:
1. 能很好地代表该事件的类型或相关概念
2. 严格遵循给定格式,不要添加额外字符或词语
3. 尽可能提供3个或以上不同抽象层次的短语
4. 不要重复相同词语或输入内容
5. 如果无法想出更多短语立即停止,不需要解释
事件:一个人退隐到山林中
你的回答:退隐, 放松, 逃避, 自然, 独处
事件:一只猫将猎物追进巢穴
你的回答:捕猎, 逃跑, 捕食, 躲藏, 潜行
事件:山姆和他的狗玩耍
你的回答:休闲活动, 抚摸, 玩耍, bonding, 友谊
事件:[EVENT]
请以简体中文输出你的回答:""",
"entity":"""我将给你一个实体。你需要为这个实体的抽象概念提供几个1-2个词的短语。
你必须按照以下格式返回答案短语1, 短语2, 短语3,...
除了答案外不要返回任何其他内容,请以简体中文输出。
这些抽象实体短语应满足以下要求:
1. 能很好地代表该实体的类型或相关概念
2. 严格遵循给定格式,不要添加额外字符或词语
3. 尽可能提供3个或以上不同抽象层次的短语
4. 不要重复相同词语或输入内容
5. 如果无法想出更多短语立即停止,不需要解释
实体:心灵奇旅
上下文在BFI伦敦电影节首映成为皮克斯最卖座影片
你的回答:电影, 影片
实体Thinkpad X60
上下文Richard Stallman宣布他在Thinkpad X60上使用Trisquel系统
你的回答Thinkpad, 笔记本电脑, 机器, 设备, 硬件, 电脑, 品牌
实体:哈利·卡拉汉
上下文:吓退另一个劫匪,折磨天蝎座
你的回答:人物, 美国人, 角色, 警察, 侦探
实体:黑山学院
上下文由John Andrew Rice创办吸引了众多教员
你的回答:学院, 大学, 学校, 文理学院
事件4月1日
上下文庆祝Utkal Dibas
你的回答:日期, 日子, 时间, 节日
实体:[ENTITY]
上下文:[CONTEXT]
请以简体中文输出你的回答:""",
"relation":"""我将给你一个关系。你需要为这个关系的抽象概念提供几个1-2个词的短语。
你必须按照以下格式返回答案短语1, 短语2, 短语3,...
除了答案外不要返回任何其他内容,请以简体中文输出。
这些抽象关系短语应满足以下要求:
1. 能很好地代表该关系的类型或最基本概念
2. 严格遵循给定格式,不要添加额外字符或词语
3. 尽可能提供3个或以上不同抽象层次的短语
4. 不要重复相同词语或输入内容
5. 如果无法想出更多短语立即停止,不需要解释
关系:参与
你的回答:成为一部分, 参加, 参与其中, 涉及, 卷入
关系:被包含在
你的回答:加入, 成为一部分, 成为成员, 成为组成部分
关系:[RELATION]
请以简体中文输出你的回答:"""
},
"zh-HK": {
"event": """我將給你一個事件。你需要為這個事件的抽象概念提供幾個1-2個詞的短語。
你必須按照以下格式返回答案短語1, 短語2, 短語3,...
除了答案外不要返回任何其他內容,請以繁體中文輸出。
這些抽象事件短語應滿足以下要求:
1. 能很好地代表該事件的類型或相關概念
2. 嚴格遵循給定格式,不要添加額外字符或詞語
3. 盡可能提供3個或以上不同抽象層次的短語
4. 不要重複相同詞語或輸入內容
5. 如果無法想出更多短語立即停止,不需要解釋
事件:一個人退隱到山林中
你的回答:退隱, 放鬆, 逃避, 自然, 獨處
事件:一隻貓將獵物追進巢穴
你的回答:捕獵, 逃跑, 捕食, 躲藏, 潛行
事件:山姆和他的狗玩耍
你的回答:休閒活動, 撫摸, 玩耍, bonding, 友誼
事件:[EVENT]
請以繁體中文輸出你的回答:""",
"entity":"""我將給你一個實體。你需要為這個實體的抽象概念提供幾個1-2個詞的短語。
你必須按照以下格式返回答案短語1, 短語2, 短語3,...
除了答案外不要返回任何其他內容,請以繁體中文輸出。
這些抽象實體短語應滿足以下要求:
1. 能很好地代表該實體的類型或相關概念
2. 嚴格遵循給定格式,不要添加額外字符或詞語
3. 盡可能提供3個或以上不同抽象層次的短語
4. 不要重複相同詞語或輸入內容
5. 如果無法想出更多短語立即停止,不需要解釋
實體:心靈奇旅
上下文在BFI倫敦電影節首映成為皮克斯最賣座影片
你的回答:電影, 影片
實體Thinkpad X60
上下文Richard Stallman宣布他在Thinkpad X60上使用Trisquel系統
你的回答Thinkpad, 筆記本電腦, 機器, 設備, 硬件, 電腦, 品牌
實體:哈利·卡拉漢
上下文:嚇退另一個劫匪,折磨天蠍座
你的回答:人物, 美國人, 角色, 警察, 偵探
實體:黑山學院
上下文由John Andrew Rice創辦吸引了眾多教員
你的回答:學院, 大學, 學校, 文理學院
事件4月1日
上下文慶祝Utkal Dibas
你的回答:日期, 日子, 時間, 節日
實體:[ENTITY]
上下文:[CONTEXT]
請以繁體中文輸出你的回答:""",
"relation":"""我將給你一個關係。你需要為這個關係的抽象概念提供幾個1-2個詞的短語。
你必須按照以下格式返回答案短語1, 短語2, 短語3,...
除了答案外不要返回任何其他內容,請以繁體中文輸出。
這些抽象關係短語應滿足以下要求:
1. 能很好地代表該關係的類型或最基本概念
2. 嚴格遵循給定格式,不要添加額外字符或詞語
3. 盡可能提供3個或以上不同抽象層次的短語
4. 不要重複相同詞語或輸入內容
5. 如果無法想出更多短語立即停止,不需要解釋
關係:參與
你的回答:成為一部分, 參加, 參與其中, 涉及, 捲入
關係:被包含在
你的回答:加入, 成為一部分, 成為成員, 成為組成部分
關係:[RELATION]
請以繁體中文輸出你的回答:"""
}
}

19
atlas_rag/logging.py Normal file
View File

@ -0,0 +1,19 @@
import logging
from logging import Logger
from logging.handlers import RotatingFileHandler
import os
import datetime
from atlas_rag.evaluation.benchmark import BenchMarkConfig
def setup_logger(config:BenchMarkConfig, logger_name = "MyLogger") -> Logger:
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file_path = f'./log/{config.dataset_name}_event{config.include_events}_concept{config.include_concept}_{date_time}.log'
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
max_bytes = 50 * 1024 * 1024
if not os.path.exists(log_file_path):
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
handler = RotatingFileHandler(log_file_path, maxBytes=max_bytes, backupCount=5)
logger.addHandler(handler)
return logger

View File

@ -0,0 +1,4 @@
from .hipporag import HippoRAGRetriever
from .hipporag2 import HippoRAG2Retriever
from .simple_retriever import SimpleGraphRetriever, SimpleTextRetriever
from .tog import TogRetriever

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
class BaseRetriever(ABC):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")
class BaseEdgeRetriever(BaseRetriever):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")
class BasePassageRetriever(BaseRetriever):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")

View File

@ -0,0 +1,140 @@
from tqdm import tqdm
import networkx as nx
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from logging import Logger
from typing import Optional
from atlas_rag.retriever.base import BasePassageRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class HippoRAGRetriever(BasePassageRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
self.passage_dict = data["text_dict"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.node_list = data["node_list"]
file_id_to_node_id = {}
self.KG = data["KG"]
for node_id in tqdm(list(self.KG.nodes)):
if self.KG.nodes[node_id]['type'] == "passage":
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
self.file_id_to_node_id = file_id_to_node_id
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
self.logger :Logger = logger
if self.logger is None:
self.logging = False
else:
self.logging = True
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def retrieve_personalization_dict(self, query, topN=10):
# extract entities from the query
entities = self.llm_generator.ner(query)
entities = entities.split(", ")
if self.logging:
self.logger.info(f"HippoRAG NER Entities: {entities}")
# print("Entities:", entities)
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_name_list:
# get the index of the entity in the node list
index = self.node_name_list.index(entity)
topk_nodes.append(self.node_list[index])
else:
topk_for_this_entity = 1
# print("Topk for this entity:", topk_for_this_entity)
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
scores = self.node_embeddings@entity_embedding[0].T
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in index_matrix]
if self.logging:
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
topk_nodes = list(set(topk_nodes))
# assert len(topk_nodes) <= topN
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
# print("Topk nodes:", topk_nodes)
# find the number of docs that one work appears in
freq_dict_for_nodes = {}
for node in topk_nodes:
node_data = self.KG.nodes[node]
# print(node_data)
file_ids = node_data["file_id"]
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
freq_dict_for_nodes[node] = len(file_ids_list)
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
# print("personalization dict: ")
return personalization_dict
def retrieve(self, query, topN=5, **kwargs):
topN_nodes = self.inference_config.topk_nodes
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
# retrieve the top N passages
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
for node in pr:
pr[node] = round(pr[node], 4)
if pr[node] < 0.001:
pr[node] = 0
passage_probabilities_sum = {}
for node in pr:
node_data = self.KG.nodes[node]
file_ids = node_data["file_id"]
# for each file id check through each text_id
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
# file id to node id
for file_id in file_ids_list:
if file_id == 'concept_file':
continue
for node_id in self.file_id_to_node_id[file_id]:
if node_id not in passage_probabilities_sum:
passage_probabilities_sum[node_id] = 0
passage_probabilities_sum[node_id] += pr[node]
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
top_passages = sorted_passages[:topN]
top_passages, scores = zip(*top_passages)
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
return passag_contents, top_passages

View File

@ -0,0 +1,237 @@
import networkx as nx
import json
from tqdm import tqdm
import json
from tqdm import tqdm
from typing import Dict, List, Tuple
import networkx as nx
import numpy as np
import json_repair
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from logging import Logger
from dataclasses import dataclass
from typing import Optional
from atlas_rag.retriever.base import BasePassageRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
def min_max_normalize(x):
min_val = np.min(x)
max_val = np.max(x)
range_val = max_val - min_val
# Handle the case where all values are the same (range is zero)
if range_val == 0:
return np.ones_like(x) # Return an array of ones with the same shape as x
return (x - min_val) / range_val
class HippoRAG2Retriever(BasePassageRetriever):
def __init__(self, llm_generator:LLMGenerator,
sentence_encoder:BaseEmbeddingModel,
data : dict,
inference_config: Optional[InferenceConfig] = None,
logger = None,
**kwargs):
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.node_list = data["node_list"]
self.edge_list = data["edge_list"]
self.edge_embeddings = data["edge_embeddings"]
self.text_embeddings = data["text_embeddings"]
self.edge_faiss_index = data["edge_faiss_index"]
self.passage_dict = data["text_dict"]
self.text_id_list = list(self.passage_dict.keys())
self.KG = data["KG"]
self.KG = self.KG.subgraph(self.node_list + self.text_id_list)
self.logger = logger
if self.logger is None:
self.logging = False
else:
self.logging = True
hipporag2mode = "query2edge"
if hipporag2mode == "query2edge":
self.retrieve_node_fn = self.query2edge
elif hipporag2mode == "query2node":
self.retrieve_node_fn = self.query2node
elif hipporag2mode == "ner2node":
self.retrieve_node_fn = self.ner2node
else:
raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.")
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
node_id_to_file_id = {}
for node_id in tqdm(list(self.KG.nodes)):
if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage":
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"]
else:
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"]
self.node_id_to_file_id = node_id_to_file_id
def ner(self, text):
return self.llm_generator.ner(text)
def ner2node(self, query, topN = 10):
entities = self.ner(query)
entities = entities.split(", ")
if len(entities) == 0:
entities = [query]
# retrieve the top k nodes
topk_nodes = []
node_score_dict = {}
for entity_index, entity in enumerate(entities):
topk_for_this_entity = 1
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T)
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
similarity_matrix = [scores[i] for i in index_matrix]
for index, sim_score in zip(index_matrix, similarity_matrix):
node = self.node_list[index]
if node not in topk_nodes:
topk_nodes.append(node)
node_score_dict[node] = sim_score
topk_nodes = list(set(topk_nodes))
result_node_score_dict = {}
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
for node in topk_nodes:
if node in node_score_dict:
result_node_score_dict[node] = node_score_dict[node]
return result_node_score_dict
def query2node(self, query, topN = 10):
query_emb = self.sentence_encoder.encode([query], query_type="entity")
scores = min_max_normalize(self.node_embeddings@query_emb[0].T)
index_matrix = np.argsort(scores)[-topN:][::-1]
similarity_matrix = [scores[i] for i in index_matrix]
result_node_score_dict = {}
for index, sim_score in zip(index_matrix, similarity_matrix):
node = self.node_list[index]
result_node_score_dict[node] = sim_score
return result_node_score_dict
def query2edge(self, query, topN = 10):
query_emb = self.sentence_encoder.encode([query], query_type="edge")
scores = min_max_normalize(self.edge_embeddings@query_emb[0].T)
index_matrix = np.argsort(scores)[-topN:][::-1]
log_edge_list = []
for index in index_matrix:
edge = self.edge_list[index]
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
log_edge_list.append(edge_str)
similarity_matrix = [scores[i] for i in index_matrix]
# construct the edge list
before_filter_edge_json = {}
before_filter_edge_json['fact'] = []
for index, sim_score in zip(index_matrix, similarity_matrix):
edge = self.edge_list[index]
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
before_filter_edge_json['fact'].append(edge_str)
if self.logging:
self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}")
filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False))
filtered_facts = json_repair.loads(filtered_facts)['fact']
if len(filtered_facts) == 0:
return {}
# use filtered facts to get the edge id and check if it exists in the original candidate list.
node_score_dict = {}
log_edge_list = []
for edge in filtered_facts:
edge_str = f'{edge[0]} {edge[1]} {edge[2]}'
search_emb = self.sentence_encoder.encode([edge_str], query_type="search")
D, I = self.edge_faiss_index.search(search_emb, 1)
filtered_index = I[0][0]
# get the edge and the original score
edge = self.edge_list[filtered_index]
log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']])
head, tail = edge[0], edge[1]
sim_score = scores[filtered_index]
if head not in node_score_dict:
node_score_dict[head] = [sim_score]
else:
node_score_dict[head].append(sim_score)
if tail not in node_score_dict:
node_score_dict[tail] = [sim_score]
else:
node_score_dict[tail].append(sim_score)
# average the scores
if self.logging:
self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}")
# take average of the scores
for node in node_score_dict:
node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node])
return node_score_dict
def query2passage(self, query, weight_adjust = 0.05):
query_emb = self.sentence_encoder.encode([query], query_type="passage")
sim_scores = self.text_embeddings @ query_emb[0].T
sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability
# create dict of passage id and score
return dict(zip(self.text_id_list, sim_scores))
def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05):
node_dict = self.retrieve_node_fn(query, topN=topN)
text_dict = self.query2passage(query, weight_adjust=weight_adjust)
return node_dict, text_dict
def retrieve(self, query, topN=5, **kwargs):
topN_edges = self.inference_config.topk_edges
weight_adjust = self.inference_config.weight_adjust
node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust)
personalization_dict = {}
if len(node_dict) == 0:
# return topN text passages
sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True)
sorted_passages = sorted_passages[:topN]
sorted_passages_contents = []
sorted_scores = []
sorted_passage_ids = []
for passage_id, score in sorted_passages:
sorted_passages_contents.append(self.passage_dict[passage_id])
sorted_scores.append(float(score))
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
return sorted_passages_contents, sorted_passage_ids
personalization_dict.update(node_dict)
personalization_dict.update(text_dict)
# retrieve the top N passages
pr = nx.pagerank(self.KG, personalization=personalization_dict,
alpha = self.inference_config.ppr_alpha,
max_iter=self.inference_config.ppr_max_iter,
tol=self.inference_config.ppr_tol)
# get the top N passages based on the text_id list and pagerank score
text_dict_score = {}
for node in self.text_id_list:
# filter out nodes that have 0 score
if pr[node] > 0.0:
text_dict_score[node] = pr[node]
# return topN passages
sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True)
sorted_passages_ids = sorted_passages_ids[:topN]
sorted_passages_contents = []
sorted_scores = []
sorted_passage_ids = []
for passage_id, score in sorted_passages_ids:
sorted_passages_contents.append(self.passage_dict[passage_id])
sorted_scores.append(score)
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
return sorted_passages_contents, sorted_passage_ids

View File

@ -0,0 +1,23 @@
from dataclasses import dataclass
@dataclass
class InferenceConfig:
"""
Configuration class for inference settings.
Attributes:
topk (int): Number of top results to retrieve. Default is 5.
Dmax (int): Maximum depth for search. Default is 4.
weight_adjust (float): Weight adjustment factor for passage retrieval. Default is 0.05.
topk_edges (int): Number of top edges to retrieve. Default is 50.
topk_nodes (int): Number of top nodes to retrieve. Default is 10.
"""
keyword: str = "musique"
topk: int = 5
Dmax: int = 4
weight_adjust: float = 1.0
topk_edges: int = 50
topk_nodes: int = 10
ppr_alpha: float = 0.99
ppr_max_iter: int = 2000
ppr_tol: float = 1e-7

View File

@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
class BaseLargeKGRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve passages based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class BaseLargeKGEdgeRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve Edges / Paths based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

View File

@ -0,0 +1,313 @@
from difflib import get_close_matches
from logging import Logger
import faiss
from neo4j import GraphDatabase
import time
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from graphdatascience import GraphDataScience
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
import string
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
class LargeKGRetriever(BaseLargeKGRetriever):
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
node_index:faiss.Index, passage_index:faiss.Index,
topN: int = 5,
number_of_source_nodes_per_ner: int = 10,
sampling_area : int = 250,logger:Logger = None):
# istantiate one kg resources
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.gds_driver = GraphDataScience(self.neo4j_driver)
self.topN = topN
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
self.sampling_area = sampling_area
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = node_index
# self.edge_faiss_index = self.edge_indexes[keyword]
self.passage_faiss_index = passage_index
self.verbose = False if logger is None else True
self.logger = logger
self.ppr_weight_threshold = 0.00005
def set_model(self, model):
if self.llm_generator.inference_type == 'openai':
self.llm_generator.model_name = model
else:
raise ValueError("Model can only be set for OpenAI inference type.")
def ner(self, text):
return self.llm_generator.large_kg_ner(text)
def convert_numeric_id_to_name(self, numeric_id):
if numeric_id.isdigit():
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
else:
return numeric_id
def has_intersection(self, word_set, input_string):
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
if self.keyword == 'cc_en':
# Check if any phrase in word_set is a substring of cleaned_string
for phrase in word_set:
if phrase in cleaned_string:
return True
return False
else:
# Check if any word in word_set is present in the cleaned_string's words
words_in_string = set(cleaned_string.split())
return not word_set.isdisjoint(words_in_string)
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
graph = self.gds_driver.graph.get('largekgrag_graph')
node_count = graph.node_count()
sampling_ratio = sampling_area / node_count
aggregation_node_dict = []
ppr_weight_threshold = self.ppr_weight_threshold
start_time = time.time()
# Pre-filter nodes based on ppr_weight threshold
# Precompute word sets based on keyword
if self.keyword == 'cc_en':
filtered_personalization = {
node_id: ppr_weight
for node_id, ppr_weight in personalization_dict.items()
if ppr_weight >= ppr_weight_threshold
}
stop_words = set(stopwords.words('english'))
word_set_phrases = set()
word_set_words = set()
for node_id, ppr_weight in filtered_personalization.items():
name = self.gds_driver.util.asNode(node_id)['name']
if name:
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
if cleaned_phrase:
# Process for 'cc_en': remove stop words and add cleaned phrase
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
if filtered_words:
cleaned_phrase_filtered = ' '.join(filtered_words)
word_set_phrases.add(cleaned_phrase_filtered)
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
if self.verbose:
self.logger.info(f"Optimized word set: {word_set}")
else:
filtered_personalization = personalization_dict
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
# Process each node in the filtered personalization dict
for node_id, ppr_weight in filtered_personalization.items():
try:
self.gds_driver.graph.drop('rwr_sample')
start_time = time.time()
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
restartProbability = 0.4, logProgress = False)
if self.verbose:
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
result = self.gds_driver.pageRank.stream(
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
).sort_values("score", ascending=False)
if self.verbose:
self.logger.info(f"pagerank type: {type(result)}")
self.logger.info(f"pagerank result: {result}")
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
# if self.keyword == 'cc_en':
if self.keyword != 'cc_en':
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
else:
result = result.to_dict('records')
if self.verbose:
self.logger.info(f"largekgRAG :result: {result}")
for entry in result:
if self.keyword == 'cc_en':
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
if not self.has_intersection(word_set, node_name):
continue
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
aggregation_node_dict.append({
'nodeId': numeric_id,
'score': entry['score'] * ppr_weight
})
except Exception as e:
if self.verbose:
self.logger.error(f"Error processing node {node_id}: {e}")
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
else:
continue
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
if self.verbose:
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
if self.verbose:
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
start_time = time.time()
with self.neo4j_driver.session() as session:
intermediate_time = time.time()
# Step 1: Distribute entity scores to connected text nodes and find the top 5
query_scores = """
UNWIND $entries AS entry
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
ORDER BY total_score DESC
LIMIT $topN
RETURN textId, total_score
"""
# Execute query to aggregate scores
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
top_numeric_ids = []
top_scores = []
# Extract the top text node IDs and scores
for record in result_scores:
top_numeric_ids.append(record["textId"])
top_scores.append(record["total_score"])
# Step 2: Use top numeric IDs to retrieve the original text
if self.verbose:
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
intermediate_time = time.time()
query_text = """
UNWIND $textIds AS textId
MATCH (t:Text {numeric_id: textId})
RETURN t.original_text AS text, t.numeric_id AS textId
"""
result_texts = session.run(query_text, textIds=top_numeric_ids)
topN_passages = []
score_dict = dict(zip(top_numeric_ids, top_scores))
# Combine original text with scores
for record in result_texts:
original_text = record["text"]
text_id = record["textId"]
score = score_dict.get(text_id, 0)
topN_passages.append((original_text, score))
if self.verbose:
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
# Sort passages by score
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
top_texts = [item[0] for item in topN_passages][:topN]
top_scores = [item[1] for item in topN_passages][:topN]
if self.verbose:
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
self.logger.info(f"Top passages: {top_texts}")
self.logger.info(f"Top scores: {top_scores}")
if self.verbose:
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
return top_texts, top_scores
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
# extract entities from the query
entities = self.ner(query)
if self.verbose:
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
if len(entities) == 0:
entities = [query]
num_entities = len(entities)
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
if self.verbose:
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
initial_nodes += [str(i)for i in I[0]]
if self.verbose:
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
name_id_map = {}
for node_id in initial_nodes:
name = self.convert_numeric_id_to_name(node_id)
name_id_map[name] = node_id
topk_nodes = list(set(initial_nodes))
# convert the numeric id to string and filter again then return numeric id
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
# Second pass: Add filtered keywords
filtered_top_k_nodes = []
filter_log_dict = {}
match_threshold = 0.8
if self.verbose:
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
for keyword in filtered_keywords:
# Check for an exact match first
if keyword in name_id_map:
filtered_top_k_nodes.append(name_id_map[keyword])
filter_log_dict[keyword] = name_id_map[keyword]
else:
# Look for close matches using difflib's get_close_matches
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
if close_matches:
# If a close match is found, add the corresponding node
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
if self.verbose:
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
topk_nodes = list(set(filtered_top_k_nodes))
if len(topk_nodes) > 2 * num_entities:
topk_nodes = topk_nodes[:2 * num_entities]
return topk_nodes
def _process_text(self, text):
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
text = text.lower()
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
return set(text.split())
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
if topk_nodes == []:
if self.verbose:
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
return {}
if self.verbose:
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
freq_dict_for_nodes = {}
query = """
UNWIND $nodes AS node
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
"""
with self.neo4j_driver.session() as session:
result = session.run(query, nodes=topk_nodes)
for record in result:
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
# Create the personalization dictionary
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
return personalization_dict
def retrieve_passages(self, query):
if self.verbose:
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
topN = self.topN
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
sampling_area = self.sampling_area
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
if personalization_dict == {}:
return [], [0]
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
return topN_passages, topN_scores

View File

@ -0,0 +1,469 @@
from neo4j import GraphDatabase
import faiss
import numpy as np
import random
from collections import defaultdict
from typing import List
import time
import logging
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
node_index: faiss.Index,
topN : int = 5,
Dmax : int = 3,
Wmax : int = 3,
prune_size: int = 10,
logger: logging.Logger = None):
"""
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
Args:
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
neo4j_driver (GraphDatabase): Neo4j driver for database access.
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
node_index (faiss.Index): FAISS index for node embeddings.
logger (Logger, optional): Logger for verbose output.
"""
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.filter_encoder = filter_encoder
self.node_faiss_index = node_index
self.verbose = logger is not None
self.logger = logger
self.topN = topN
self.Dmax = Dmax
self.Wmax = Wmax
self.prune_size = prune_size
def ner(self, text: str) -> List[str]:
"""
Extract named entities from the query text using the LLM.
Args:
text (str): The query text.
Returns:
List[str]: List of extracted entities.
"""
entities = self.llm_generator.large_kg_ner(text)
if self.verbose:
self.logger.info(f"Extracted entities: {entities}")
return entities
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
"""
Retrieve top-k nodes similar to entities in the query.
Args:
query (str): The user query.
top_k_nodes (int): Number of nodes to retrieve per entity.
Returns:
List[str]: List of node numeric_ids.
"""
start_time = time.time()
entities = self.ner(query)
if self.verbose:
ner_time = time.time() - start_time
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
if not entities:
entities = [query]
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, 3)
if self.verbose:
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
if len(I[0]) > 0: # Check if results exist
initial_nodes.extend([str(i) for i in I[0]])
# no need filtering as ToG pruning will handle it.
topk_nodes_ids = list(set(initial_nodes))
with self.neo4j_driver.session() as session:
start_time = time.time()
query = """
MATCH (n:Node)
WHERE n.numeric_id IN $topk_nodes_ids
RETURN n.numeric_id AS id, n.name AS name
"""
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
topk_nodes_dict = {}
for record in result:
topk_nodes_dict[record["id"]] = record["name"]
if self.verbose:
neo4j_time = time.time() - start_time
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
if self.verbose:
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
"""
Expand each path by adding neighbors of the last node.
Args:
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
Returns:
List[List[str]]: List of expanded paths.
"""
last_nodes = []
last_node_ids = []
last_node_types = []
paths_end_with_text = []
paths_end_with_text_id = []
paths_end_with_text_type = []
paths_end_with_node = []
paths_end_with_node_id = []
paths_end_with_node_type = []
if self.verbose:
self.logger.info(f"Expanding paths, current paths: {P}")
for p, pid, ptype in zip(P, PIDS, PTYPES):
if not p or not pid or not ptype: # Skip empty paths
continue
t = ptype[-1]
if t == "Text":
paths_end_with_text.append(p)
paths_end_with_text_id.append(pid)
paths_end_with_text_type.append(ptype)
continue
last_node = p[-1] # Last node in the path
last_node_id = pid[-1] # Last node numeric_id
last_nodes.append(last_node)
last_node_ids.append(last_node_id)
last_node_types.append(t)
paths_end_with_node.append(p)
paths_end_with_node_id.append(pid)
paths_end_with_node_type.append(ptype)
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
if not last_node_ids:
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
with self.neo4j_driver.session() as session:
# Query Node relationships
start_time = time.time()
outgoing_query = """
CALL apoc.cypher.runTimeboxed(
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
WITH n, r, m ORDER BY rand() LIMIT 60000
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
{last_node_ids: $last_node_ids},
60000
)
YIELD value
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
"""
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
for record in outgoing_result]
if self.verbose:
outgoing_time = time.time() - start_time
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
# # Query outgoing Node -> Text relationships
# start_time = time.time()
# outgoing_text_query = """
# MATCH (n:Node)-[r:Source]->(t:Text)
# WHERE n.numeric_id IN $last_node_ids
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
# """
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
# for record in outgoing_text_result]
# if self.verbose:
# outgoing_text_time = time.time() - start_time
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
last_node_to_new_paths = defaultdict(list)
last_node_to_new_paths_ids = defaultdict(list)
last_node_to_new_paths_types = defaultdict(list)
for p, pid, ptype in zip(P, PIDS, PTYPES):
last_node = p[-1]
last_node_id = pid[-1]
# Outgoing Node -> Node
for source, source_name, rel_type, target, target_name, target_type in outgoing:
if source == last_node_id and target_name not in p:
new_path = p + [rel_type, target_name]
if target_name.lower() in stopwords.words('english'):
continue
last_node_to_new_paths[last_node].append(new_path)
last_node_to_new_paths_ids[last_node].append(pid + [target])
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
# # Outgoing Node -> Text
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
# if source == last_node_id and target_name not in p:
# new_path = p + [rel_type, target_name]
# last_node_to_new_paths_text[last_node].append(new_path)
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
# # Incoming Node -> Node
# for source, rel_type, target, source_name, source_type in incoming:
# if target == last_node_id and source not in p:
# new_path = p + [rel_type, source_name]
num_paths = 0
for last_node, new_paths in last_node_to_new_paths.items():
num_paths += len(new_paths)
# for last_node, new_paths in last_node_to_new_paths_text.items():
# num_paths += len(new_paths)
new_paths = []
new_pids = []
new_ptypes = []
if self.verbose:
self.logger.info(f"Number of new paths before filtering: {num_paths}")
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
if num_paths > len(last_node_ids) * width:
# Apply filtering when total paths exceed threshold
for last_node, new_ps in last_node_to_new_paths.items():
if len(new_ps) > width:
path_embeddings = self.filter_encoder.encode(new_ps)
query_embeddings = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
top_indices = np.argsort(scores)[-width:]
new_paths.extend([new_ps[i] for i in top_indices])
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
else:
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
else:
# Collect all paths without filtering when total is at or below threshold
for last_node, new_ps in last_node_to_new_paths.items():
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
if self.verbose:
self.logger.info(f"Expanded paths count: {len(new_paths)}")
self.logger.info(f"Expanded paths: {new_paths}")
return new_paths, new_pids, new_ptypes
def path_to_string(self, path: List[str]) -> str:
"""
Convert a path to a human-readable string for LLM rating.
Args:
path (List[str]): Path as a list of node_ids and relation types.
Returns:
str: String representation of the path.
"""
if len(path) < 1:
return ""
path_str = []
with self.neo4j_driver.session() as session:
for i in range(0, len(path), 2):
node_id = path[i]
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
node_name = result.single()["n.name"] if result.single() else node_id
if i + 1 < len(path):
rel_type = path[i + 1]
path_str.append(f"{node_name} ---> {rel_type} --->")
else:
path_str.append(node_name)
return " ".join(path_str).strip()
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
"""
Prune paths to keep the top N based on LLM relevance ratings.
Args:
query (str): The user query.
P (List[List[str]]): List of paths to prune.
topN (int): Number of paths to retain.
Returns:
List[List[str]]: Top N paths.
"""
ratings = []
path_strings = P
# Process paths in chunks of 10
for i in range(0, len(path_strings), self.prune_size):
chunk = path_strings[i:i + self.prune_size]
# Construct user prompt with the current chunk of paths listed
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
for j, path_str in enumerate(chunk, 1):
user_prompt += f"{j + i}. {path_str}\n"
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
# Define system prompt to expect a list of integers
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
# Send the prompt to the language model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
if self.verbose:
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
# Parse the response into a list of ratings
rating_str = response.strip()
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
if len(chunk_ratings) > len(chunk):
chunk_ratings = chunk_ratings[:len(chunk)]
if self.verbose:
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
ratings.extend(chunk_ratings) # Concatenate ratings
# Ensure ratings length matches number of paths, padding with 0s if necessary
if len(ratings) < len(path_strings):
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
# ratings += [0] * (len(path_strings) - len(ratings))
# fall back to use filter encoder to get topN
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
path_embeddings = self.filter_encoder.encode(path_strings)
query_embedding = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embedding.T).flatten()
top_indices = np.argsort(scores)[-topN:]
top_paths = [path_strings[i] for i in top_indices]
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
elif len(ratings) > len(path_strings):
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
ratings = ratings[:len(path_strings)]
# Sort indices based on ratings in descending order
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
# Filter out indices where the rating is 0
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
# Take the top N indices from the filtered list
top_indices = filtered_indices[:topN]
# Use the filtered indices to get the top paths, PIDS, and PTYPES
if self.verbose:
self.logger.info(f"Top indices after pruning: {top_indices}")
self.logger.info(f"length of path_strings: {len(path_strings)}")
top_paths = [path_strings[i] for i in top_indices]
top_pids = [PIDS[i] for i in top_indices]
top_ptypes = [PTYPES[i] for i in top_indices]
# Log top paths if verbose mode is enabled
if self.verbose:
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
return top_paths, top_pids, top_ptypes
def reasoning(self, query: str, P: List[List[str]]) -> bool:
"""
Check if the current paths are sufficient to answer the query.
Args:
query (str): The user query.
P (List[List[str]]): Current list of paths.
Returns:
bool: True if sufficient, False otherwise.
"""
triples = []
with self.neo4j_driver.session() as session:
for path in P:
if len(path) < 3:
continue
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
triples_str = ". ".join(triples)
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
messages = [
{"role": "system", "content": "Answer Yes or No only."},
{"role": "user", "content": prompt}
]
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
if self.verbose:
self.logger.info(f"Reasoning result: {response}")
return "yes" in response.lower()
def retrieve_passages(self, query: str) -> List[str]:
"""
Retrieve the top N paths to answer the query.
Args:
query (str): The user query.
topN (int): Number of paths to return.
Dmax (int): Maximum depth of path expansion.
Wmax (int): Maximum width of path expansion.
Returns:
List[str]: List of triples as strings.
"""
topN = self.topN
Dmax = self.Dmax
Wmax = self.Wmax
if self.verbose:
self.logger.info(f"Retrieving paths for query: {query}")
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
if not initial_nodes:
if self.verbose:
self.logger.info("No initial nodes found.")
return []
P = [[node] for node in initial_nodes]
PIDS = [[node_id] for node_id in initial_nodes_ids]
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
for D in range(Dmax + 1):
if self.verbose:
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
if not P:
if self.verbose:
self.logger.info("No paths to expand.")
break
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
if D == Dmax:
if self.verbose:
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
break
if self.reasoning(query, P):
if self.verbose:
self.logger.info("Paths sufficient, stopping expansion.")
break
# Extract final triples
triples = []
with self.neo4j_driver.session() as session:
for path in P:
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
if self.verbose:
self.logger.info(f"Final triples: {triples}")
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages

View File

@ -0,0 +1,51 @@
from typing import Dict
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
class SimpleGraphRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict):
self.KG = data["KG"]
self.node_list = data["node_list"]
self.edge_list = data["edge_list"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = data["node_faiss_index"]
self.edge_faiss_index = data["edge_faiss_index"]
def retrieve(self, query, topN=5, **kwargs):
# retrieve the top k edges
topk_edges = []
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
D, I = self.edge_faiss_index.search(query_embedding, topN)
topk_edges += [self.edge_list[i] for i in I[0]]
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
class SimpleTextRetriever(BasePassageRetriever):
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
self.sentence_encoder = sentence_encoder
self.passage_dict = passage_dict
self.passage_list = list(passage_dict.values())
self.passage_keys = list(passage_dict.keys())
self.text_embeddings = data["text_embeddings"]
def retrieve(self, query, topN=5, **kwargs):
query_emb = self.sentence_encoder.encode([query], query_type="passage")
sim_scores = self.text_embeddings @ query_emb[0].T
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
# Retrieve top-k passages
topk_passages = [self.passage_list[i] for i in topk_indices]
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
return topk_passages, topk_passages_ids

195
atlas_rag/retriever/tog.py Normal file
View File

@ -0,0 +1,195 @@
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from typing import Optional
from atlas_rag.retriever.base import BaseEdgeRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class TogRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
self.KG = data["KG"]
self.node_list = list(self.KG.nodes)
self.edge_list = list(self.KG.edges)
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
self.llm_generator:LLMGenerator = llm_generator
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.edge_embeddings = data["edge_embeddings"]
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def ner(self, text):
messages = [
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
{"role": "user", "content": f"Extract the named entities from: {text}"},
]
response = self.llm_generator.generate_response(messages)
generated_text = response
# print(generated_text)
return generated_text
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
# extract entities from the query
entities = self.ner(query)
entities = entities.split(", ")
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_list:
topk_nodes.append(entity)
for entity_index, entity in enumerate(entities):
topk_for_this_entity = topk_for_each_entity + 1
entity_embedding = self.sentence_encoder.encode([entity])
# Calculate similarity scores using dot product
scores = self.node_embeddings @ entity_embedding[0].T
# Get top-k indices
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in top_indices]
topk_nodes = list(set(topk_nodes))
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
return topk_nodes
def retrieve(self, query, topN=5, **kwargs):
"""
Retrieve the top N paths that connect the entities in the query.
Dmax is the maximum depth of the search.
"""
Dmax = self.inference_config.Dmax
# in the first step, we retrieve the top k nodes
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
E = initial_nodes
P = [ [e] for e in E]
D = 0
while D <= Dmax:
P = self.search(query, P)
P = self.prune(query, P, topN)
if self.reasoning(query, P):
generated_text = self.generate(query, P)
break
D += 1
if D > Dmax:
generated_text = self.generate(query, P)
# print(generated_text)
return generated_text
def search(self, query, P):
new_paths = []
for path in P:
tail_entity = path[-1]
sucessors = list(self.KG.successors(tail_entity))
predecessors = list(self.KG.predecessors(tail_entity))
# print(f"tail_entity: {tail_entity}")
# print(f"sucessors: {sucessors}")
# print(f"predecessors: {predecessors}")
# # print the attributes of the tail_entity
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
# remove the entity that is already in the path
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
if len(sucessors) == 0 and len(predecessors) == 0:
new_paths.append(path)
continue
for neighbour in sucessors:
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
for neighbour in predecessors:
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
return new_paths
def prune(self, query, P, topN=3):
ratings = []
for path in P:
path_string = ""
for index, node_or_relation in enumerate(path):
if index % 2 == 0:
id_path = self.KG.nodes[node_or_relation]["id"]
else:
id_path = node_or_relation
path_string += f"{id_path} --->"
path_string = path_string[:-5]
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
# print(response)
rating = int(response)
ratings.append(rating)
# sort the paths based on the ratings
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
return sorted_paths[:topN]
def reasoning(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
triples_string = ". ".join(triples_string)
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
return "yes" in response.lower()
def generate(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
return triples_string, ["N/A" for _ in range(len(triples_string))]

View File

@ -0,0 +1 @@
from .create_graph_index import create_embeddings_and_index

Some files were not shown because too many files have changed in this diff Show More