first commit
This commit is contained in:
9
.claude/settings.local.json
Normal file
9
.claude/settings.local.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(python test_numpy_embedding.py:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
57
.dockerignore
Normal file
57
.dockerignore
Normal file
@ -0,0 +1,57 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# 环境变量(包含敏感信息)
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# 测试数据
|
||||
*.pkl
|
||||
test_*.py
|
||||
*_test.py
|
||||
|
||||
# 日志和输出
|
||||
*.log
|
||||
api_outputs/
|
||||
logs/
|
||||
|
||||
# 文档
|
||||
*.md
|
||||
docs/
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Claude相关
|
||||
.claude/
|
||||
|
||||
# 临时文件
|
||||
*.tmp
|
||||
*.bak
|
||||
*.backup
|
||||
24
.env
Normal file
24
.env
Normal file
@ -0,0 +1,24 @@
|
||||
# 阿里云DashScope API密钥(直接调用,无需OneAPI中转)
|
||||
ONEAPI_KEY=sk-2c5045478a244022865e65c5c9b6adf7
|
||||
|
||||
# 大语言模型名称(阿里云DashScope原生模型名)
|
||||
ONEAPI_MODEL=qwen2-7b-instruct
|
||||
ONEAPI_MODEL_GEN=qwen2-7b-instruct
|
||||
ONEAPI_MODEL_MAX=qwen2-7b-instruct
|
||||
|
||||
# 嵌入模型名称
|
||||
ONEAPI_MODEL_EMBED=text-embedding-v3
|
||||
|
||||
# LangSmith追踪配置
|
||||
LANGCHAIN_TRACING_V2=false
|
||||
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
|
||||
LANGCHAIN_API_KEY=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
|
||||
LANGCHAIN_PROJECT=hipporag-retriever
|
||||
|
||||
# ============= Elasticsearch配置 =============
|
||||
# Elasticsearch服务器地址
|
||||
ELASTICSEARCH_HOST=http://101.200.154.78:9200
|
||||
|
||||
# Elasticsearch认证信息
|
||||
ELASTICSEARCH_USERNAME=elastic
|
||||
ELASTICSEARCH_PASSWORD=Abcd123456
|
||||
87
.env.example
Normal file
87
.env.example
Normal file
@ -0,0 +1,87 @@
|
||||
# =======================
|
||||
# AIEC-RAG 环境配置示例
|
||||
# =======================
|
||||
# 复制此文件为 .env 并填入实际配置值
|
||||
|
||||
# =======================
|
||||
# API配置
|
||||
# =======================
|
||||
|
||||
# OneAPI/DashScope API密钥
|
||||
# 获取地址:https://dashscope.console.aliyun.com/
|
||||
ONEAPI_KEY=sk-your-api-key-here
|
||||
|
||||
# =======================
|
||||
# 模型配置
|
||||
# =======================
|
||||
|
||||
# 主语言模型(用于生成答案)
|
||||
# 可选:qwen2-7b-instruct, qwen-max, gpt-3.5-turbo等
|
||||
ONEAPI_MODEL=qwen2-7b-instruct
|
||||
|
||||
# 生成模型(用于查询分解等任务)
|
||||
ONEAPI_MODEL_GEN=qwen2-7b-instruct
|
||||
|
||||
# 最大上下文模型(用于处理长文本)
|
||||
ONEAPI_MODEL_MAX=qwen2-7b-instruct
|
||||
|
||||
# 嵌入模型(用于向量化)
|
||||
# 可选:text-embedding-v3, text-embedding-ada-002等
|
||||
ONEAPI_MODEL_EMBED=text-embedding-v3
|
||||
|
||||
# =======================
|
||||
# Elasticsearch配置
|
||||
# =======================
|
||||
|
||||
# Elasticsearch服务器地址
|
||||
# 本地部署:http://localhost:9200
|
||||
# 远程部署:http://your-es-server:9200
|
||||
ELASTICSEARCH_HOST=http://localhost:9200
|
||||
|
||||
# Elasticsearch认证信息
|
||||
ELASTICSEARCH_USERNAME=elastic
|
||||
ELASTICSEARCH_PASSWORD=your-password-here
|
||||
|
||||
# =======================
|
||||
# LangSmith配置(可选)
|
||||
# =======================
|
||||
|
||||
# 是否启用LangSmith追踪
|
||||
# true: 启用追踪,false: 禁用追踪
|
||||
LANGCHAIN_TRACING_V2=false
|
||||
|
||||
# LangSmith API端点
|
||||
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
|
||||
|
||||
# LangSmith API密钥
|
||||
# 获取地址:https://smith.langchain.com/
|
||||
LANGCHAIN_API_KEY=lsv2_pt_your_api_key_here
|
||||
|
||||
# LangSmith项目名称
|
||||
LANGCHAIN_PROJECT=aiec-rag
|
||||
|
||||
# =======================
|
||||
# 服务配置(可选)
|
||||
# =======================
|
||||
|
||||
# 服务端口(默认8100)
|
||||
# SERVICE_PORT=8100
|
||||
|
||||
# 服务主机(默认0.0.0.0)
|
||||
# SERVICE_HOST=0.0.0.0
|
||||
|
||||
# 调试模式(默认false)
|
||||
# DEBUG_MODE=false
|
||||
|
||||
# =======================
|
||||
# 性能配置(可选)
|
||||
# =======================
|
||||
|
||||
# 最大并发请求数
|
||||
# MAX_CONCURRENT_REQUESTS=10
|
||||
|
||||
# 请求超时时间(秒)
|
||||
# REQUEST_TIMEOUT=120
|
||||
|
||||
# 缓存大小(MB)
|
||||
# CACHE_SIZE=512
|
||||
579
DEPLOYMENT_GUIDE.md
Normal file
579
DEPLOYMENT_GUIDE.md
Normal file
@ -0,0 +1,579 @@
|
||||
# AIEC-RAG 部署指南
|
||||
|
||||
## 目录
|
||||
|
||||
1. [部署架构](#部署架构)
|
||||
2. [单机部署](#单机部署)
|
||||
3. [Docker部署](#docker部署)
|
||||
4. [生产环境部署](#生产环境部署)
|
||||
5. [性能调优](#性能调优)
|
||||
6. [监控配置](#监控配置)
|
||||
7. [备份恢复](#备份恢复)
|
||||
|
||||
## 部署架构
|
||||
|
||||
### 推荐架构
|
||||
|
||||
```
|
||||
[负载均衡器]
|
||||
|
|
||||
┌────────────┼────────────┐
|
||||
↓ ↓ ↓
|
||||
[AIEC-RAG-1] [AIEC-RAG-2] [AIEC-RAG-3]
|
||||
↓ ↓ ↓
|
||||
└────────────┼────────────┘
|
||||
↓
|
||||
[Elasticsearch集群]
|
||||
↓
|
||||
[向量数据库]
|
||||
```
|
||||
|
||||
### 最小配置要求
|
||||
|
||||
| 组件 | CPU | 内存 | 存储 | 说明 |
|
||||
|-----|-----|------|------|------|
|
||||
| API服务 | 4核 | 8GB | 50GB | 单实例最小配置 |
|
||||
| Elasticsearch | 4核 | 16GB | 200GB | 推荐使用SSD |
|
||||
| 整体系统 | 8核 | 32GB | 500GB | 生产环境推荐 |
|
||||
|
||||
## 单机部署
|
||||
|
||||
### 1. 系统准备
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt update
|
||||
sudo apt install -y python3.8 python3-pip git curl wget
|
||||
|
||||
# CentOS/RHEL
|
||||
sudo yum update -y
|
||||
sudo yum install -y python38 python38-pip git curl wget
|
||||
```
|
||||
|
||||
### 2. 安装Elasticsearch
|
||||
|
||||
```bash
|
||||
# 下载并安装Elasticsearch 8.x
|
||||
wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-8.11.0-linux-x86_64.tar.gz
|
||||
tar -xzf elasticsearch-8.11.0-linux-x86_64.tar.gz
|
||||
cd elasticsearch-8.11.0
|
||||
|
||||
# 配置Elasticsearch
|
||||
cat >> config/elasticsearch.yml << EOF
|
||||
network.host: 0.0.0.0
|
||||
discovery.type: single-node
|
||||
xpack.security.enabled: true
|
||||
xpack.security.authc.api_key.enabled: true
|
||||
EOF
|
||||
|
||||
# 启动Elasticsearch
|
||||
./bin/elasticsearch -d
|
||||
```
|
||||
|
||||
### 3. 部署AIEC-RAG
|
||||
|
||||
```bash
|
||||
# 克隆项目
|
||||
git clone <repository_url>
|
||||
cd AIEC-RAG
|
||||
|
||||
# 创建虚拟环境
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑.env文件,填入实际配置
|
||||
|
||||
# 启动服务
|
||||
python rag_api_server_production.py
|
||||
```
|
||||
|
||||
### 4. 设置系统服务
|
||||
|
||||
创建 `/etc/systemd/system/aiec-rag.service`:
|
||||
|
||||
```ini
|
||||
[Unit]
|
||||
Description=AIEC-RAG Service
|
||||
After=network.target elasticsearch.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=aiec
|
||||
WorkingDirectory=/opt/AIEC-RAG
|
||||
Environment="PATH=/opt/AIEC-RAG/venv/bin"
|
||||
ExecStart=/opt/AIEC-RAG/venv/bin/python /opt/AIEC-RAG/rag_api_server_production.py
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
启用服务:
|
||||
|
||||
```bash
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl enable aiec-rag
|
||||
sudo systemctl start aiec-rag
|
||||
sudo systemctl status aiec-rag
|
||||
```
|
||||
|
||||
## Docker部署
|
||||
|
||||
### 1. 使用预构建镜像
|
||||
|
||||
```bash
|
||||
# 拉取镜像(如果有私有仓库)
|
||||
docker pull your-registry/aiec-rag:latest
|
||||
|
||||
# 或构建本地镜像
|
||||
docker build -t aiec-rag:latest .
|
||||
```
|
||||
|
||||
### 2. Docker Compose部署
|
||||
|
||||
创建 `docker-compose.yml`:
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
elasticsearch:
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.11.0
|
||||
container_name: aiec-elasticsearch
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
- "ES_JAVA_OPTS=-Xms2g -Xmx2g"
|
||||
- xpack.security.enabled=true
|
||||
- ELASTIC_PASSWORD=your_password
|
||||
volumes:
|
||||
- es_data:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
networks:
|
||||
- aiec_network
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9200"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
|
||||
aiec-rag:
|
||||
build: .
|
||||
container_name: aiec-rag
|
||||
depends_on:
|
||||
elasticsearch:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- ELASTICSEARCH_HOST=http://elasticsearch:9200
|
||||
- ELASTICSEARCH_USERNAME=elastic
|
||||
- ELASTICSEARCH_PASSWORD=your_password
|
||||
env_file:
|
||||
- .env
|
||||
ports:
|
||||
- "8100:8100"
|
||||
volumes:
|
||||
- ./rag_config_production.yaml:/app/rag_config_production.yaml
|
||||
- ./api_outputs:/app/api_outputs
|
||||
networks:
|
||||
- aiec_network
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
es_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
aiec_network:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
启动服务:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
### 3. Kubernetes部署
|
||||
|
||||
创建 `k8s-deployment.yaml`:
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: aiec-rag
|
||||
labels:
|
||||
app: aiec-rag
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: aiec-rag
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: aiec-rag
|
||||
spec:
|
||||
containers:
|
||||
- name: aiec-rag
|
||||
image: your-registry/aiec-rag:latest
|
||||
ports:
|
||||
- containerPort: 8100
|
||||
env:
|
||||
- name: ELASTICSEARCH_HOST
|
||||
value: "http://elasticsearch-service:9200"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: aiec-secrets
|
||||
resources:
|
||||
requests:
|
||||
memory: "4Gi"
|
||||
cpu: "2"
|
||||
limits:
|
||||
memory: "8Gi"
|
||||
cpu: "4"
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8100
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 10
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8100
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: aiec-rag-service
|
||||
spec:
|
||||
selector:
|
||||
app: aiec-rag
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 80
|
||||
targetPort: 8100
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
部署到Kubernetes:
|
||||
|
||||
```bash
|
||||
# 创建密钥
|
||||
kubectl create secret generic aiec-secrets --from-env-file=.env
|
||||
|
||||
# 部署应用
|
||||
kubectl apply -f k8s-deployment.yaml
|
||||
|
||||
# 查看状态
|
||||
kubectl get pods
|
||||
kubectl get services
|
||||
```
|
||||
|
||||
## 生产环境部署
|
||||
|
||||
### 1. 负载均衡配置
|
||||
|
||||
使用Nginx作为负载均衡器:
|
||||
|
||||
```nginx
|
||||
upstream aiec_backend {
|
||||
least_conn;
|
||||
server 10.0.1.10:8100 weight=1 max_fails=3 fail_timeout=30s;
|
||||
server 10.0.1.11:8100 weight=1 max_fails=3 fail_timeout=30s;
|
||||
server 10.0.1.12:8100 weight=1 max_fails=3 fail_timeout=30s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name api.aiec-rag.com;
|
||||
|
||||
location / {
|
||||
proxy_pass http://aiec_backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# 超时设置
|
||||
proxy_connect_timeout 60s;
|
||||
proxy_send_timeout 120s;
|
||||
proxy_read_timeout 120s;
|
||||
}
|
||||
|
||||
# 健康检查端点
|
||||
location /health {
|
||||
proxy_pass http://aiec_backend/health;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. SSL/TLS配置
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name api.aiec-rag.com;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/aiec-rag.crt;
|
||||
ssl_certificate_key /etc/nginx/ssl/aiec-rag.key;
|
||||
|
||||
ssl_protocols TLSv1.2 TLSv1.3;
|
||||
ssl_ciphers HIGH:!aNULL:!MD5;
|
||||
ssl_prefer_server_ciphers on;
|
||||
|
||||
# ... 其他配置同上
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 数据库优化
|
||||
|
||||
Elasticsearch优化配置:
|
||||
|
||||
```yaml
|
||||
# elasticsearch.yml
|
||||
cluster.name: aiec-rag-cluster
|
||||
node.name: node-1
|
||||
|
||||
# 内存设置
|
||||
bootstrap.memory_lock: true
|
||||
|
||||
# 线程池
|
||||
thread_pool:
|
||||
write:
|
||||
size: 8
|
||||
queue_size: 1000
|
||||
search:
|
||||
size: 16
|
||||
queue_size: 1000
|
||||
|
||||
# 索引设置
|
||||
index:
|
||||
number_of_shards: 3
|
||||
number_of_replicas: 1
|
||||
refresh_interval: 30s
|
||||
```
|
||||
|
||||
## 性能调优
|
||||
|
||||
### 1. Python应用优化
|
||||
|
||||
```python
|
||||
# 使用Gunicorn作为WSGI服务器(Linux)
|
||||
gunicorn -w 4 -k uvicorn.workers.UvicornWorker \
|
||||
--bind 0.0.0.0:8100 \
|
||||
--timeout 120 \
|
||||
--keep-alive 5 \
|
||||
--max-requests 1000 \
|
||||
--max-requests-jitter 50 \
|
||||
rag_api_server_production:app
|
||||
```
|
||||
|
||||
### 2. 系统参数优化
|
||||
|
||||
```bash
|
||||
# /etc/sysctl.conf
|
||||
net.ipv4.tcp_fin_timeout = 30
|
||||
net.ipv4.tcp_tw_reuse = 1
|
||||
net.ipv4.tcp_tw_recycle = 1
|
||||
net.ipv4.tcp_max_syn_backlog = 8192
|
||||
net.ipv4.tcp_max_tw_buckets = 10000
|
||||
net.core.somaxconn = 65535
|
||||
net.core.netdev_max_backlog = 65535
|
||||
|
||||
# 应用配置
|
||||
sudo sysctl -p
|
||||
```
|
||||
|
||||
### 3. 缓存策略
|
||||
|
||||
配置Redis缓存:
|
||||
|
||||
```python
|
||||
# 在代码中添加缓存支持
|
||||
import redis
|
||||
from functools import lru_cache
|
||||
|
||||
redis_client = redis.Redis(
|
||||
host='localhost',
|
||||
port=6379,
|
||||
decode_responses=True,
|
||||
max_connections=50
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_embedding(text: str):
|
||||
# 缓存嵌入向量
|
||||
pass
|
||||
```
|
||||
|
||||
## 监控配置
|
||||
|
||||
### 1. Prometheus监控
|
||||
|
||||
```yaml
|
||||
# prometheus.yml
|
||||
scrape_configs:
|
||||
- job_name: 'aiec-rag'
|
||||
static_configs:
|
||||
- targets: ['localhost:8100']
|
||||
metrics_path: '/metrics'
|
||||
scrape_interval: 15s
|
||||
```
|
||||
|
||||
### 2. 日志管理
|
||||
|
||||
配置日志轮转:
|
||||
|
||||
```bash
|
||||
# /etc/logrotate.d/aiec-rag
|
||||
/opt/AIEC-RAG/logs/*.log {
|
||||
daily
|
||||
rotate 30
|
||||
compress
|
||||
delaycompress
|
||||
missingok
|
||||
notifempty
|
||||
create 644 aiec aiec
|
||||
sharedscripts
|
||||
postrotate
|
||||
systemctl reload aiec-rag
|
||||
endscript
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 告警配置
|
||||
|
||||
```yaml
|
||||
# alerting_rules.yml
|
||||
groups:
|
||||
- name: aiec_alerts
|
||||
rules:
|
||||
- alert: HighResponseTime
|
||||
expr: http_request_duration_seconds{quantile="0.99"} > 5
|
||||
for: 5m
|
||||
labels:
|
||||
severity: warning
|
||||
annotations:
|
||||
summary: "High response time on {{ $labels.instance }}"
|
||||
|
||||
- alert: ServiceDown
|
||||
expr: up{job="aiec-rag"} == 0
|
||||
for: 1m
|
||||
labels:
|
||||
severity: critical
|
||||
annotations:
|
||||
summary: "AIEC-RAG service is down"
|
||||
```
|
||||
|
||||
## 备份恢复
|
||||
|
||||
### 1. 数据备份
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# backup.sh
|
||||
DATE=$(date +%Y%m%d_%H%M%S)
|
||||
BACKUP_DIR="/backup/aiec-rag"
|
||||
|
||||
# 备份Elasticsearch数据
|
||||
curl -X PUT "localhost:9200/_snapshot/backup_repo" -H 'Content-Type: application/json' -d'
|
||||
{
|
||||
"type": "fs",
|
||||
"settings": {
|
||||
"location": "'$BACKUP_DIR'/elasticsearch"
|
||||
}
|
||||
}'
|
||||
|
||||
curl -X PUT "localhost:9200/_snapshot/backup_repo/snapshot_$DATE?wait_for_completion=true"
|
||||
|
||||
# 备份配置文件
|
||||
tar -czf $BACKUP_DIR/config_$DATE.tar.gz \
|
||||
/opt/AIEC-RAG/.env \
|
||||
/opt/AIEC-RAG/rag_config_production.yaml
|
||||
|
||||
echo "Backup completed: $DATE"
|
||||
```
|
||||
|
||||
### 2. 恢复流程
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# restore.sh
|
||||
SNAPSHOT_NAME=$1
|
||||
|
||||
# 恢复Elasticsearch数据
|
||||
curl -X POST "localhost:9200/_snapshot/backup_repo/$SNAPSHOT_NAME/_restore"
|
||||
|
||||
# 恢复配置文件
|
||||
tar -xzf /backup/aiec-rag/config_latest.tar.gz -C /
|
||||
|
||||
# 重启服务
|
||||
systemctl restart aiec-rag
|
||||
|
||||
echo "Restore completed from: $SNAPSHOT_NAME"
|
||||
```
|
||||
|
||||
## 故障处理
|
||||
|
||||
### 常见问题处理
|
||||
|
||||
1. **服务无响应**
|
||||
```bash
|
||||
# 检查服务状态
|
||||
systemctl status aiec-rag
|
||||
# 查看日志
|
||||
journalctl -u aiec-rag -n 100
|
||||
# 重启服务
|
||||
systemctl restart aiec-rag
|
||||
```
|
||||
|
||||
2. **Elasticsearch连接失败**
|
||||
```bash
|
||||
# 检查ES状态
|
||||
curl -X GET "localhost:9200/_cluster/health?pretty"
|
||||
# 检查网络连接
|
||||
telnet localhost 9200
|
||||
```
|
||||
|
||||
3. **内存溢出**
|
||||
```bash
|
||||
# 增加内存限制
|
||||
export PYTHONUNBUFFERED=1
|
||||
export OMP_NUM_THREADS=4
|
||||
```
|
||||
|
||||
## 安全建议
|
||||
|
||||
1. **API密钥管理**
|
||||
- 使用密钥管理服务(如HashiCorp Vault)
|
||||
- 定期轮换API密钥
|
||||
- 不要在代码中硬编码密钥
|
||||
|
||||
2. **网络安全**
|
||||
- 使用防火墙限制访问
|
||||
- 配置SSL/TLS加密
|
||||
- 实施速率限制
|
||||
|
||||
3. **数据安全**
|
||||
- 加密敏感数据
|
||||
- 定期备份
|
||||
- 实施访问控制
|
||||
|
||||
---
|
||||
|
||||
*更多部署问题,请参考项目Wiki或联系技术支持。*
|
||||
41
Dockerfile
Normal file
41
Dockerfile
Normal file
@ -0,0 +1,41 @@
|
||||
# Python RAG API服务 Dockerfile
|
||||
FROM python:3.10-slim
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY retriver/requirements.txt /app/requirements.txt
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt && \
|
||||
pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
fastapi \
|
||||
uvicorn[standard] \
|
||||
pyyaml
|
||||
|
||||
# 复制项目文件
|
||||
COPY . /app
|
||||
|
||||
# 创建输出目录
|
||||
RUN mkdir -p /app/api_outputs
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONPATH=/app
|
||||
ENV RAG_CONFIG_PATH=rag_config_production.yaml
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 健康检查
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "rag_api_server_production:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
61
Dockerfile.cpu-full
Normal file
61
Dockerfile.cpu-full
Normal file
@ -0,0 +1,61 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 配置pip使用阿里云镜像
|
||||
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
|
||||
ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
|
||||
ENV PIP_NO_CACHE_DIR=1
|
||||
|
||||
# 复制requirements文件
|
||||
COPY requirements-cpu-full.txt requirements.txt
|
||||
|
||||
# 不使用apt-get,直接用pip
|
||||
|
||||
# 升级pip并安装所有依赖
|
||||
RUN pip install --upgrade pip setuptools wheel && \
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu && \
|
||||
pip install transformers sentence-transformers && \
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 确保test_with_concept.pkl文件在正确位置
|
||||
# 文件应该在项目根目录,如果不在,需要单独复制
|
||||
RUN if [ -f /app/test_with_concept.pkl ]; then \
|
||||
echo "test_with_concept.pkl found in /app"; \
|
||||
else \
|
||||
echo "WARNING: test_with_concept.pkl not found"; \
|
||||
fi
|
||||
|
||||
# 创建必要目录
|
||||
RUN mkdir -p /app/api_outputs /app/logs
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV RAG_CONFIG_PATH=rag_config_production.yaml
|
||||
|
||||
# API配置
|
||||
ENV ONEAPI_KEY=sk-2c5045478a244022865e65c5c9b6adf7
|
||||
ENV ONEAPI_MODEL=qwen2-7b-instruct
|
||||
ENV ONEAPI_MODEL_GEN=qwen2-7b-instruct
|
||||
ENV ONEAPI_MODEL_MAX=qwen2-7b-instruct
|
||||
ENV ONEAPI_MODEL_EMBED=text-embedding-v3
|
||||
|
||||
# LangSmith配置
|
||||
ENV LANGCHAIN_TRACING_V2=false
|
||||
ENV LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
|
||||
ENV LANGCHAIN_API_KEY=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
|
||||
ENV LANGCHAIN_PROJECT=hipporag-retriever
|
||||
|
||||
# Elasticsearch配置
|
||||
ENV ELASTICSEARCH_HOST=http://101.200.154.78:9200
|
||||
ENV ELASTICSEARCH_USERNAME=elastic
|
||||
ENV ELASTICSEARCH_PASSWORD=Abcd123456
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "rag_api_server_production:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
BIN
__pycache__/config_loader.cpython-311.pyc
Normal file
BIN
__pycache__/config_loader.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/config_manager.cpython-311.pyc
Normal file
BIN
__pycache__/config_manager.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/config_wrapper.cpython-311.pyc
Normal file
BIN
__pycache__/config_wrapper.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/edge_filter_monkey_patch.cpython-311.pyc
Normal file
BIN
__pycache__/edge_filter_monkey_patch.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/prompt_loader.cpython-311.pyc
Normal file
BIN
__pycache__/prompt_loader.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/prompt_loader.cpython-312.pyc
Normal file
BIN
__pycache__/prompt_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/rag_api_server.cpython-311.pyc
Normal file
BIN
__pycache__/rag_api_server.cpython-311.pyc
Normal file
Binary file not shown.
BIN
__pycache__/rag_api_server_production.cpython-311.pyc
Normal file
BIN
__pycache__/rag_api_server_production.cpython-311.pyc
Normal file
Binary file not shown.
1
atlas_rag/__init__.py
Normal file
1
atlas_rag/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .logging import setup_logger
|
||||
BIN
atlas_rag/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
atlas_rag/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/__pycache__/logging.cpython-311.pyc
Normal file
BIN
atlas_rag/__pycache__/logging.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/__pycache__/logging.cpython-39.pyc
Normal file
BIN
atlas_rag/__pycache__/logging.cpython-39.pyc
Normal file
Binary file not shown.
1
atlas_rag/evaluation/__init__.py
Normal file
1
atlas_rag/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .benchmark import BenchMarkConfig, RAGBenchmark
|
||||
BIN
atlas_rag/evaluation/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/evaluation/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/evaluation/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
atlas_rag/evaluation/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/evaluation/__pycache__/benchmark.cpython-311.pyc
Normal file
BIN
atlas_rag/evaluation/__pycache__/benchmark.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/evaluation/__pycache__/benchmark.cpython-39.pyc
Normal file
BIN
atlas_rag/evaluation/__pycache__/benchmark.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/evaluation/__pycache__/evaluation.cpython-311.pyc
Normal file
BIN
atlas_rag/evaluation/__pycache__/evaluation.cpython-311.pyc
Normal file
Binary file not shown.
236
atlas_rag/evaluation/benchmark.py
Normal file
236
atlas_rag/evaluation/benchmark.py
Normal file
@ -0,0 +1,236 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from logging import Logger
|
||||
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from transformers import AutoModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.evaluation.evaluation import QAJudger
|
||||
from dataclasses import dataclass
|
||||
from atlas_rag.llm_generator.prompt.react import ReAct
|
||||
|
||||
|
||||
def normalize_embeddings(embeddings):
|
||||
"""Normalize the embeddings to unit length (L2 norm)."""
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
# Handle PyTorch tensors
|
||||
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
elif isinstance(embeddings, np.ndarray):
|
||||
# Handle numpy arrays
|
||||
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
|
||||
|
||||
return norm_emb
|
||||
|
||||
@dataclass
|
||||
class BenchMarkConfig:
|
||||
"""
|
||||
Configuration class for benchmarking.
|
||||
|
||||
Attributes:
|
||||
dataset_name (str): Name of the dataset. Default is "hotpotqa".
|
||||
question_file (str): Path to the question file. Default is "hotpotqa".
|
||||
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
|
||||
include_events (bool): Whether to include events. Default is False.
|
||||
include_concept (bool): Whether to include concepts. Default is False.
|
||||
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
|
||||
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
|
||||
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
|
||||
"""
|
||||
dataset_name: str = "hotpotqa"
|
||||
question_file: str = "hotpotqa"
|
||||
include_events: bool = False
|
||||
include_concept: bool = False
|
||||
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
|
||||
encoder_model_name: str = "nvidia/NV-Embed-v2"
|
||||
number_of_samples: int = -1 # Default to -1 to use all samples
|
||||
react_max_iterations: int = 5
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logging : bool = self.logger is not None
|
||||
|
||||
def load_encoder_model(self, encoder_model_name, **kwargs):
|
||||
if encoder_model_name == "nvidia/NV-Embed-v2":
|
||||
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
|
||||
return NvEmbed(sentence_encoder)
|
||||
else:
|
||||
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
|
||||
return SentenceEmbedding(sentence_encoder)
|
||||
|
||||
def run(self, retrievers:List[BaseRetriever],
|
||||
llm_generator:LLMGenerator,
|
||||
use_react: bool = False):
|
||||
qa_judge = QAJudger()
|
||||
if use_react:
|
||||
react_agent = ReAct(llm_generator=llm_generator)
|
||||
result_list = []
|
||||
with open(self.config.question_file, "r") as f:
|
||||
data = json.load(f)
|
||||
print(f"Data loaded from {self.config.question_file}")
|
||||
if self.config.number_of_samples > 0:
|
||||
data = data[:self.config.number_of_samples]
|
||||
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
|
||||
for sample in tqdm(data):
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
gold_file_ids = []
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
for fact in sample["supporting_facts"]:
|
||||
gold_file_ids.append(fact[0])
|
||||
elif self.config.dataset_name == "musique":
|
||||
for paragraph in sample["paragraphs"]:
|
||||
if paragraph["is_supporting"]:
|
||||
gold_file_ids.append(paragraph["paragraph_text"])
|
||||
else:
|
||||
print("Dataset not supported")
|
||||
continue
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"gold_file_ids": gold_file_ids,
|
||||
}
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"Question: {question}")
|
||||
for retriever in retrievers:
|
||||
if use_react:
|
||||
# Use RAG with ReAct
|
||||
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
max_iterations=self.config.react_max_iterations,
|
||||
max_new_tokens=2048,
|
||||
logger=self.logger
|
||||
)
|
||||
self.logger.info(f"Search history: {search_history}")
|
||||
self.logger.info(f"Final answer: {llm_generated_answer}")
|
||||
# Store search history in results
|
||||
result[f"{retriever.__class__.__name__}_search_history"] = search_history
|
||||
|
||||
# Extract all retrieved contexts from search history
|
||||
all_contexts = []
|
||||
for _, action, observation in search_history:
|
||||
if "search" in action.lower() or "look up" in action.lower():
|
||||
all_contexts.append(observation)
|
||||
|
||||
sorted_context = "\n".join(all_contexts)
|
||||
sorted_context_ids = [] # We don't track IDs in ReAct mode
|
||||
else:
|
||||
# Original RAG implementation
|
||||
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
|
||||
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
retrieved_context = ". ".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
retrieved_context = "\n".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
|
||||
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
|
||||
|
||||
short_answer = qa_judge.split_answer(llm_generated_answer)
|
||||
em, f1 = qa_judge.judge(short_answer, answer)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_em"] = em
|
||||
result[f"{retriever.__class__.__name__ }_f1"] = f1
|
||||
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
|
||||
if not use_react:
|
||||
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
|
||||
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
|
||||
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
|
||||
|
||||
# Calculate recall
|
||||
if not use_react: # Only calculate recall for non-ReAct mode
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
|
||||
elif self.config.dataset_name == "musique":
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
|
||||
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
|
||||
|
||||
result_list.append(result)
|
||||
|
||||
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
|
||||
|
||||
|
||||
def save_results(self, result_list, retriever_names:List[str]):
|
||||
current_time = datetime.now()
|
||||
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
|
||||
|
||||
dataset_name = self.config.dataset_name
|
||||
include_events = self.config.include_events
|
||||
include_concept = self.config.include_concept
|
||||
encoder_model_name = self.config.encoder_model_name
|
||||
reader_model_name = self.config.reader_model_name
|
||||
|
||||
# use last part of model name as identifier
|
||||
if "/" in encoder_model_name:
|
||||
encoder_model_name = encoder_model_name.split("/")[-1]
|
||||
if "/" in reader_model_name:
|
||||
reader_model_name = reader_model_name.split("/")[-1]
|
||||
|
||||
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(summary_file)):
|
||||
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
|
||||
|
||||
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(result_dir)):
|
||||
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
|
||||
|
||||
summary_dict = self.calculate_summary(result_list, retriever_names)
|
||||
|
||||
with open(summary_file, "w") as f_summary:
|
||||
json.dump(summary_dict, f_summary)
|
||||
f_summary.write("\n")
|
||||
|
||||
with open(result_dir, "w") as f:
|
||||
for result in result_list:
|
||||
json.dump(result, f)
|
||||
f.write("\n")
|
||||
|
||||
def calculate_summary(self, result_list, method):
|
||||
summary_dict = {}
|
||||
for retriever_name in method:
|
||||
if not all(f"{retriever_name}_em" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_em in results")
|
||||
if not all(f"{retriever_name}_f1" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_f1 in results")
|
||||
|
||||
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
|
||||
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
|
||||
|
||||
# Only calculate recall metrics if they exist in the results
|
||||
if all(f"{retriever_name}_recall@2" in result for result in result_list):
|
||||
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
|
||||
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
f"{retriever_name}_average_recall@2": average_recall_2,
|
||||
f"{retriever_name}_average_recall@5": average_recall_5,
|
||||
})
|
||||
else:
|
||||
# For ReAct mode where recall metrics don't exist
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
})
|
||||
|
||||
return summary_dict
|
||||
158
atlas_rag/evaluation/evaluation.py
Normal file
158
atlas_rag/evaluation/evaluation.py
Normal file
@ -0,0 +1,158 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Tuple
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
class QAJudger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def split_answer(self, generated_text):
|
||||
if "Answer:" in generated_text:
|
||||
generated_text = generated_text.split("Answer:")[-1]
|
||||
elif "answer:" in generated_text:
|
||||
generated_text = generated_text.split("answer:")[-1]
|
||||
# if answer is none
|
||||
if not generated_text:
|
||||
return "none"
|
||||
return generated_text
|
||||
|
||||
def normalize_answer(self, answer: str) -> str:
|
||||
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
|
||||
# Lowercase and normalize whitespace
|
||||
answer = answer.lower()
|
||||
# Replace hyphens with spaces
|
||||
answer = answer.replace('-', ' ')
|
||||
# Remove all other punctuation
|
||||
answer = re.sub(r'[^\w\s]', '', answer)
|
||||
# Standardize whitespace
|
||||
return ' '.join(answer.split())
|
||||
|
||||
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
|
||||
"""Direct port of the original scoring logic"""
|
||||
# Extract answer from generated text
|
||||
pred_answer = self.split_answer(generated_text)
|
||||
|
||||
# Normalize both answers
|
||||
pred_norm = self.normalize_answer(pred_answer)
|
||||
ref_norm = self.normalize_answer(reference_text)
|
||||
|
||||
# Exact match calculation
|
||||
em = 1 if pred_norm == ref_norm else 0
|
||||
|
||||
# F1 calculation (direct port from QAF1Score)
|
||||
pred_tokens = pred_norm.split()
|
||||
ref_tokens = ref_norm.split()
|
||||
|
||||
common = Counter(pred_tokens) & Counter(ref_tokens)
|
||||
num_same = sum(common.values())
|
||||
|
||||
if num_same == 0:
|
||||
return em, 0.0
|
||||
|
||||
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
|
||||
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
|
||||
|
||||
if (precision + recall) == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
f1 = 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
return em, f1
|
||||
|
||||
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
|
||||
"""Calculates recall at k based on the top k retrieved texts."""
|
||||
successful_retrievals = 0
|
||||
|
||||
# Limit the retrieved texts to the top k entries
|
||||
limited_retrieved_text = retrieved_text[:k]
|
||||
|
||||
for ref_text in reference_text:
|
||||
for ret_text in limited_retrieved_text:
|
||||
if ref_text in ret_text:
|
||||
successful_retrievals += 1
|
||||
break
|
||||
|
||||
recall = successful_retrievals / len(reference_text) if reference_text else 0
|
||||
return recall
|
||||
|
||||
# recall for 1 answer
|
||||
def recall(self, retrieved_text: list, reference_text: list) -> dict:
|
||||
"""Calculates recall values at different k levels."""
|
||||
recall_values = {
|
||||
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
|
||||
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
|
||||
}
|
||||
return recall_values['recall@2'], recall_values['recall@5']
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argument_parser = argparse.ArgumentParser()
|
||||
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
|
||||
args = argument_parser.parse_args()
|
||||
|
||||
# Initialize the QAJudger
|
||||
llm_judge = QAJudger()
|
||||
|
||||
# Load results from the JSON file
|
||||
result_list = []
|
||||
with open(args.file_path, 'r') as file:
|
||||
for line in file:
|
||||
if line.strip(): # Make sure the line is not empty
|
||||
try:
|
||||
result = json.loads(line.strip())
|
||||
result_list.append(result)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
|
||||
# Debugging output to inspect the loaded data structure
|
||||
# print("Loaded data structure:", result_list)
|
||||
|
||||
# Evaluate each entry in result_list
|
||||
for result in tqdm(result_list):
|
||||
if isinstance(result, dict): # Ensure each result is a dictionary
|
||||
question = result["question"]
|
||||
answer = result["answer"]
|
||||
|
||||
# Evaluate generated answers with Hippo and Hippo2
|
||||
hippo_generated_answer = result["hippo_generated_answer"]
|
||||
hippo2_generated_answer = result["hippo2_generated_answer"]
|
||||
|
||||
# Split and judge the answers
|
||||
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
|
||||
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
|
||||
|
||||
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
|
||||
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
|
||||
|
||||
# Store the scores back in the result dictionary
|
||||
result["hippo_em"] = hippo_em
|
||||
result["hippo_f1"] = hippo_f1
|
||||
result["hippo2_em"] = hippo2_em
|
||||
result["hippo2_f1"] = hippo2_f1
|
||||
|
||||
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
|
||||
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
|
||||
|
||||
# Calculate averages
|
||||
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
|
||||
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
|
||||
|
||||
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
|
||||
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
|
||||
|
||||
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
|
||||
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
|
||||
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
|
||||
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
|
||||
# Output the averages
|
||||
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
|
||||
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
|
||||
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
|
||||
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
|
||||
|
||||
print(f"Average Recall@2: {average_recall2:.4f}")
|
||||
print(f"Average Recall@5: {average_recall5:.4f}")
|
||||
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
|
||||
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")
|
||||
0
atlas_rag/kg_construction/__init__.py
Normal file
0
atlas_rag/kg_construction/__init__.py
Normal file
BIN
atlas_rag/kg_construction/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/kg_construction/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
282
atlas_rag/kg_construction/concept_generation.py
Normal file
282
atlas_rag/kg_construction/concept_generation.py
Normal file
@ -0,0 +1,282 @@
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import logging
|
||||
import csv
|
||||
import os
|
||||
import hashlib
|
||||
import re
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import get_node_id
|
||||
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import CONCEPT_INSTRUCTIONS
|
||||
import pickle
|
||||
# Increase the field size limit
|
||||
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
|
||||
|
||||
|
||||
|
||||
def build_batch_data(sessions, batch_size):
|
||||
batched_sessions = []
|
||||
for i in range(0, len(sessions), batch_size):
|
||||
batched_sessions.append(sessions[i:i+batch_size])
|
||||
return batched_sessions
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
def convert_attribute(value):
|
||||
""" Convert attributes to GDS-compatible types. """
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
new_text = re.sub(r'\s+', ' ', new_text).strip()
|
||||
|
||||
return new_text
|
||||
|
||||
def remove_NUL(text):
|
||||
return text.replace("\x00", "")
|
||||
|
||||
def build_batched_events(all_node_list, batch_size):
|
||||
"The types are in Entity Event Relation"
|
||||
event_nodes = [node[0] for node in all_node_list if node[1].lower() == "event"]
|
||||
batched_events = []
|
||||
|
||||
for i in range(0, len(event_nodes), batch_size):
|
||||
batched_events.append(event_nodes[i:i+batch_size])
|
||||
|
||||
return batched_events
|
||||
|
||||
def build_batched_entities(all_node_list, batch_size):
|
||||
|
||||
entity_nodes = [node[0] for node in all_node_list if node[1].lower() == "entity"]
|
||||
batched_entities = []
|
||||
|
||||
for i in range(0, len(entity_nodes), batch_size):
|
||||
batched_entities.append(entity_nodes[i:i+batch_size])
|
||||
|
||||
return batched_entities
|
||||
|
||||
def build_batched_relations(all_node_list, batch_size):
|
||||
|
||||
relations = [node[0] for node in all_node_list if node[1].lower() == "relation"]
|
||||
# relations = list(set(relations))
|
||||
batched_relations = []
|
||||
|
||||
for i in range(0, len(relations), batch_size):
|
||||
batched_relations.append(relations[i:i+batch_size])
|
||||
|
||||
return batched_relations
|
||||
|
||||
def batched_inference(model:LLMGenerator, inputs, record=False, **kwargs):
|
||||
responses = model.generate_response(inputs, return_text_only = not record, **kwargs)
|
||||
answers = []
|
||||
if record:
|
||||
text_responses = [response[0] for response in responses]
|
||||
usages = [response[1] for response in responses]
|
||||
else:
|
||||
text_responses = responses
|
||||
for i in range(len(text_responses)):
|
||||
answer = text_responses[i]
|
||||
answers.append([x.strip().lower() for x in answer.split(",")])
|
||||
|
||||
if record:
|
||||
return answers, usages
|
||||
else:
|
||||
return answers
|
||||
|
||||
def load_data_with_shard(input_file, shard_idx, num_shards):
|
||||
|
||||
with open(input_file, "r") as f:
|
||||
csv_reader = list(csv.reader(f))
|
||||
|
||||
# data = csv_reader
|
||||
data = csv_reader[1:]
|
||||
# Random shuffle the data before splitting into shards
|
||||
random.shuffle(data)
|
||||
|
||||
total_lines = len(data)
|
||||
lines_per_shard = (total_lines + num_shards - 1) // num_shards
|
||||
start_idx = shard_idx * lines_per_shard
|
||||
end_idx = min((shard_idx + 1) * lines_per_shard, total_lines)
|
||||
|
||||
return data[start_idx:end_idx]
|
||||
|
||||
def generate_concept(model: LLMGenerator,
|
||||
input_file = 'processed_data/triples_csv',
|
||||
output_folder = 'processed_data/triples_conceptualized',
|
||||
output_file = 'output.json',
|
||||
logging_file = 'processed_data/logging.txt',
|
||||
config:ProcessingConfig=None,
|
||||
batch_size=32,
|
||||
shard=0,
|
||||
num_shards=1,
|
||||
**kwargs):
|
||||
log_dir = os.path.dirname(logging_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
# Create the log file if it doesn't exist
|
||||
if not os.path.exists(logging_file):
|
||||
open(logging_file, 'w').close()
|
||||
|
||||
language = kwargs.get('language', 'en')
|
||||
record = kwargs.get('record', False)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
file_handler = logging.FileHandler(logging_file)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
with open(f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl", "rb") as f:
|
||||
temp_kg = pickle.load(f)
|
||||
|
||||
# read data
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
all_missing_nodes = load_data_with_shard(
|
||||
input_file,
|
||||
shard_idx=shard,
|
||||
num_shards=num_shards
|
||||
)
|
||||
|
||||
batched_events = build_batched_events(all_missing_nodes, batch_size)
|
||||
batched_entities = build_batched_entities(all_missing_nodes, batch_size)
|
||||
batched_relations = build_batched_relations(all_missing_nodes, batch_size)
|
||||
|
||||
all_batches = []
|
||||
all_batches.extend(('event', batch) for batch in batched_events)
|
||||
all_batches.extend(('entity', batch) for batch in batched_entities)
|
||||
all_batches.extend(('relation', batch) for batch in batched_relations)
|
||||
|
||||
print("all_batches", len(all_batches))
|
||||
|
||||
|
||||
|
||||
|
||||
output_file = output_folder + f"/{output_file.rsplit('.', 1)[0]}_shard_{shard}.csv"
|
||||
with open(output_file, "w", newline='') as file:
|
||||
csv_writer = csv.writer(file)
|
||||
csv_writer.writerow(["node", "conceptualized_node", "node_type"])
|
||||
|
||||
# for batch_type, batch in tqdm(all_batches, total=total_batches, desc="Generating concepts"):
|
||||
# don't use tqdm for now
|
||||
for batch_type, batch in tqdm(all_batches, desc="Shard_{}".format(shard)):
|
||||
# print("batch_type", batch_type)
|
||||
# print("batch", batch)
|
||||
replace_context_token = None
|
||||
if batch_type == 'event':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['event']
|
||||
node_type = 'event'
|
||||
replace_token = '[EVENT]'
|
||||
elif batch_type == 'entity':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['entity']
|
||||
node_type = 'entity'
|
||||
replace_token = '[ENTITY]'
|
||||
replace_context_token = '[CONTEXT]'
|
||||
elif batch_type == 'relation':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['relation']
|
||||
node_type = 'relation'
|
||||
replace_token = '[RELATION]'
|
||||
|
||||
inputs = []
|
||||
for node in batch:
|
||||
# sample node from given node and replace context token.
|
||||
if replace_context_token:
|
||||
node_id = get_node_id(node)
|
||||
entity_predecessors = list(temp_kg.predecessors(node_id))
|
||||
entity_successors = list(temp_kg.successors(node_id))
|
||||
|
||||
context = ""
|
||||
|
||||
if len(entity_predecessors) > 0:
|
||||
random_two_neighbors = random.sample(entity_predecessors, min(1, len(entity_predecessors)))
|
||||
context += ", ".join([f"{temp_kg.nodes[neighbor]['id']} {temp_kg[neighbor][node_id]['relation']}" for neighbor in random_two_neighbors])
|
||||
|
||||
if len(entity_successors) > 0:
|
||||
random_two_neighbors = random.sample(entity_successors, min(1, len(entity_successors)))
|
||||
context += ", ".join([f"{temp_kg[node_id][neighbor]['relation']} {temp_kg.nodes[neighbor]['id']}" for neighbor in random_two_neighbors])
|
||||
|
||||
prompt = template.replace(replace_token, node).replace(replace_context_token, context)
|
||||
else:
|
||||
prompt = template.replace(replace_token, node)
|
||||
constructed_input = [
|
||||
{"role": "system", "content": "You are a helpful AI assistant."},
|
||||
{"role": "user", "content": f"{prompt}"},
|
||||
]
|
||||
inputs.append(constructed_input)
|
||||
|
||||
try:
|
||||
# print("inputs", inputs)
|
||||
if record:
|
||||
# If recording, we will get both answers and responses
|
||||
answers, usages = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
|
||||
else:
|
||||
answers = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
|
||||
usages = None
|
||||
# print("answers", answers)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing {batch_type} batch: {e}")
|
||||
raise e
|
||||
# try:
|
||||
# answers = batched_inference(llm, sampling_params, inputs)
|
||||
# except Exception as e:
|
||||
# logging.error(f"Error processing {batch_type} batch: {e}")
|
||||
# continue
|
||||
|
||||
for i,(node, answer) in enumerate(zip(batch, answers)):
|
||||
# print(node, answer, node_type)
|
||||
if usages is not None:
|
||||
logging.info(f"Usage log: Node {node}, completion_usage: {usages[i]}")
|
||||
csv_writer.writerow([node, ", ".join(answer), node_type])
|
||||
file.flush()
|
||||
# count unique conceptualized nodes
|
||||
conceptualized_nodes = []
|
||||
|
||||
conceptualized_events = []
|
||||
conceptualized_entities = []
|
||||
conceptualized_relations = []
|
||||
|
||||
with open(output_file, "r") as file:
|
||||
reader = csv.reader(file)
|
||||
next(reader)
|
||||
for row in reader:
|
||||
conceptualized_nodes.extend(row[1].split(","))
|
||||
if row[2] == "event":
|
||||
conceptualized_events.extend(row[1].split(","))
|
||||
elif row[2] == "entity":
|
||||
conceptualized_entities.extend(row[1].split(","))
|
||||
elif row[2] == "relation":
|
||||
conceptualized_relations.extend(row[1].split(","))
|
||||
|
||||
conceptualized_nodes = [x.strip() for x in conceptualized_nodes]
|
||||
conceptualized_events = [x.strip() for x in conceptualized_events]
|
||||
conceptualized_entities = [x.strip() for x in conceptualized_entities]
|
||||
conceptualized_relations = [x.strip() for x in conceptualized_relations]
|
||||
|
||||
unique_conceptualized_nodes = list(set(conceptualized_nodes))
|
||||
unique_conceptualized_events = list(set(conceptualized_events))
|
||||
unique_conceptualized_entities = list(set(conceptualized_entities))
|
||||
unique_conceptualized_relations = list(set(conceptualized_relations))
|
||||
|
||||
print(f"Number of unique conceptualized nodes: {len(unique_conceptualized_nodes)}")
|
||||
print(f"Number of unique conceptualized events: {len(unique_conceptualized_events)}")
|
||||
print(f"Number of unique conceptualized entities: {len(unique_conceptualized_entities)}")
|
||||
print(f"Number of unique conceptualized relations: {len(unique_conceptualized_relations)}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
153
atlas_rag/kg_construction/concept_to_csv.py
Normal file
153
atlas_rag/kg_construction/concept_to_csv.py
Normal file
@ -0,0 +1,153 @@
|
||||
import ast
|
||||
import uuid
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a random UUID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def parse_concepts(s):
|
||||
"""Parse concepts field and filter empty values"""
|
||||
try:
|
||||
parsed = ast.literal_eval(s) if s and s != '[]' else []
|
||||
return [c.strip() for c in parsed if c.strip()]
|
||||
except:
|
||||
return []
|
||||
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
text = text + '_concept'
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
|
||||
def all_concept_triples_csv_to_csv(node_file, edge_file, concepts_file, output_node_file, output_edge_file, output_full_concept_triple_edges):
|
||||
|
||||
# to deal add output the concepts nodes, edges, and new full_triple_edges,
|
||||
# we need to read the concepts maps to the memory, as it is usually not too large.
|
||||
# Then we need to iterate over the triple nodes to create concept edges
|
||||
# Finally we iterate over the triple edges to create the full_triple_edges
|
||||
|
||||
# Read missing concept
|
||||
# relation_concepts_mapping = {}
|
||||
# all_missing_concepts = []
|
||||
|
||||
# check if all output directories exist
|
||||
output_dir = os.path.dirname(output_node_file)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
output_dir = os.path.dirname(output_edge_file)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
output_dir = os.path.dirname(output_full_concept_triple_edges)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
node_to_concepts = {}
|
||||
relation_to_concepts = {}
|
||||
|
||||
all_concepts = set()
|
||||
with open(concepts_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
# Load missing concepts list
|
||||
print("Loading concepts...")
|
||||
for row in tqdm(reader):
|
||||
|
||||
|
||||
if row['node_type'] == 'relation':
|
||||
relation = row['node']
|
||||
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
|
||||
|
||||
if relation not in relation_to_concepts:
|
||||
relation_to_concepts[relation] = concepts
|
||||
else:
|
||||
relation_to_concepts[relation].extend(concepts)
|
||||
relation_to_concepts[relation] = list(set(relation_to_concepts[relation]))
|
||||
|
||||
else:
|
||||
node = row['node']
|
||||
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
|
||||
|
||||
if node not in node_to_concepts:
|
||||
node_to_concepts[node] = concepts
|
||||
else:
|
||||
node_to_concepts[node].extend(concepts)
|
||||
node_to_concepts[node] = list(set(node_to_concepts[node]))
|
||||
|
||||
print("Loading concepts done.")
|
||||
print(f"Relation to concepts: {len(relation_to_concepts)}")
|
||||
print(f"Node to concepts: {len(node_to_concepts)}")
|
||||
|
||||
# Read triple nodes and write to output concept edges files
|
||||
print("Processing triple nodes...")
|
||||
with open(node_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
# name:ID,type,concepts,synsets,:LABEL
|
||||
header = next(reader)
|
||||
|
||||
with open (output_edge_file, 'w', newline='', encoding='utf-8') as f_out:
|
||||
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
writer.writerow([':START_ID', ':END_ID', 'relation', ':TYPE'])
|
||||
|
||||
for row in tqdm(reader):
|
||||
node_name = row[0]
|
||||
if node_name in node_to_concepts:
|
||||
|
||||
for concept in node_to_concepts[node_name]:
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
|
||||
all_concepts.add(concept)
|
||||
|
||||
|
||||
for concept in parse_concepts(row[2]):
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
|
||||
all_concepts.add(concept)
|
||||
|
||||
# Read the concept nodes and write to output concept nodes file
|
||||
print("Processing concept nodes...")
|
||||
with open (output_node_file, 'w', newline='', encoding='utf-8') as f_out:
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
writer.writerow(['concept_id:ID', 'name', ':LABEL'])
|
||||
|
||||
for concept in tqdm(all_concepts):
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([concept_id, concept, 'Concept'])
|
||||
|
||||
|
||||
# Read triple edges and write to output full concept triple edges file
|
||||
print("Processing triple edges...")
|
||||
with open(edge_file, 'r', encoding='utf-8') as f:
|
||||
with open(output_full_concept_triple_edges, 'w', newline='', encoding='utf-8') as f_out:
|
||||
reader = csv.reader(f)
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
|
||||
|
||||
header = next(reader)
|
||||
writer.writerow([':START_ID', ':END_ID', 'relation', 'concepts', 'synsets', ':TYPE'])
|
||||
|
||||
for row in tqdm(reader):
|
||||
src_id = row[0]
|
||||
end_id = row[1]
|
||||
relation = row[2]
|
||||
concepts = row[3]
|
||||
synsets = row[4]
|
||||
|
||||
original_concepts = parse_concepts(concepts)
|
||||
|
||||
|
||||
if relation in relation_to_concepts:
|
||||
for concept in relation_to_concepts[relation]:
|
||||
if concept not in original_concepts:
|
||||
original_concepts.append(concept)
|
||||
original_concepts = list(set(original_concepts))
|
||||
|
||||
writer.writerow([src_id, end_id, relation, original_concepts, synsets, 'Relation'])
|
||||
return
|
||||
0
atlas_rag/kg_construction/neo4j/__init__.py
Normal file
0
atlas_rag/kg_construction/neo4j/__init__.py
Normal file
267
atlas_rag/kg_construction/neo4j/neo4j_api.py
Normal file
267
atlas_rag/kg_construction/neo4j/neo4j_api.py
Normal file
@ -0,0 +1,267 @@
|
||||
import time
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from logging import Logger
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever, BaseLargeKGEdgeRetriever
|
||||
from atlas_rag.kg_construction.neo4j.utils import start_up_large_kg_index_graph
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from neo4j import Driver
|
||||
from dataclasses import dataclass
|
||||
import traceback
|
||||
|
||||
@dataclass
|
||||
class LargeKGConfig:
|
||||
largekg_retriever: BaseLargeKGRetriever | BaseLargeKGEdgeRetriever = None
|
||||
reader_llm_generator : LLMGenerator = None
|
||||
driver: Driver = None
|
||||
logger: Logger = None
|
||||
is_felm: bool = False
|
||||
is_mmlu: bool = False
|
||||
|
||||
rag_exemption_list = [
|
||||
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to assess whether each text segment contains errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: 8923164*7236571?\nSegments: \n1. The product of 8923164 and 7236571 is: 6,461,216,222,844\n2. So, 8923164 multiplied by 7236571 is equal to 6,461,216,222,844.\n\nBelow are your outputs:\nAnswer: 1,2\nIt means segment 1,2 contain errors.""",
|
||||
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to determine whether each text segment contains factual errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: A company offers a 10% discount on all purchases over $100. A customer purchases three items, each costing $80. Does the customer qualify for the discount?\nSegments: \n1. To solve this problem, we need to use deductive reasoning. We know that the company offers a 10% discount on purchases over $100, so we need to calculate the total cost of the customer's purchase.\n2. The customer purchased three items, each costing $80, so the total cost of the purchase is: 3 x $80 = $200.\n3. Since the total cost of the purchase is greater than $100, the customer qualifies for the discount. \n4. To calculate the discounted price, we can multiply the total cost by 0.1 (which represents the 10% discount): $200 x 0.1 = $20.\n5. So the customer is eligible for a discount of $20, and the final cost of the purchase would be: $200 - $20 = $180.\n6. Therefore, the customer would pay a total of $216 for the three items with the discount applied.\n\nBelow are your outputs:\nAnswer: 2,3,4,5,6\nIt means segment 2,3,4,5,6 contains errors.""",
|
||||
]
|
||||
|
||||
mmlu_check_list = [
|
||||
"""Given the following question and four candidate answers (A, B, C and D), choose the answer."""
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global large_kg_config
|
||||
start_up_large_kg_index_graph(large_kg_config.driver)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
global large_kg_config
|
||||
print("Shutting down the model...")
|
||||
del large_kg_config
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "test"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "system", "assistant"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
retriever_config: Optional[dict] = {
|
||||
"topN": 5,
|
||||
"number_of_source_nodes_per_ner": 10,
|
||||
"sampling_area": 250,
|
||||
"Dmax": 2,
|
||||
"Wmax": 3
|
||||
}
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "function_call"]
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
index: int
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global large_kg_config
|
||||
try:
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
print(request)
|
||||
# raise HTTPException(status_code=400, detail="Invalid request")
|
||||
if large_kg_config.logger is not None:
|
||||
large_kg_config.logger.info(f"Request: {request}")
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=0.8,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=False,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
last_message = request.messages[-1]
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
|
||||
|
||||
is_exemption = any(exemption in question for exemption in LargeKGConfig.rag_exemption_list)
|
||||
is_mmlu = any(exemption in question for exemption in LargeKGConfig.mmlu_check_list)
|
||||
print(f"Is exemption: {is_exemption}, Is MMLU: {is_mmlu}")
|
||||
if is_mmlu:
|
||||
rag_text = question
|
||||
else:
|
||||
parts = question.rsplit("Question:", 1)
|
||||
rag_text = parts[-1] if len(parts) > 1 else None
|
||||
print(f"RAG text: {rag_text}")
|
||||
if not is_exemption:
|
||||
passages, passages_score = large_kg_config.largekg_retriever.retrieve_passages(rag_text)
|
||||
context = "No retrieved context, Please answer the question with your own knowledge." if not passages else "\n".join([f"Passage {i+1}: {text}" for i, text in enumerate(reversed(passages))])
|
||||
|
||||
if is_mmlu:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Here is the context: {context} \n\n
|
||||
If the context is not useful, you can answer the question with your own knowledge. \n {question}\nThink step by step. Your response should end with 'The answer is ([the_answer_letter])' where the [the_answer_letter] is one of A, B, C and D."""
|
||||
}
|
||||
]
|
||||
elif not is_exemption:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} Reference doc: {context}"""
|
||||
}
|
||||
]
|
||||
else:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} """
|
||||
}
|
||||
]
|
||||
if large_kg_config.logger is not None:
|
||||
large_kg_config.logger.info(rag_chat_content)
|
||||
|
||||
response = large_kg_config.reader_llm_generator.generate_response(
|
||||
batch_messages=rag_chat_content,
|
||||
max_new_tokens=gen_params["max_tokens"],
|
||||
temperature=gen_params["temperature"],
|
||||
frequency_penalty = 1.1
|
||||
)
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.strip()
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop"
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
object="chat.completion",
|
||||
choices=[choice_data]
|
||||
)
|
||||
except Exception as e:
|
||||
print("ERROR: ", e)
|
||||
print("Catched error")
|
||||
traceback.print_exc()
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=0.8,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=False,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
last_message = request.messages[-1]
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} """
|
||||
}
|
||||
]
|
||||
response = large_kg_config.reader_llm_generator.generate_response(
|
||||
batch_messages=rag_chat_content,
|
||||
max_new_tokens=gen_params["max_tokens"],
|
||||
temperature=gen_params["temperature"],
|
||||
frequency_penalty = 1.1
|
||||
)
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.strip()
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop"
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
object="chat.completion",
|
||||
choices=[choice_data]
|
||||
)
|
||||
|
||||
|
||||
def start_app(user_config:LargeKGConfig, host="0.0.0.0", port=10090, reload=False):
|
||||
"""Function to start the FastAPI application."""
|
||||
global large_kg_config
|
||||
large_kg_config = user_config # Use the passed context if provided
|
||||
|
||||
uvicorn.run(f"atlas_rag.kg_construction.neo4j.neo4j_api:app", host=host, port=port, reload=reload)
|
||||
|
||||
73
atlas_rag/kg_construction/neo4j/utils.py
Normal file
73
atlas_rag/kg_construction/neo4j/utils.py
Normal file
@ -0,0 +1,73 @@
|
||||
import faiss
|
||||
from neo4j import Driver
|
||||
import time
|
||||
from graphdatascience import GraphDataScience
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
||||
|
||||
def build_projection_graph(driver: GraphDataScience):
|
||||
project_graph_1 = "largekgrag_graph"
|
||||
is_project_graph_1_exist = False
|
||||
# is_project_graph_2_exist = False
|
||||
result = driver.graph.list()
|
||||
for index, row in result.iterrows():
|
||||
if row['graphName'] == project_graph_1:
|
||||
is_project_graph_1_exist = True
|
||||
# if row['graphName'] == project_graph_2:
|
||||
# is_project_graph_2_exist = True
|
||||
|
||||
if not is_project_graph_1_exist:
|
||||
start_time = time.time()
|
||||
node_properties = ["Node"]
|
||||
relation_projection = [ "Relation"]
|
||||
result = driver.graph.project(
|
||||
project_graph_1,
|
||||
node_properties,
|
||||
relation_projection
|
||||
)
|
||||
graph = driver.graph.get(project_graph_1)
|
||||
print(f"Projection graph {project_graph_1} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
def build_neo4j_label_index(driver: GraphDataScience):
|
||||
with driver.session() as session:
|
||||
index_name = f"NodeNumericIDIndex"
|
||||
# Check if the index already exists
|
||||
existing_indexes = session.run("SHOW INDEXES").data()
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
# Drop the index if it exists
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR (n:Node) ON (n.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
index_name = f"TextNumericIDIndex"
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR (t:Text) ON (t.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
index_name = f"EntityEventEdgeNumericIDIndex"
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR ()-[r:Relation]-() on (r.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
def load_indexes(path_dict):
|
||||
for key, value in path_dict.items():
|
||||
if key == 'node':
|
||||
node_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Node index loaded from {value}")
|
||||
elif key == 'edge':
|
||||
edge_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Edge index loaded from {value}")
|
||||
elif key == 'text':
|
||||
passage_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Passage index loaded from {value}")
|
||||
return node_index, edge_index, passage_index
|
||||
|
||||
def start_up_large_kg_index_graph(neo4j_driver: Driver)->BaseLargeKGRetriever:
|
||||
gds_driver = GraphDataScience(neo4j_driver)
|
||||
# build label index and projection graph
|
||||
build_neo4j_label_index(neo4j_driver)
|
||||
build_projection_graph(gds_driver)
|
||||
22
atlas_rag/kg_construction/triple_config.py
Normal file
22
atlas_rag/kg_construction/triple_config.py
Normal file
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
"""Configuration for text processing pipeline."""
|
||||
model_path: str
|
||||
data_directory: str
|
||||
filename_pattern: str
|
||||
batch_size_triple: int = 16
|
||||
batch_size_concept: int = 64
|
||||
output_directory: str = "./generation_result_debug"
|
||||
total_shards_triple: int = 1
|
||||
current_shard_triple: int = 0
|
||||
total_shards_concept: int = 1
|
||||
current_shard_concept: int = 0
|
||||
use_8bit: bool = False
|
||||
debug_mode: bool = False
|
||||
resume_from: int = 0
|
||||
record : bool = False
|
||||
max_new_tokens: int = 8192
|
||||
max_workers: int = 8
|
||||
remove_doc_spaces: bool = False
|
||||
497
atlas_rag/kg_construction/triple_extraction.py
Normal file
497
atlas_rag/kg_construction/triple_extraction.py
Normal file
@ -0,0 +1,497 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Knowledge Graph Extraction Pipeline
|
||||
Extracts entities, relations, and events from text data using transformer models.
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
import json_repair
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
|
||||
from atlas_rag.kg_construction.concept_generation import generate_concept
|
||||
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
|
||||
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
|
||||
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
# Constants
|
||||
TOKEN_LIMIT = 1024
|
||||
INSTRUCTION_TOKEN_ESTIMATE = 200
|
||||
CHAR_TO_TOKEN_RATIO = 3.5
|
||||
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Handles text chunking based on token limits."""
|
||||
|
||||
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
|
||||
self.max_tokens = max_tokens
|
||||
self.instruction_tokens = instruction_tokens
|
||||
self.char_ratio = CHAR_TO_TOKEN_RATIO
|
||||
|
||||
def calculate_max_chars(self) -> int:
|
||||
"""Calculate maximum characters per chunk."""
|
||||
available_tokens = self.max_tokens - self.instruction_tokens
|
||||
return int(available_tokens * self.char_ratio)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split text into chunks that fit within token limits."""
|
||||
max_chars = self.calculate_max_chars()
|
||||
chunks = []
|
||||
|
||||
while len(text) > max_chars:
|
||||
chunks.append(text[:max_chars])
|
||||
text = text[max_chars:]
|
||||
|
||||
if text: # Add remaining text
|
||||
chunks.append(text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DatasetProcessor:
|
||||
"""Processes and prepares dataset for knowledge graph extraction."""
|
||||
|
||||
def __init__(self, config: ProcessingConfig):
|
||||
self.config = config
|
||||
self.chunker = TextChunker()
|
||||
|
||||
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
|
||||
"""Check if content is in English."""
|
||||
metadata = sample.get("metadata", {})
|
||||
language = metadata.get("lang", "en") # Default to English if not specified
|
||||
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
|
||||
return language in supported_languages
|
||||
|
||||
|
||||
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Create chunks from a single sample."""
|
||||
original_text = sample.get("text", "")
|
||||
if self.config.remove_doc_spaces:
|
||||
original_text = re.sub(r'\s+', ' ',original_text).strip()
|
||||
text_chunks = self.chunker.split_text(original_text)
|
||||
chunks = []
|
||||
|
||||
for chunk_idx, chunk_text in enumerate(text_chunks):
|
||||
chunk_data = {
|
||||
"id": sample["id"],
|
||||
"text": chunk_text,
|
||||
"chunk_id": chunk_idx,
|
||||
"metadata": sample["metadata"]
|
||||
}
|
||||
chunks.append(chunk_data)
|
||||
|
||||
return chunks
|
||||
|
||||
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
|
||||
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
|
||||
|
||||
processed_samples = []
|
||||
total_texts = len(raw_dataset)
|
||||
|
||||
# Handle edge cases
|
||||
if total_texts == 0:
|
||||
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||||
return processed_samples
|
||||
|
||||
# Calculate base and remainder for fair distribution
|
||||
base_texts_per_shard = total_texts // self.config.total_shards_triple
|
||||
remainder = total_texts % self.config.total_shards_triple
|
||||
|
||||
# Calculate start index
|
||||
if self.config.current_shard_triple < remainder:
|
||||
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
|
||||
else:
|
||||
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
|
||||
|
||||
# Calculate end index
|
||||
if self.config.current_shard_triple < remainder:
|
||||
end_idx = start_idx + (base_texts_per_shard + 1)
|
||||
else:
|
||||
end_idx = start_idx + base_texts_per_shard
|
||||
|
||||
# Ensure indices are within bounds
|
||||
start_idx = min(start_idx, total_texts)
|
||||
end_idx = min(end_idx, total_texts)
|
||||
|
||||
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
|
||||
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
|
||||
|
||||
# Process documents in assigned shard
|
||||
for idx in range(start_idx, end_idx):
|
||||
sample = raw_dataset[idx]
|
||||
|
||||
# Filter by language
|
||||
if not self.filter_language_content(sample):
|
||||
print(f"Unsupported language in sample {idx}, skipping.")
|
||||
continue
|
||||
|
||||
# Create chunks
|
||||
chunks = self.create_sample_chunks(sample)
|
||||
processed_samples.extend(chunks)
|
||||
|
||||
# Debug mode early termination
|
||||
if self.config.debug_mode and len(processed_samples) >= 20:
|
||||
print("Debug mode: Stopping at 20 chunks")
|
||||
break
|
||||
|
||||
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||||
return processed_samples
|
||||
|
||||
|
||||
class CustomDataLoader:
|
||||
"""Custom data loader for knowledge graph extraction."""
|
||||
|
||||
def __init__(self, dataset, processor: DatasetProcessor):
|
||||
self.raw_dataset = dataset
|
||||
self.processor = processor
|
||||
self.processed_data = processor.prepare_dataset(dataset)
|
||||
self.stage_to_prompt_dict = {
|
||||
"stage_1": "entity_relation",
|
||||
"stage_2": "event_entity",
|
||||
"stage_3": "event_relation"
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.processed_data)
|
||||
|
||||
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
|
||||
messages_dict = {
|
||||
'stage_1': [],
|
||||
'stage_2': [],
|
||||
'stage_3': []
|
||||
}
|
||||
for item in batch_data:
|
||||
# get language
|
||||
language = item.get("metadata",{}).get("lang", "en")
|
||||
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
|
||||
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
|
||||
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||||
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||||
stage_one_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_1_msg}
|
||||
]
|
||||
stage_two_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_2_msg}
|
||||
]
|
||||
stage_three_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_3_msg}
|
||||
]
|
||||
messages_dict['stage_1'].append(stage_one_message)
|
||||
messages_dict['stage_2'].append(stage_two_message)
|
||||
messages_dict['stage_3'].append(stage_three_message)
|
||||
|
||||
return messages_dict
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through batches."""
|
||||
batch_size = self.processor.config.batch_size_triple
|
||||
start_idx = self.processor.config.resume_from * batch_size
|
||||
|
||||
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
|
||||
batch_data = self.processed_data[i:i + batch_size]
|
||||
|
||||
# Prepare instructions
|
||||
instructions = self.create_batch_instructions(batch_data)
|
||||
|
||||
# Extract batch information
|
||||
batch_ids = [item["id"] for item in batch_data]
|
||||
batch_metadata = [item["metadata"] for item in batch_data]
|
||||
batch_texts = [item["text"] for item in batch_data]
|
||||
|
||||
yield instructions, batch_ids, batch_texts, batch_metadata
|
||||
|
||||
|
||||
class OutputParser:
|
||||
"""Parses model outputs and extracts structured data."""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
|
||||
"""Extract structured data from model outputs."""
|
||||
results = []
|
||||
for output in outputs:
|
||||
parsed_data = json_repair.loads(output)
|
||||
results.append(parsed_data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class KnowledgeGraphExtractor:
|
||||
"""Main class for knowledge graph extraction pipeline."""
|
||||
|
||||
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.parser = None
|
||||
self.model = model
|
||||
self.model_name = model.model_name
|
||||
self.parser = OutputParser()
|
||||
|
||||
def load_dataset(self) -> Any:
|
||||
"""Load and prepare dataset."""
|
||||
data_path = Path(self.config.data_directory)
|
||||
all_files = os.listdir(data_path)
|
||||
|
||||
valid_files = [
|
||||
filename for filename in all_files
|
||||
if filename.startswith(self.config.filename_pattern) and
|
||||
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
|
||||
]
|
||||
|
||||
print(f"Found data files: {valid_files}")
|
||||
data_files = valid_files
|
||||
dataset_config = {"train": data_files}
|
||||
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
|
||||
|
||||
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
|
||||
"""Process first stage: entity-relation extraction."""
|
||||
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
|
||||
if self.config.record:
|
||||
text_outputs = [output[0] for output in outputs]
|
||||
else:
|
||||
text_outputs = outputs
|
||||
structured_data = self.parser.extract_structured_data(text_outputs)
|
||||
return outputs, structured_data
|
||||
|
||||
def create_output_filename(self) -> str:
|
||||
"""Create output filename with timestamp and shard info."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
model_name_safe = self.config.model_path.replace("/", "_")
|
||||
|
||||
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
|
||||
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
|
||||
|
||||
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
|
||||
os.makedirs(extraction_dir, exist_ok=True)
|
||||
|
||||
return os.path.join(extraction_dir, filename)
|
||||
|
||||
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
|
||||
"""Prepare result dictionary for a single sample."""
|
||||
ids, original_texts, metadata = batch_data
|
||||
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
|
||||
if self.config.record:
|
||||
stage1_outputs = [output[0] for output in stage1_results]
|
||||
stage1_usage = [output[1] for output in stage1_results]
|
||||
stage2_outputs = [output[0] for output in stage2_results]
|
||||
stage2_usage = [output[1] for output in stage2_results]
|
||||
stage3_outputs = [output[0] for output in stage3_results]
|
||||
stage3_usage = [output[1] for output in stage3_results]
|
||||
else:
|
||||
stage1_outputs = stage1_results
|
||||
stage2_outputs = stage2_results
|
||||
stage3_outputs = stage3_results
|
||||
result = {
|
||||
"id": ids[index],
|
||||
"metadata": metadata[index],
|
||||
"original_text": original_texts[index],
|
||||
"entity_relation_dict": entity_relations[index],
|
||||
"event_entity_relation_dict": event_entities[index],
|
||||
"event_relation_dict": event_relations[index],
|
||||
"output_stage_one": stage1_outputs[index],
|
||||
"output_stage_two": stage2_outputs[index],
|
||||
"output_stage_three": stage3_outputs[index],
|
||||
}
|
||||
if self.config.record:
|
||||
result['usage_stage_one'] = stage1_usage[index]
|
||||
result['usage_stage_two'] = stage2_usage[index]
|
||||
result['usage_stage_three'] = stage3_usage[index]
|
||||
|
||||
# Handle date serialization
|
||||
if 'date_download' in result['metadata']:
|
||||
result['metadata']['date_download'] = str(result['metadata']['date_download'])
|
||||
|
||||
return result
|
||||
|
||||
def debug_print_result(self, result: Dict[str, Any]):
|
||||
"""Print result for debugging."""
|
||||
for key, value in result.items():
|
||||
print(f"{key}: {value}")
|
||||
print("-" * 100)
|
||||
|
||||
def run_extraction(self):
|
||||
"""Run the complete knowledge graph extraction pipeline."""
|
||||
# Setup
|
||||
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
|
||||
dataset = self.load_dataset()
|
||||
|
||||
if self.config.debug_mode:
|
||||
print("Debug mode: Processing only 20 samples")
|
||||
|
||||
# Create data processor and loader
|
||||
processor = DatasetProcessor(self.config)
|
||||
data_loader = CustomDataLoader(dataset["train"], processor)
|
||||
|
||||
output_file = self.create_output_filename()
|
||||
print(f"Model: {self.config.model_path}")
|
||||
|
||||
batch_counter = 0
|
||||
|
||||
with torch.no_grad():
|
||||
with open(output_file, "w") as output_stream:
|
||||
for batch in data_loader:
|
||||
batch_counter += 1
|
||||
messages_dict, batch_ids, batch_texts, batch_metadata = batch
|
||||
|
||||
# Process all three stages
|
||||
stage1_results = self.process_stage(messages_dict['stage_1'],1)
|
||||
stage2_results = self.process_stage(messages_dict['stage_2'],2)
|
||||
stage3_results = self.process_stage(messages_dict['stage_3'],3)
|
||||
|
||||
# Combine results
|
||||
batch_data = (batch_ids, batch_texts, batch_metadata)
|
||||
stage_outputs = (stage1_results, stage2_results, stage3_results)
|
||||
|
||||
# Write results
|
||||
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
|
||||
for i in range(len(batch_ids)):
|
||||
result = self.prepare_result_dict(batch_data, stage_outputs, i)
|
||||
|
||||
if self.config.debug_mode:
|
||||
self.debug_print_result(result)
|
||||
|
||||
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
output_stream.flush()
|
||||
|
||||
def convert_json_to_csv(self):
|
||||
json2csv(
|
||||
dataset = self.config.filename_pattern,
|
||||
output_dir=f"{self.config.output_directory}/triples_csv",
|
||||
data_dir=f"{self.config.output_directory}/kg_extraction"
|
||||
)
|
||||
csvs_to_temp_graphml(
|
||||
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
config = self.config
|
||||
)
|
||||
|
||||
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
|
||||
generate_concept(
|
||||
model=self.model,
|
||||
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
|
||||
output_folder=f"{self.config.output_directory}/concepts",
|
||||
output_file="concept.json",
|
||||
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
|
||||
config=self.config,
|
||||
batch_size=batch_size if batch_size else self.config.batch_size_concept,
|
||||
shard=self.config.current_shard_concept,
|
||||
num_shards=self.config.total_shards_concept,
|
||||
record = self.config.record,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def create_concept_csv(self):
|
||||
merge_csv_files(
|
||||
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
input_dir=f"{self.config.output_directory}/concepts",
|
||||
)
|
||||
all_concept_triples_csv_to_csv(
|
||||
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||||
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||||
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
)
|
||||
|
||||
def convert_to_graphml(self):
|
||||
csvs_to_graphml(
|
||||
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
|
||||
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
|
||||
)
|
||||
|
||||
def add_numeric_id(self):
|
||||
add_csv_columns(
|
||||
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||||
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||||
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
|
||||
)
|
||||
|
||||
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
|
||||
encoder_model.compute_kg_embedding(
|
||||
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||||
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
|
||||
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||||
batch_size = 2048
|
||||
)
|
||||
|
||||
def create_faiss_index(self, index_type="HNSW,Flat"):
|
||||
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
|
||||
|
||||
def parse_command_line_arguments() -> ProcessingConfig:
|
||||
"""Parse command line arguments and return configuration."""
|
||||
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
|
||||
|
||||
parser.add_argument("-m", "--model", type=str, required=True,
|
||||
default="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
help="Model path for knowledge extraction")
|
||||
parser.add_argument("--data_dir", type=str, default="your_data_dir",
|
||||
help="Directory containing input data")
|
||||
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
|
||||
help="Filename pattern to match")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=16,
|
||||
help="Batch size for processing")
|
||||
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
|
||||
help="Output directory for results")
|
||||
parser.add_argument("--total_shards_triple", type=int, default=1,
|
||||
help="Total number of data shards")
|
||||
parser.add_argument("--shard", type=int, default=0,
|
||||
help="Current shard index")
|
||||
parser.add_argument("--bit8", action="store_true",
|
||||
help="Use 8-bit quantization")
|
||||
parser.add_argument("--debug", action="store_true",
|
||||
help="Enable debug mode")
|
||||
parser.add_argument("--resume", type=int, default=0,
|
||||
help="Resume from specific batch")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ProcessingConfig(
|
||||
model_path=args.model,
|
||||
data_directory=args.data_dir,
|
||||
filename_pattern=args.file_name,
|
||||
batch_size=args.batch_size,
|
||||
output_directory=args.output_dir,
|
||||
total_shards_triple=args.total_shards_triple,
|
||||
current_shard_triple=args.shard,
|
||||
use_8bit=args.bit8,
|
||||
debug_mode=args.debug,
|
||||
resume_from=args.resume
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the knowledge graph extraction pipeline."""
|
||||
config = parse_command_line_arguments()
|
||||
extractor = KnowledgeGraphExtractor(config)
|
||||
extractor.run_extraction()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
atlas_rag/kg_construction/utils/__init__.py
Normal file
0
atlas_rag/kg_construction/utils/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,160 @@
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
|
||||
def check_created_csv_header(keyword, csv_dir):
|
||||
keyword_to_paths ={
|
||||
'cc_en':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv",
|
||||
'concept_with_numeric_id': f"{csv_dir}/concept_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
},
|
||||
'pes2o_abstract':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv",
|
||||
},
|
||||
'en_simple_wiki_v0':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv",
|
||||
},
|
||||
}
|
||||
for key, path in keyword_to_paths[keyword].items():
|
||||
with open(path) as infile:
|
||||
reader = csv.reader(infile)
|
||||
header = next(reader)
|
||||
print(f"Header of {key}: {header}")
|
||||
|
||||
# print first 5 rows
|
||||
for i, row in enumerate(reader):
|
||||
if i < 1:
|
||||
print(row)
|
||||
else:
|
||||
break
|
||||
|
||||
def add_csv_columns(node_csv, edge_csv, text_csv, node_with_numeric_id, edge_with_numeric_id, text_with_numeric_id):
|
||||
with open(node_csv) as infile, open(node_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':LABEL')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
with open(edge_csv) as infile, open(edge_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':TYPE')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
with open(text_csv) as infile, open(text_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':LABEL')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
# def add_csv_columns(keyword, csv_dir):
|
||||
# keyword_to_paths ={
|
||||
# 'cc_en':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# 'pes2o_abstract':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# 'en_simple_wiki_v0':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# }
|
||||
# # ouput node
|
||||
# with open(keyword_to_paths[keyword]['node_csv']) as infile, open(keyword_to_paths[keyword]['node_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':LABEL'
|
||||
# label_index = header.index(':LABEL')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
# # output edge (TYPE instead of LABEL for edge)
|
||||
# with open(keyword_to_paths[keyword]['edge_csv']) as infile, open(keyword_to_paths[keyword]['edge_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':TYPE'
|
||||
# label_index = header.index(':TYPE')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
# # output text
|
||||
# with open(keyword_to_paths[keyword]['text_csv']) as infile, open(keyword_to_paths[keyword]['text_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':LABEL'
|
||||
# label_index = header.index(':LABEL')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
if __name__ == "__main__":
|
||||
keyword = "en_simple_wiki_v0"
|
||||
csv_dir = "./import" # Change this to your CSV directory
|
||||
add_csv_columns(keyword, csv_dir)
|
||||
# check_created_csv_header(keyword)
|
||||
189
atlas_rag/kg_construction/utils/csv_processing/csv_to_graphml.py
Normal file
189
atlas_rag/kg_construction/utils/csv_processing/csv_to_graphml.py
Normal file
@ -0,0 +1,189 @@
|
||||
import networkx as nx
|
||||
import csv
|
||||
import ast
|
||||
import hashlib
|
||||
import os
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
import pickle
|
||||
|
||||
def get_node_id(entity_name, entity_to_id={}):
|
||||
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
|
||||
if entity_name not in entity_to_id:
|
||||
# Use a hash function to generate a unique ID
|
||||
hash_object = hashlib.sha256(entity_name.encode('utf-8'))
|
||||
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
|
||||
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
|
||||
entity_to_id[entity_name] = hash_hex
|
||||
return entity_to_id[entity_name]
|
||||
|
||||
def csvs_to_temp_graphml(triple_node_file, triple_edge_file, config:ProcessingConfig=None):
|
||||
g = nx.DiGraph()
|
||||
entity_to_id = {}
|
||||
|
||||
# Add triple nodes
|
||||
with open(triple_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["name:ID"]
|
||||
mapped_id = get_node_id(node_id, entity_to_id)
|
||||
if mapped_id not in g.nodes:
|
||||
g.add_node(mapped_id, id=node_id, type=row["type"])
|
||||
|
||||
|
||||
# Add triple edges
|
||||
with open(triple_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = get_node_id(row[":END_ID"], entity_to_id)
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
|
||||
# save graph to
|
||||
output_name = f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl"
|
||||
# check if output file directory exists
|
||||
output_dir = os.path.dirname(output_name)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# store the graph to a pickle file
|
||||
with open(output_name, 'wb') as output_file:
|
||||
pickle.dump(g, output_file)
|
||||
|
||||
|
||||
|
||||
def csvs_to_graphml(triple_node_file, text_node_file, concept_node_file,
|
||||
triple_edge_file, text_edge_file, concept_edge_file,
|
||||
output_file):
|
||||
'''
|
||||
Convert multiple CSV files into a single GraphML file.
|
||||
|
||||
Types of nodes to be added to the graph:
|
||||
- Triple nodes: Nodes representing triples, with properties like subject, predicate, object.
|
||||
- Text nodes: Nodes representing text, with properties like text content.
|
||||
- Concept nodes: Nodes representing concepts, with properties like concept name and type.
|
||||
|
||||
Types of edges to be added to the graph:
|
||||
- Triple edges: Edges representing relationships between triples, with properties like relation type.
|
||||
- Text edges: Edges representing relationships between text and nodes, with properties like text type.
|
||||
- Concept edges: Edges representing relationships between concepts and nodes, with properties like concept type.
|
||||
|
||||
DiGraph networkx attributes:
|
||||
Node:
|
||||
- type: Type of the node (e.g., entity, event, text, concept).
|
||||
- file_id: List of text IDs the node is associated with.
|
||||
- id: Node Name
|
||||
Edge:
|
||||
- relation: relation name
|
||||
- file_id: List of text IDs the edge is associated with.
|
||||
- type: Type of the edge (e.g., Source, Relation, Concept).
|
||||
- synsets: List of synsets associated with the edge.
|
||||
|
||||
'''
|
||||
g = nx.DiGraph()
|
||||
entity_to_id = {}
|
||||
|
||||
# Add triple nodes
|
||||
with open(triple_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["name:ID"]
|
||||
mapped_id = get_node_id(node_id, entity_to_id)
|
||||
# Check if node already exists to prevent duplicates
|
||||
if mapped_id not in g.nodes:
|
||||
g.add_node(mapped_id, id=node_id, type=row["type"])
|
||||
|
||||
# Add text nodes
|
||||
with open(text_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["text_id:ID"]
|
||||
# Check if node already exists to prevent duplicates
|
||||
if node_id not in g.nodes:
|
||||
g.add_node(node_id, file_id=node_id, id=row["original_text"], type="passage")
|
||||
|
||||
# Add concept nodes
|
||||
with open(concept_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["concept_id:ID"]
|
||||
# Check if node already exists to prevent duplicates
|
||||
if node_id not in g.nodes:
|
||||
g.add_node(node_id, file_id="concept_file", id=row["name"], type="concept")
|
||||
|
||||
# Add file id for triple nodes and concept nodes when add the edges
|
||||
|
||||
# Add triple edges
|
||||
with open(triple_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = get_node_id(row[":END_ID"], entity_to_id)
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
# Add file_id to start and end nodes if they are triple or concept nodes
|
||||
for node_id in [start_id, end_id]:
|
||||
if g.nodes[node_id]['type'] in ['triple', 'concept'] and 'file_id' not in g.nodes[node_id]:
|
||||
g.nodes[node_id]['file_id'] = row.get("file_id", "triple_file")
|
||||
|
||||
# Add concepts to the edge
|
||||
concepts = ast.literal_eval(row["concepts"])
|
||||
for concept in concepts:
|
||||
if "concepts" not in g.edges[start_id, end_id]:
|
||||
g.edges[start_id, end_id]['concepts'] = str(concept)
|
||||
else:
|
||||
# Avoid duplicate concepts by checking if concept is already in the list
|
||||
current_concepts = g.edges[start_id, end_id]['concepts'].split(",")
|
||||
if str(concept) not in current_concepts:
|
||||
g.edges[start_id, end_id]['concepts'] += "," + str(concept)
|
||||
|
||||
|
||||
# Add text edges
|
||||
with open(text_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = row[":END_ID"]
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation="mention in", type=row[":TYPE"])
|
||||
# Add file_id to start node if it is a triple or concept node
|
||||
if 'file_id' in g.nodes[start_id]:
|
||||
g.nodes[start_id]['file_id'] += "," + str(end_id)
|
||||
else:
|
||||
g.nodes[start_id]['file_id'] = str(end_id)
|
||||
|
||||
# Add concept edges between triple nodes and concept nodes
|
||||
with open(concept_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = row[":END_ID"] # end id is concept node id
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
|
||||
# Write to GraphML
|
||||
# check if output file directory exists
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
nx.write_graphml(g, output_file, infer_numeric_types=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert CSV files to GraphML format.')
|
||||
parser.add_argument('--triple_node_file', type=str, required=True, help='Path to the triple node CSV file.')
|
||||
parser.add_argument('--text_node_file', type=str, required=True, help='Path to the text node CSV file.')
|
||||
parser.add_argument('--concept_node_file', type=str, required=True, help='Path to the concept node CSV file.')
|
||||
parser.add_argument('--triple_edge_file', type=str, required=True, help='Path to the triple edge CSV file.')
|
||||
parser.add_argument('--text_edge_file', type=str, required=True, help='Path to the text edge CSV file.')
|
||||
parser.add_argument('--concept_edge_file', type=str, required=True, help='Path to the concept edge CSV file.')
|
||||
parser.add_argument('--output_file', type=str, required=True, help='Path to the output GraphML file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
csvs_to_graphml(args.triple_node_file, args.text_node_file, args.concept_node_file,
|
||||
args.triple_edge_file, args.text_edge_file, args.concept_edge_file,
|
||||
args.output_file)
|
||||
70
atlas_rag/kg_construction/utils/csv_processing/csv_to_npy.py
Normal file
70
atlas_rag/kg_construction/utils/csv_processing/csv_to_npy.py
Normal file
@ -0,0 +1,70 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from ast import literal_eval # Safer string-to-list conversion
|
||||
import os
|
||||
|
||||
CHUNKSIZE = 100_000 # Adjust based on your RAM (100K rows per chunk)
|
||||
EMBEDDING_COL = "embedding:STRING" # Column name with embeddings
|
||||
# DIMENSION = 32 # Update with your embedding dimension
|
||||
ENTITY_ONLY = True
|
||||
def parse_embedding(embed_str):
|
||||
"""Convert embedding string to numpy array"""
|
||||
# Remove brackets and convert to list
|
||||
return np.array(literal_eval(embed_str), dtype=np.float32)
|
||||
|
||||
# Create memory-mapped numpy file
|
||||
def convert_csv_to_npy(csv_path, npy_path):
|
||||
total_embeddings = 0
|
||||
# check dir exist, if not then create it
|
||||
os.makedirs(os.path.dirname(npy_path), exist_ok=True)
|
||||
|
||||
with open(npy_path, "wb") as f:
|
||||
pass # Initialize empty file
|
||||
|
||||
# Process CSV in chunks
|
||||
for chunk_idx, df_chunk in enumerate(
|
||||
pd.read_csv(csv_path, chunksize=CHUNKSIZE, usecols=[EMBEDDING_COL])
|
||||
):
|
||||
|
||||
|
||||
# Parse embeddings
|
||||
embeddings = np.stack(
|
||||
df_chunk[EMBEDDING_COL].apply(parse_embedding).values
|
||||
)
|
||||
|
||||
# Verify dimensions
|
||||
# assert embeddings.shape[1] == DIMENSION, \
|
||||
# f"Dimension mismatch at chunk {chunk_idx}"
|
||||
total_embeddings += embeddings.shape[0]
|
||||
# Append to .npy file
|
||||
with open(npy_path, "ab") as f:
|
||||
np.save(f, embeddings.astype(np.float32))
|
||||
|
||||
print(f"Processed chunk {chunk_idx} ({CHUNKSIZE*(chunk_idx+1)} rows)")
|
||||
print(f"Total number of embeddings: {total_embeddings}")
|
||||
print("Conversion complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
keyword = 'cc_en' # Change this to your desired keyword
|
||||
csv_dir="./import" # Change this to your CSV directory
|
||||
keyword_to_paths ={
|
||||
'cc_en':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_2.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_2.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json_with_emb.csv",
|
||||
},
|
||||
'pes2o_abstract':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
|
||||
},
|
||||
'en_simple_wiki_v0':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
|
||||
},
|
||||
}
|
||||
for key, path in keyword_to_paths[keyword].items():
|
||||
npy_path = path.replace(".csv", ".npy")
|
||||
convert_csv_to_npy(path, npy_path)
|
||||
print(f"Converted {path} to {npy_path}")
|
||||
27
atlas_rag/kg_construction/utils/csv_processing/merge_csv.py
Normal file
27
atlas_rag/kg_construction/utils/csv_processing/merge_csv.py
Normal file
@ -0,0 +1,27 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
def merge_csv_files(output_file, input_dir):
|
||||
"""
|
||||
Merge all CSV files in the input directory into a single output file.
|
||||
|
||||
Args:
|
||||
output_file (str): Path to the output CSV file.
|
||||
input_dir (str): Directory containing the input CSV files.
|
||||
"""
|
||||
# Delete the output file if it exists
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
|
||||
# Write the header to the output file
|
||||
with open(output_file, 'w') as outfile:
|
||||
outfile.write("node,conceptualized_node,node_type\n")
|
||||
|
||||
# Append the contents of all CSV files in the input directory
|
||||
for csv_file in glob.glob(os.path.join(input_dir, '*.csv')):
|
||||
with open(csv_file, 'r') as infile:
|
||||
# Skip the header line
|
||||
next(infile)
|
||||
# Append the remaining lines to the output file
|
||||
with open(output_file, 'a') as outfile:
|
||||
outfile.writelines(infile)
|
||||
277
atlas_rag/kg_construction/utils/json_processing/json_to_csv.py
Normal file
277
atlas_rag/kg_construction/utils/json_processing/json_to_csv.py
Normal file
@ -0,0 +1,277 @@
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
# Increase the field size limit
|
||||
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
|
||||
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
new_text = re.sub(r'\s+', ' ', new_text).strip()
|
||||
|
||||
return new_text
|
||||
|
||||
def remove_NUL(text):
|
||||
return text.replace("\x00", "")
|
||||
|
||||
def json2csv(dataset, data_dir, output_dir, test=False):
|
||||
"""
|
||||
Convert JSON files to CSV files for nodes, edges, and missing concepts.
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
data_dir (str): Directory containing the JSON files.
|
||||
output_dir (str): Directory to save the output CSV files.
|
||||
test (bool): If True, run in test mode (process only 3 files).
|
||||
"""
|
||||
visited_nodes = set()
|
||||
visited_hashes = set()
|
||||
|
||||
all_entities = set()
|
||||
all_events = set()
|
||||
all_relations = set()
|
||||
|
||||
file_dir_list = [f for f in os.listdir(data_dir) if dataset in f]
|
||||
file_dir_list = sorted(file_dir_list)
|
||||
if test:
|
||||
file_dir_list = file_dir_list[:3]
|
||||
print("Loading data from the json files")
|
||||
print("Number of files: ", len(file_dir_list))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Define output file paths
|
||||
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb.csv")
|
||||
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb.csv")
|
||||
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json.csv")
|
||||
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json.csv")
|
||||
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json.csv")
|
||||
|
||||
if test:
|
||||
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json_test.csv")
|
||||
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json_test.csv")
|
||||
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb_test.csv")
|
||||
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb_test.csv")
|
||||
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json_test.csv")
|
||||
|
||||
# Open CSV files for writing
|
||||
with open(node_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node_text, \
|
||||
open(edge_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge_text, \
|
||||
open(node_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node, \
|
||||
open(edge_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge:
|
||||
|
||||
csv_writer_node_text = csv.writer(csvfile_node_text)
|
||||
csv_writer_edge_text = csv.writer(csvfile_edge_text)
|
||||
writer_node = csv.writer(csvfile_node)
|
||||
writer_edge = csv.writer(csvfile_edge)
|
||||
|
||||
# Write headers
|
||||
csv_writer_node_text.writerow(["text_id:ID", "original_text", ":LABEL"])
|
||||
csv_writer_edge_text.writerow([":START_ID", ":END_ID", ":TYPE"])
|
||||
writer_node.writerow(["name:ID", "type", "concepts", "synsets", ":LABEL"])
|
||||
writer_edge.writerow([":START_ID", ":END_ID", "relation", "concepts", "synsets", ":TYPE"])
|
||||
|
||||
# Process each file
|
||||
for file_dir in tqdm(file_dir_list):
|
||||
print("Processing file for file ids: ", file_dir)
|
||||
with open(os.path.join(data_dir, file_dir), "r") as jsonfile:
|
||||
for line in jsonfile:
|
||||
data = json.loads(line.strip())
|
||||
original_text = data["original_text"]
|
||||
original_text = remove_NUL(original_text)
|
||||
if "Here is the passage." in original_text:
|
||||
original_text = original_text.split("Here is the passage.")[-1]
|
||||
eot_token = "<|eot_id|>"
|
||||
original_text = original_text.split(eot_token)[0]
|
||||
|
||||
text_hash_id = compute_hash_id(original_text)
|
||||
|
||||
# Write the original text as nodes
|
||||
if text_hash_id not in visited_hashes:
|
||||
visited_hashes.add(text_hash_id)
|
||||
csv_writer_node_text.writerow([text_hash_id, original_text, "Text"])
|
||||
|
||||
file_id = str(data["id"])
|
||||
entity_relation_dict = data["entity_relation_dict"]
|
||||
event_entity_relation_dict = data["event_entity_relation_dict"]
|
||||
event_relation_dict = data["event_relation_dict"]
|
||||
|
||||
# Process entity triples
|
||||
entity_triples = []
|
||||
for entity_triple in entity_relation_dict:
|
||||
try:
|
||||
assert isinstance(entity_triple["Head"], str)
|
||||
assert isinstance(entity_triple["Relation"], str)
|
||||
assert isinstance(entity_triple["Tail"], str)
|
||||
|
||||
head_entity = entity_triple["Head"]
|
||||
relation = entity_triple["Relation"]
|
||||
tail_entity = entity_triple["Tail"]
|
||||
|
||||
# Clean the text
|
||||
head_entity = clean_text(head_entity)
|
||||
relation = clean_text(relation)
|
||||
tail_entity = clean_text(tail_entity)
|
||||
|
||||
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
|
||||
continue
|
||||
|
||||
entity_triples.append((head_entity, relation, tail_entity))
|
||||
except:
|
||||
print(f"Error processing entity triple: {entity_triple}")
|
||||
continue
|
||||
|
||||
# Process event triples
|
||||
event_triples = []
|
||||
for event_triple in event_relation_dict:
|
||||
try:
|
||||
assert isinstance(event_triple["Head"], str)
|
||||
assert isinstance(event_triple["Relation"], str)
|
||||
assert isinstance(event_triple["Tail"], str)
|
||||
head_event = event_triple["Head"]
|
||||
relation = event_triple["Relation"]
|
||||
tail_event = event_triple["Tail"]
|
||||
|
||||
# Clean the text
|
||||
head_event = clean_text(head_event)
|
||||
relation = clean_text(relation)
|
||||
tail_event = clean_text(tail_event)
|
||||
|
||||
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
|
||||
continue
|
||||
|
||||
event_triples.append((head_event, relation, tail_event))
|
||||
except:
|
||||
print(f"Error processing event triple: {event_triple}")
|
||||
|
||||
# Process event-entity triples
|
||||
event_entity_triples = []
|
||||
for event_entity_participations in event_entity_relation_dict:
|
||||
if "Event" not in event_entity_participations or "Entity" not in event_entity_participations:
|
||||
continue
|
||||
if not isinstance(event_entity_participations["Event"], str) or not isinstance(event_entity_participations["Entity"], list):
|
||||
continue
|
||||
|
||||
for entity in event_entity_participations["Entity"]:
|
||||
if not isinstance(entity, str):
|
||||
continue
|
||||
|
||||
entity = clean_text(entity)
|
||||
event = clean_text(event_entity_participations["Event"])
|
||||
|
||||
if event.isspace() or len(event) == 0 or entity.isspace() or len(entity) == 0:
|
||||
continue
|
||||
|
||||
event_entity_triples.append((event, "is participated by", entity))
|
||||
|
||||
# Write nodes and edges to CSV files
|
||||
for entity_triple in entity_triples:
|
||||
head_entity, relation, tail_entity = entity_triple
|
||||
if head_entity is None or tail_entity is None or relation is None:
|
||||
continue
|
||||
if head_entity.isspace() or tail_entity.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_entity) == 0 or len(tail_entity) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_entity not in visited_nodes:
|
||||
visited_nodes.add(head_entity)
|
||||
all_entities.add(head_entity)
|
||||
writer_node.writerow([head_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_entity, text_hash_id, "Source"])
|
||||
|
||||
if tail_entity not in visited_nodes:
|
||||
visited_nodes.add(tail_entity)
|
||||
all_entities.add(tail_entity)
|
||||
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_entity, tail_entity, relation, [], [], "Relation"])
|
||||
|
||||
for event_triple in event_triples:
|
||||
head_event, relation, tail_event = event_triple
|
||||
if head_event is None or tail_event is None or relation is None:
|
||||
continue
|
||||
if head_event.isspace() or tail_event.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_event) == 0 or len(tail_event) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_event not in visited_nodes:
|
||||
visited_nodes.add(head_event)
|
||||
all_events.add(head_event)
|
||||
writer_node.writerow([head_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
|
||||
|
||||
if tail_event not in visited_nodes:
|
||||
visited_nodes.add(tail_event)
|
||||
all_events.add(tail_event)
|
||||
writer_node.writerow([tail_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_event, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_event, tail_event, relation, [], [], "Relation"])
|
||||
|
||||
for event_entity_triple in event_entity_triples:
|
||||
head_event, relation, tail_entity = event_entity_triple
|
||||
if head_event is None or tail_entity is None or relation is None:
|
||||
continue
|
||||
if head_event.isspace() or tail_entity.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_event) == 0 or len(tail_entity) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_event not in visited_nodes:
|
||||
visited_nodes.add(head_event)
|
||||
all_events.add(head_event)
|
||||
writer_node.writerow([head_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
|
||||
|
||||
if tail_entity not in visited_nodes:
|
||||
visited_nodes.add(tail_entity)
|
||||
all_entities.add(tail_entity)
|
||||
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_event, tail_entity, relation, [], [], "Relation"])
|
||||
|
||||
# Write missing concepts to CSV
|
||||
with open(missing_concepts_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(["Name", "Type"])
|
||||
for entity in all_entities:
|
||||
writer.writerow([entity, "Entity"])
|
||||
for event in all_events:
|
||||
writer.writerow([event, "Event"])
|
||||
for relation in all_relations:
|
||||
writer.writerow([relation, "Relation"])
|
||||
|
||||
print("Data to CSV completed successfully.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", type=str, required=True, help="[pes2o_abstract, en_simple_wiki_v0, cc_en]")
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the graph raw JSON files")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output CSV files")
|
||||
parser.add_argument("--test", action="store_true", help="Test the script")
|
||||
args = parser.parse_args()
|
||||
json2csv(dataset=args.dataset, data_dir=args.data_dir, output_dir=args.output_dir, test=args.test)
|
||||
@ -0,0 +1,169 @@
|
||||
import networkx as nx
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
def get_node_id(entity_name, entity_to_id):
|
||||
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
|
||||
if entity_name not in entity_to_id:
|
||||
# Use a hash function to generate a unique ID
|
||||
hash_object = hashlib.md5(entity_name.encode()) # Use MD5 or another hashing algorithm
|
||||
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
|
||||
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
|
||||
entity_to_id[entity_name] = f'n{hash_hex[:16]}'
|
||||
return entity_to_id[entity_name]
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
return new_text
|
||||
|
||||
|
||||
def process_kg_data(input_passage_dir, input_triple_dir, output_dir, keyword):
|
||||
# Get file names containing the keyword
|
||||
file_names = [file for file in list(os.listdir(input_triple_dir)) if keyword in file]
|
||||
print(f"Keyword: {keyword}")
|
||||
print(f"Number of files: {len(file_names)}")
|
||||
print(file_names)
|
||||
|
||||
passage_file_names = [file for file in list(os.listdir(input_passage_dir)) if keyword in file]
|
||||
print(f'Passage file names: {passage_file_names}')
|
||||
|
||||
g = nx.DiGraph()
|
||||
print("Graph created.")
|
||||
entity_to_id = {}
|
||||
# check if output directory exists, if not create it
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Output directory {output_dir} created.")
|
||||
|
||||
output_path = f"{output_dir}/{keyword}_kg_from_corpus.graphml"
|
||||
|
||||
# Create the original_text to node_id dictionary and add passage node to the graph
|
||||
with open(f"{input_passage_dir}/{passage_file_names[0]}") as f:
|
||||
data = json.load(f)
|
||||
for item in tqdm(data, desc="Processing passages"):
|
||||
passage_id = item["id"]
|
||||
passage_text = item["text"]
|
||||
node_id = get_node_id(passage_text, entity_to_id)
|
||||
if passage_text.isspace() or len(passage_text) == 0:
|
||||
continue
|
||||
# Add the passage node to the graph
|
||||
g.add_node(node_id, type="passage", id=passage_text, file_id=passage_id)
|
||||
|
||||
for file_name in tqdm(file_names):
|
||||
print(f"Processing {file_name}")
|
||||
input_file_path = f"{input_triple_dir}/{file_name}"
|
||||
with open(input_file_path) as f:
|
||||
for line in tqdm(f):
|
||||
data = json.loads(line)
|
||||
metadata = data["metadata"]
|
||||
file_id = data["id"]
|
||||
original_text = data["original_text"]
|
||||
entity_relation_dict = data["entity_relation_dict"]
|
||||
event_entity_relation_dict = data["event_entity_relation_dict"]
|
||||
event_relation_dict = data["event_relation_dict"]
|
||||
|
||||
# Process entity triples
|
||||
entity_triples = []
|
||||
for entity_triple in entity_relation_dict:
|
||||
if not all(key in entity_triple for key in ["Head", "Relation", "Tail"]):
|
||||
continue
|
||||
head_entity = clean_text(entity_triple["Head"])
|
||||
relation = clean_text(entity_triple["Relation"])
|
||||
tail_entity = clean_text(entity_triple["Tail"])
|
||||
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
|
||||
continue
|
||||
entity_triples.append((head_entity, relation, tail_entity))
|
||||
|
||||
# Add entity triples to the graph
|
||||
for triple in entity_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="entity", id=triple[0])
|
||||
g.add_node(tail_id, type="entity", id=triple[2])
|
||||
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
else:
|
||||
g.nodes[node_id]["file_id"] += "," + str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
# Process event triples
|
||||
event_triples = []
|
||||
for event_triple in event_relation_dict:
|
||||
if not all(key in event_triple for key in ["Head", "Relation", "Tail"]):
|
||||
continue
|
||||
head_event = clean_text(event_triple["Head"])
|
||||
relation = clean_text(event_triple["Relation"])
|
||||
tail_event = clean_text(event_triple["Tail"])
|
||||
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
|
||||
continue
|
||||
event_triples.append((head_event, relation, tail_event))
|
||||
|
||||
# Add event triples to the graph
|
||||
for triple in event_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="event", id=triple[0])
|
||||
g.add_node(tail_id, type="event", id=triple[2])
|
||||
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
else:
|
||||
g.nodes[node_id]["file_id"] += "," + str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
# Process event-entity triples
|
||||
event_entity_triples = []
|
||||
for event_entity_participations in event_entity_relation_dict:
|
||||
if not all(key in event_entity_participations for key in ["Event", "Entity"]):
|
||||
continue
|
||||
event = clean_text(event_entity_participations["Event"])
|
||||
if event.isspace() or len(event) == 0:
|
||||
continue
|
||||
for entity in event_entity_participations["Entity"]:
|
||||
if not isinstance(entity, str) or entity.isspace() or len(entity) == 0:
|
||||
continue
|
||||
entity = clean_text(entity)
|
||||
event_entity_triples.append((event, "is participated by", entity))
|
||||
|
||||
# Add event-entity triples to the graph
|
||||
for triple in event_entity_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="event", id=triple[0])
|
||||
g.add_node(tail_id, type="entity", id=triple[2])
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
print(f"Number of nodes: {g.number_of_nodes()}")
|
||||
print(f"Number of edges: {g.number_of_edges()}")
|
||||
print(f"Graph density: {nx.density(g)}")
|
||||
with open(output_path, 'wb') as f:
|
||||
nx.write_graphml(g, f, infer_numeric_types=True)
|
||||
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Set up argument parser
|
||||
parser = argparse.ArgumentParser(description="Convert all Markdown files in a folder to separate JSON files.")
|
||||
parser.add_argument(
|
||||
"--input", required=True, help="Path to the folder containing Markdown files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", default=None, help="Output folder for JSON files (defaults to input folder if not specified)"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve input folder path
|
||||
input_folder = Path(args.input)
|
||||
if not input_folder.is_dir():
|
||||
print(f"Error: '{args.input}' is not a directory.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Set output folder (use input folder if not specified)
|
||||
output_folder = Path(args.output) if args.output else input_folder
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find all .md files in the input folder
|
||||
markdown_files = [f for f in input_folder.iterdir() if f.suffix.lower() == ".md"]
|
||||
|
||||
if not markdown_files:
|
||||
print(f"Error: No Markdown files found in '{args.input}'.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Process each Markdown file
|
||||
for file in markdown_files:
|
||||
try:
|
||||
# Read the content of the file
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Create the JSON object
|
||||
obj = {
|
||||
"id": "1",
|
||||
"text": content,
|
||||
"metadata": {
|
||||
"lang": "en"
|
||||
}
|
||||
}
|
||||
|
||||
# Create output JSON filename (e.g., file1.md -> file1.json)
|
||||
output_file = output_folder / f"{file.stem}.json"
|
||||
|
||||
# Write JSON to file
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump([obj], f, indent=4)
|
||||
|
||||
print(f"Successfully converted '{file}' to '{output_file}'")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{file}' not found.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error processing file '{file}': {e}", file=sys.stderr)
|
||||
1
atlas_rag/llm_generator/__init__.py
Normal file
1
atlas_rag/llm_generator/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .llm_generator import LLMGenerator
|
||||
BIN
atlas_rag/llm_generator/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/llm_generator/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
0
atlas_rag/llm_generator/format/__init__.py
Normal file
0
atlas_rag/llm_generator/format/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
144
atlas_rag/llm_generator/format/validate_json_output.py
Normal file
144
atlas_rag/llm_generator/format/validate_json_output.py
Normal file
@ -0,0 +1,144 @@
|
||||
import json
|
||||
from typing import List, Any
|
||||
import json_repair
|
||||
import jsonschema
|
||||
|
||||
def normalize_key(key):
|
||||
return key.strip().lower()
|
||||
|
||||
# recover function can be fix_triple_extraction_response, fix_filter_triplets
|
||||
def validate_output(output_str, **kwargs):
|
||||
schema = kwargs.get("schema")
|
||||
fix_function = kwargs.get("fix_function", None)
|
||||
allow_empty = kwargs.get("allow_empty", True)
|
||||
if fix_function:
|
||||
parsed_data = fix_function(output_str, **kwargs)
|
||||
jsonschema.validate(instance=parsed_data, schema=schema)
|
||||
if not allow_empty and (not parsed_data or len(parsed_data) == 0):
|
||||
raise ValueError("Parsed data is empty after validation.")
|
||||
return json.dumps(parsed_data, ensure_ascii=False)
|
||||
|
||||
def fix_filter_triplets(data: str, **kwargs) -> dict:
|
||||
data = json_repair.loads(data)
|
||||
processed_facts = []
|
||||
def find_triplet(element: Any) -> List[str] | None:
|
||||
# Base case: a valid triplet
|
||||
if isinstance(element, list) and len(element) == 3 and all(isinstance(item, str) for item in element):
|
||||
return element
|
||||
# Recursive case: dig deeper into nested lists
|
||||
elif isinstance(element, list):
|
||||
for sub_element in element:
|
||||
result = find_triplet(sub_element)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
for item in data.get("fact", []):
|
||||
triplet = find_triplet(item)
|
||||
if triplet:
|
||||
processed_facts.append(triplet)
|
||||
|
||||
return {"fact": processed_facts}
|
||||
|
||||
def fix_triple_extraction_response(response: str, **kwargs) -> str:
|
||||
"""Attempt to fix and validate JSON response based on the prompt type."""
|
||||
# Extract the JSON list from the response
|
||||
# raise error if prompt_type is not provided
|
||||
if "prompt_type" not in kwargs:
|
||||
raise ValueError("The 'prompt_type' argument is required.")
|
||||
prompt_type = kwargs.get("prompt_type")
|
||||
|
||||
json_start_token = response.find("[")
|
||||
if json_start_token == -1:
|
||||
# add [ at the start
|
||||
response = "[" + response.strip() + "]"
|
||||
parsed_objects = json_repair.loads(response)
|
||||
if len(parsed_objects) == 0:
|
||||
return []
|
||||
# Define required keys for each prompt type
|
||||
required_keys = {
|
||||
"entity_relation": {"Head", "Relation", "Tail"},
|
||||
"event_entity": {"Event", "Entity"},
|
||||
"event_relation": {"Head", "Relation", "Tail"}
|
||||
}
|
||||
|
||||
corrected_data = []
|
||||
seen_triples = set()
|
||||
for idx, item in enumerate(parsed_objects):
|
||||
if not isinstance(item, dict):
|
||||
print(f"Item {idx} must be a JSON object. Problematic item: {item}")
|
||||
continue
|
||||
|
||||
# Correct the keys
|
||||
corrected_item = {}
|
||||
for key, value in item.items():
|
||||
norm_key = normalize_key(key)
|
||||
matching_expected_keys = [exp_key for exp_key in required_keys[prompt_type] if normalize_key(exp_key) in norm_key]
|
||||
if len(matching_expected_keys) == 1:
|
||||
corrected_key = matching_expected_keys[0]
|
||||
corrected_item[corrected_key] = value
|
||||
else:
|
||||
corrected_item[key] = value
|
||||
|
||||
# Check for missing keys in corrected_item
|
||||
missing = required_keys[prompt_type] - corrected_item.keys()
|
||||
if missing:
|
||||
print(f"Item {idx} missing required keys: {missing}. Problematic item: {item}")
|
||||
continue
|
||||
|
||||
# Validate and correct the values in corrected_item
|
||||
if prompt_type == "entity_relation":
|
||||
for key in ["Head", "Relation", "Tail"]:
|
||||
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
|
||||
print(f"Item {idx} {key} must be a non-empty string. Problematic item: {corrected_item}")
|
||||
continue
|
||||
|
||||
elif prompt_type == "event_entity":
|
||||
if not isinstance(corrected_item["Event"], str) or not corrected_item["Event"].strip():
|
||||
print(f"Item {idx} Event must be a non-empty string. Problematic item: {corrected_item}")
|
||||
continue
|
||||
if not isinstance(corrected_item["Entity"], list) or not corrected_item["Entity"]:
|
||||
print(f"Item {idx} Entity must be a non-empty array. Problematic item: {corrected_item}")
|
||||
continue
|
||||
else:
|
||||
corrected_item["Entity"] = [ent.strip() for ent in corrected_item["Entity"] if isinstance(ent, str)]
|
||||
|
||||
elif prompt_type == "event_relation":
|
||||
for key in ["Head", "Tail", "Relation"]:
|
||||
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
|
||||
print(f"Item {idx} {key} must be a non-empty sentence. Problematic item: {corrected_item}")
|
||||
continue
|
||||
|
||||
triple_tuple = tuple((k, str(v)) for k, v in corrected_item.items())
|
||||
if triple_tuple in seen_triples:
|
||||
print(f"Item {idx} is a duplicate triple: {corrected_item}")
|
||||
continue
|
||||
else:
|
||||
seen_triples.add(triple_tuple)
|
||||
corrected_data.append(corrected_item)
|
||||
|
||||
if not corrected_data:
|
||||
return []
|
||||
|
||||
return corrected_data
|
||||
|
||||
def fix_lkg_keywords(data: str, **kwargs) -> dict:
|
||||
"""
|
||||
Extract and flatten keywords into a list of strings, filtering invalid types.
|
||||
"""
|
||||
data = json_repair.loads(data)
|
||||
processed_keywords = []
|
||||
|
||||
def collect_strings(element: Any) -> None:
|
||||
if isinstance(element, str):
|
||||
if len(element) <= 200: # Filter out keywords longer than 100 characters
|
||||
processed_keywords.append(element)
|
||||
elif isinstance(element, list):
|
||||
for item in element:
|
||||
collect_strings(item)
|
||||
|
||||
# Start processing from the root "keywords" field
|
||||
collect_strings(data.get("keywords", []))
|
||||
|
||||
return {"keywords": processed_keywords}
|
||||
|
||||
93
atlas_rag/llm_generator/format/validate_json_schema.py
Normal file
93
atlas_rag/llm_generator/format/validate_json_schema.py
Normal file
@ -0,0 +1,93 @@
|
||||
filter_fact_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string" # All items in the inner array must be strings
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"additionalItems": False # Block extra items
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["fact"]
|
||||
}
|
||||
|
||||
lkg_keyword_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 1,
|
||||
}
|
||||
},
|
||||
"required": ["keywords"]
|
||||
}
|
||||
|
||||
triple_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Head": {
|
||||
"type": "string"
|
||||
},
|
||||
"Relation": {
|
||||
"type": "string"
|
||||
},
|
||||
"Tail": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["Head", "Relation", "Tail"]
|
||||
},
|
||||
}
|
||||
event_relation_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Head": {
|
||||
"type": "string"
|
||||
},
|
||||
"Relation": {
|
||||
"type": "string",
|
||||
},
|
||||
"Tail": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["Head", "Relation", "Tail"]
|
||||
},
|
||||
}
|
||||
event_entity_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Event": {
|
||||
"type": "string"
|
||||
},
|
||||
"Entity": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 1
|
||||
}
|
||||
},
|
||||
"required": ["Event", "Entity"]
|
||||
},
|
||||
}
|
||||
stage_to_schema = {
|
||||
1: triple_json_schema,
|
||||
2: event_entity_json_schema,
|
||||
3: event_relation_json_schema
|
||||
}
|
||||
364
atlas_rag/llm_generator/llm_generator.py
Normal file
364
atlas_rag/llm_generator/llm_generator.py
Normal file
@ -0,0 +1,364 @@
|
||||
import json
|
||||
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
||||
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
|
||||
from copy import deepcopy
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
|
||||
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, keyword_filtering_prompt
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import filter_triple_messages
|
||||
|
||||
from atlas_rag.llm_generator.format.validate_json_output import *
|
||||
from atlas_rag.llm_generator.format.validate_json_schema import filter_fact_json_schema, lkg_keyword_json_schema, stage_to_schema
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
import jsonschema
|
||||
|
||||
|
||||
import time
|
||||
stage_to_prompt_type = {
|
||||
1: "entity_relation",
|
||||
2: "event_entity",
|
||||
3: "event_relation",
|
||||
}
|
||||
retry_decorator = retry(
|
||||
stop=(stop_after_delay(120) | stop_after_attempt(5)), # Max 2 minutes or 5 attempts
|
||||
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
|
||||
)
|
||||
class LLMGenerator():
|
||||
def __init__(self, client, model_name):
|
||||
self.model_name = model_name
|
||||
self.client : OpenAI|Pipeline = client
|
||||
if isinstance(client, OpenAI|AzureOpenAI):
|
||||
self.inference_type = "openai"
|
||||
elif isinstance(client, Pipeline):
|
||||
self.inference_type = "pipeline"
|
||||
else:
|
||||
raise ValueError("Unsupported client type. Please provide either an OpenAI client or a Huggingface Pipeline Object.")
|
||||
|
||||
@retry_decorator
|
||||
def _api_inference(self, message, max_new_tokens=8192,
|
||||
temperature = 0.7,
|
||||
frequency_penalty = None,
|
||||
response_format = {"type": "text"},
|
||||
return_text_only=True,
|
||||
return_thinking=False,
|
||||
reasoning_effort=None,
|
||||
**kwargs):
|
||||
start_time = time.time()
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=message,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
frequency_penalty= NOT_GIVEN if frequency_penalty is None else frequency_penalty,
|
||||
response_format = response_format if response_format is not None else {"type": "text"},
|
||||
timeout = 120,
|
||||
reasoning_effort= NOT_GIVEN if reasoning_effort is None else reasoning_effort,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response.choices[0].message.content
|
||||
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
|
||||
content = response.choices[0].message.reasoning_content
|
||||
validate_function = kwargs.get('validate_function', None)
|
||||
content = validate_function(content, **kwargs) if validate_function else content
|
||||
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
else:
|
||||
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None and return_thinking:
|
||||
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
|
||||
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
completion_usage_dict = response.usage.model_dump()
|
||||
completion_usage_dict['time'] = time_cost
|
||||
return content, completion_usage_dict
|
||||
|
||||
def generate_response(self, batch_messages, do_sample=True, max_new_tokens=8192,
|
||||
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
|
||||
return_text_only=True, return_thinking=False, reasoning_effort=None, **kwargs):
|
||||
if temperature == 0.0:
|
||||
do_sample = False
|
||||
# single = list of dict, batch = list of list of dict
|
||||
is_batch = isinstance(batch_messages[0], list)
|
||||
if not is_batch:
|
||||
batch_messages = [batch_messages]
|
||||
results = [None] * len(batch_messages)
|
||||
to_process = list(range(len(batch_messages)))
|
||||
if self.inference_type == "openai":
|
||||
max_workers = kwargs.get('max_workers', 3) # Default to 4 workers if not specified
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
def process_message(i):
|
||||
try:
|
||||
return self._api_inference(
|
||||
batch_messages[i], max_new_tokens, temperature,
|
||||
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing message {i}: {e.last_attempt.result()}")
|
||||
return ""
|
||||
futures = [executor.submit(process_message, i) for i in to_process]
|
||||
for i, future in enumerate(futures):
|
||||
results[i] = future.result()
|
||||
|
||||
elif self.inference_type == "pipeline":
|
||||
max_retries = kwargs.get('max_retries', 3) # Default to 3 retries if not specified
|
||||
start_time = time.time()
|
||||
# Initial processing of all messages
|
||||
responses = self.client(
|
||||
batch_messages,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
return_full_text=False
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
|
||||
# Extract contents
|
||||
contents = [resp[0]['generated_text'].strip() for resp in responses]
|
||||
|
||||
# Validate and collect failed indices
|
||||
validate_function = kwargs.get('validate_function', None)
|
||||
failed_indices = []
|
||||
for i, content in enumerate(contents):
|
||||
if validate_function:
|
||||
try:
|
||||
contents[i] = validate_function(content, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"Validation failed for index {i}: {e}")
|
||||
failed_indices.append(i)
|
||||
|
||||
# Retry failed messages in batches
|
||||
for attempt in range(max_retries):
|
||||
if not failed_indices:
|
||||
break # No more failures to retry
|
||||
print(f"Retry attempt {attempt + 1}/{max_retries} for {len(failed_indices)} failed messages")
|
||||
# Prepare batch of failed messages
|
||||
failed_messages = [batch_messages[i] for i in failed_indices]
|
||||
try:
|
||||
# Process failed messages as a batch
|
||||
retry_responses = self.client(
|
||||
failed_messages,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
return_full_text=False
|
||||
)
|
||||
retry_contents = [resp[0]['generated_text'].strip() for resp in retry_responses]
|
||||
|
||||
# Validate retry results and update contents
|
||||
new_failed_indices = []
|
||||
for j, i in enumerate(failed_indices):
|
||||
try:
|
||||
if validate_function:
|
||||
retry_contents[j] = validate_function(retry_contents[j], **kwargs)
|
||||
contents[i] = retry_contents[j]
|
||||
except Exception as e:
|
||||
print(f"Validation failed for index {i} on retry {attempt + 1}: {e}")
|
||||
new_failed_indices.append(i)
|
||||
failed_indices = new_failed_indices # Update failed indices for next retry
|
||||
except Exception as e:
|
||||
print(f"Batch retry {attempt + 1} failed: {e}")
|
||||
# If batch processing fails, keep all indices in failed_indices
|
||||
if attempt == max_retries - 1:
|
||||
for i in failed_indices:
|
||||
contents[i] = "" # Set to "" if all retries fail
|
||||
|
||||
# Set remaining failed messages to "" after all retries
|
||||
for i in failed_indices:
|
||||
contents[i] = ""
|
||||
|
||||
# Process thinking tags
|
||||
if not return_thinking:
|
||||
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
|
||||
|
||||
if return_text_only:
|
||||
results = contents
|
||||
else:
|
||||
usage_dicts = [{
|
||||
'completion_tokens': len(content.split()),
|
||||
'time': time_cost / len(batch_messages)
|
||||
} for content in contents]
|
||||
results = list(zip(contents, usage_dicts))
|
||||
return results[0] if not is_batch else results
|
||||
|
||||
def generate_cot(self, question, max_new_tokens=1024):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_context(self, question, context, max_new_tokens=1024, temperature = 0.7):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}\nThought:"},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature = 0.7):
|
||||
messages = deepcopy(prompt_template)
|
||||
messages.append(
|
||||
{"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"},
|
||||
)
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature = 0.7):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}"},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity_event(self,question, triples):
|
||||
messages = deepcopy(filter_triple_messages)
|
||||
messages.append(
|
||||
{"role": "user", "content": f"""[ ## question ## ]]
|
||||
{question}
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{triples}"""})
|
||||
try:
|
||||
validate_args = {
|
||||
"schema": filter_fact_json_schema,
|
||||
"fix_function": fix_filter_triplets,
|
||||
}
|
||||
response = self.generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"},
|
||||
validate_function=validate_output, **validate_args)
|
||||
return response
|
||||
except Exception as e:
|
||||
# If all retries fail, return the original triples
|
||||
return triples
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_filter_keywords_with_entity(self, question, keywords):
|
||||
messages = deepcopy(keyword_filtering_prompt)
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"""[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
{keywords}"""
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
|
||||
# Validate and clean the response
|
||||
cleaned_data = validate_output(response, lkg_keyword_json_schema, fix_lkg_keywords)
|
||||
|
||||
return cleaned_data['keywords']
|
||||
except Exception as e:
|
||||
return keywords
|
||||
|
||||
def ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"},
|
||||
]
|
||||
return self.generate_response(messages)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_ner(self, text):
|
||||
messages = deepcopy(ner_prompt)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"[[ ## question ## ]]\n{text}"
|
||||
}
|
||||
)
|
||||
validation_args = {
|
||||
"schema": lkg_keyword_json_schema,
|
||||
"fix_function": fix_lkg_keywords
|
||||
}
|
||||
# Generate raw response from LLM
|
||||
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
|
||||
|
||||
try:
|
||||
# Validate and clean the response
|
||||
cleaned_data = json_repair.loads(raw_response)
|
||||
return cleaned_data['keywords']
|
||||
|
||||
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
|
||||
return [] # Fallback to empty list or raise custom exception
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_tog_ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an advanced AI assistant that extracts named entities from given text. "},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}
|
||||
]
|
||||
|
||||
# Generate raw response from LLM
|
||||
validation_args = {
|
||||
"schema": lkg_keyword_json_schema,
|
||||
"fix_function": fix_lkg_keywords
|
||||
}
|
||||
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
|
||||
|
||||
try:
|
||||
# Validate and clean the response
|
||||
cleaned_data = json_repair.loads(raw_response)
|
||||
return cleaned_data['keywords']
|
||||
|
||||
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
|
||||
return [] # Fallback to empty list or raise custom exception
|
||||
|
||||
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
|
||||
react_system_instruction = (
|
||||
'You are an advanced AI assistant that uses the ReAct framework to solve problems through iterative search. '
|
||||
'Follow these steps in your response:\n'
|
||||
'1. Thought: Think step by step and analyze if the current context is sufficient to answer the question. If not, review the current context and think critically about what can be searched to help answer the question.\n'
|
||||
' - Break down the question into *1-hop* sub-questions if necessary (e.g., identify key entities like people or places before addressing specific events).\n'
|
||||
' - Use the available context to make inferences about key entities and their relationships.\n'
|
||||
' - If a previous search query (prefix with "Previous search attempt") was not useful, reflect on why and adjust your strategy—avoid repeating similar queries and consider searching for general information about key entities or related concepts.\n'
|
||||
'2. Action: Choose one of:\n'
|
||||
' - Search for [Query]: If you need more information, specify a new query. The [Query] must differ from previous searches in wording and direction to explore new angles.\n'
|
||||
' - No Action: If the current context is sufficient.\n'
|
||||
'3. Answer: Provide one of:\n'
|
||||
' - A concise, definitive response as a noun phrase if you can answer.\n'
|
||||
' - "Need more information" if you need to search.\n\n'
|
||||
'Format your response exactly as:\n'
|
||||
'Thought: [your reasoning]\n'
|
||||
'Action: [Search for [Query] or No Action]\n'
|
||||
'Answer: [concise noun phrase if you can answer, or "Need more information" if you need to search]\n\n'
|
||||
)
|
||||
|
||||
# Build context with search history if available
|
||||
full_context = []
|
||||
if search_history:
|
||||
for i, (thought, action, observation) in enumerate(search_history):
|
||||
search_history_text = f"\nPrevious search attempt {i}:\n"
|
||||
search_history_text += f"{action}\n Result: {observation}\n"
|
||||
full_context.append(search_history_text)
|
||||
if context:
|
||||
full_context_text = f"Current Retrieved Context:\n{context}\n"
|
||||
full_context.append(full_context_text)
|
||||
if logger:
|
||||
logger.info(f"Full context for ReAct generation: {full_context}")
|
||||
|
||||
# Combine few-shot examples with system instruction and user query
|
||||
messages = [
|
||||
{"role": "system", "content": react_system_instruction},
|
||||
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
|
||||
if full_context else f"Question: {question}"}
|
||||
]
|
||||
if logger:
|
||||
logger.info(f"Messages for ReAct generation: {search_history}Question: {question}")
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
|
||||
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False, allow_empty=False):
|
||||
if isinstance(messages[0], dict):
|
||||
messages = [messages]
|
||||
validate_kwargs = {
|
||||
'schema': stage_to_schema.get(stage, None),
|
||||
'fix_function': fix_triple_extraction_response,
|
||||
'prompt_type': stage_to_prompt_type.get(stage, None),
|
||||
'allow_empty': allow_empty
|
||||
}
|
||||
result = self.generate_response(messages, max_new_tokens=max_tokens, validate_function=validate_output, return_text_only = not record, **validate_kwargs)
|
||||
return result
|
||||
381
atlas_rag/llm_generator/llm_generator_legacy.py
Normal file
381
atlas_rag/llm_generator/llm_generator_legacy.py
Normal file
@ -0,0 +1,381 @@
|
||||
import json
|
||||
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
||||
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
|
||||
from copy import deepcopy
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
|
||||
from atlas_rag.llm_generator.format.validate_json_output import validate_filter_output, messages as filter_messages
|
||||
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, validate_keyword_output, keyword_filtering_prompt
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
from atlas_rag.llm_generator.format.validate_json_output import fix_and_validate_response
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
import jsonschema
|
||||
from typing import Union
|
||||
from logging import Logger
|
||||
|
||||
stage_to_prompt_type = {
|
||||
1: "entity_relation",
|
||||
2: "event_entity",
|
||||
3: "event_relation",
|
||||
}
|
||||
retry_decorator = retry(
|
||||
stop=(stop_after_delay(120) | stop_after_attempt(5)),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
|
||||
)
|
||||
|
||||
class LLMGenerator:
|
||||
def __init__(self, client, model_name):
|
||||
self.model_name = model_name
|
||||
self.client: OpenAI | Pipeline = client
|
||||
if isinstance(client, (OpenAI, AzureOpenAI)):
|
||||
self.inference_type = "openai"
|
||||
elif isinstance(client, Pipeline):
|
||||
self.inference_type = "pipeline"
|
||||
else:
|
||||
raise ValueError("Unsupported client type6Please provide either an OpenAI client or a Huggingface Pipeline Object.")
|
||||
|
||||
@retry_decorator
|
||||
def _generate_response(self, messages, do_sample=True, max_new_tokens=8192, temperature=0.7,
|
||||
frequency_penalty=None, response_format={"type": "text"}, return_text_only=True,
|
||||
return_thinking=False, reasoning_effort=None):
|
||||
if temperature == 0.0:
|
||||
do_sample = False
|
||||
if self.inference_type == "openai":
|
||||
start_time = time.time()
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
frequency_penalty=NOT_GIVEN if frequency_penalty is None else frequency_penalty,
|
||||
response_format=response_format if response_format is not None else {"type": "text"},
|
||||
timeout=120,
|
||||
reasoning_effort=NOT_GIVEN if reasoning_effort is None else reasoning_effort,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response.choices[0].message.content
|
||||
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
|
||||
content = response.choices[0].message.reasoning_content
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
else:
|
||||
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None:
|
||||
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
completion_usage_dict = response.usage.model_dump()
|
||||
completion_usage_dict['time'] = time_cost
|
||||
return content, completion_usage_dict
|
||||
elif self.inference_type == "pipeline":
|
||||
start_time = time.time()
|
||||
if hasattr(self.client, 'tokenizer'):
|
||||
input_text = self.client.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
else:
|
||||
input_text = "\n".join([msg["content"] for msg in messages])
|
||||
response = self.client(
|
||||
input_text,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response[0]['generated_text'].strip()
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
token_count = len(content.split())
|
||||
completion_usage_dict = {
|
||||
'completion_tokens': token_count,
|
||||
'time': time_cost
|
||||
}
|
||||
return content, completion_usage_dict
|
||||
|
||||
def _generate_batch_responses(self, batch_messages, do_sample=True, max_new_tokens=8192,
|
||||
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
|
||||
return_text_only=True, return_thinking=False, reasoning_effort=None):
|
||||
if self.inference_type == "openai":
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._generate_response, messages, do_sample, max_new_tokens, temperature,
|
||||
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort
|
||||
) for messages in batch_messages
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
return results
|
||||
elif self.inference_type == "pipeline":
|
||||
if not hasattr(self.client, 'tokenizer'):
|
||||
raise ValueError("Pipeline must have a tokenizer for batch processing.")
|
||||
batch_inputs = [self.client.tokenizer.apply_chat_template(messages, tokenize=False) for messages in batch_messages]
|
||||
start_time = time.time()
|
||||
responses = self.client(
|
||||
batch_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
contents = [resp['generated_text'].strip() for resp in responses]
|
||||
if not return_thinking:
|
||||
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
|
||||
if return_text_only:
|
||||
return contents
|
||||
else:
|
||||
usage_dicts = [{
|
||||
'completion_tokens': len(content.split()),
|
||||
'time': time_cost / len(batch_messages)
|
||||
} for content in contents]
|
||||
return list(zip(contents, usage_dicts))
|
||||
|
||||
def generate_cot(self, questions, max_new_tokens=1024):
|
||||
if isinstance(questions, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": questions}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
elif isinstance(questions, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": q}] for q in questions]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_context(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}\nThought:"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{q}\nThought:"}] for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(prompt_template)
|
||||
messages.append({"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"})
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(prompt_template) + [{"role": "user", "content": f"{context}\n\nQuestions:{q}\nThought:"}]
|
||||
for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{q}"}] for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity(self, question, nodes, max_new_tokens=1024):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": """
|
||||
Your task is to filter text candidates based on their relevance to a given query...
|
||||
"""}, {"role": "user", "content": f"{question} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||||
try:
|
||||
response = json.loads(self._generate_response(messages, max_new_tokens=max_new_tokens))
|
||||
return response
|
||||
except Exception:
|
||||
return json.loads(nodes)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": """
|
||||
Your task is to filter text candidates based on their relevance to a given query...
|
||||
"""}, {"role": "user", "content": f"{q} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||||
for q in question]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||||
return [json.loads(resp) if json.loads(resp) else json.loads(nodes) for resp in responses]
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity_event(self, question, triples):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(filter_messages)
|
||||
messages.append({"role": "user", "content": f"[ ## question ## ]]\n{question}\n[[ ## fact_before_filter ## ]]\n{triples}"})
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_filter_output(response)
|
||||
return cleaned_data['fact']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(filter_messages) + [{"role": "user", "content": f"[ ## question ## ]]\n{q}\n[[ ## fact_before_filter ## ]]\n{triples}"}]
|
||||
for q in question]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||||
return [validate_filter_output(resp)['fact'] if validate_filter_output(resp) else [] for resp in responses]
|
||||
|
||||
def generate_with_custom_messages(self, custom_messages, do_sample=True, max_new_tokens=1024, temperature=0.8, frequency_penalty=None):
|
||||
if isinstance(custom_messages[0], dict):
|
||||
return self._generate_response(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||||
elif isinstance(custom_messages[0], list):
|
||||
return self._generate_batch_responses(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_filter_keywords_with_entity(self, question, keywords):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(keyword_filtering_prompt)
|
||||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{question}\n[[ ## keywords_before_filter ## ]]\n{keywords}"})
|
||||
try:
|
||||
response = self._generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return keywords
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(keyword_filtering_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{q}\n[[ ## keywords_before_filter ## ]]\n{k}"}]
|
||||
for q, k in zip(question, keywords)]
|
||||
responses = self._generate_batch_responses(batch_messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else keywords for resp in responses]
|
||||
|
||||
def ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = [{"role": "system", "content": "Please extract the entities..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||||
return self._generate_response(messages)
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [[{"role": "system", "content": "Please extract the entities..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||||
return self._generate_batch_responses(batch_messages)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = deepcopy(ner_prompt)
|
||||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{text}"})
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [deepcopy(ner_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{t}"}] for t in text]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_tog_ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = [{"role": "system", "content": "You are an advanced AI assistant..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [[{"role": "system", "content": "You are an advanced AI assistant..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||||
|
||||
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
|
||||
# Implementation remains single-input focused as it’s iterative; batching not applicable here
|
||||
react_system_instruction = (
|
||||
'You are an advanced AI assistant that uses the ReAct framework...'
|
||||
)
|
||||
full_context = []
|
||||
if search_history:
|
||||
for i, (thought, action, observation) in enumerate(search_history):
|
||||
full_context.append(f"\nPrevious search attempt {i}:\n{action}\n Result: {observation}\n")
|
||||
if context:
|
||||
full_context.append(f"Current Retrieved Context:\n{context}\n")
|
||||
messages = [{"role": "system", "content": react_system_instruction},
|
||||
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
|
||||
if full_context else f"Question: {question}"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'],
|
||||
max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||||
# Single-input iterative process; batching not applicable
|
||||
search_history = []
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = ". ".join(initial_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = "\n".join(initial_context)
|
||||
for iteration in range(max_iterations):
|
||||
analysis_response = self.generate_with_react(
|
||||
question=question, context=current_context, max_new_tokens=max_new_tokens, search_history=search_history, logger=logger
|
||||
)
|
||||
try:
|
||||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||||
answer = analysis_response.split("Answer:")[1].strip()
|
||||
if answer.lower() != "need more information":
|
||||
search_history.append((thought, action, "Using current context"))
|
||||
return answer, search_history
|
||||
if "search" in action.lower():
|
||||
search_query = action.split("search for")[-1].strip()
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
current_contexts = current_context.split(". ")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = ". ".join(new_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
current_contexts = current_context.split("\n")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = "\n".join(new_context)
|
||||
observation = f"Found information: {new_context}" if new_context else "No new information found..."
|
||||
search_history.append((thought, action, observation))
|
||||
if new_context:
|
||||
current_context = f"{current_context}\n{new_context}"
|
||||
else:
|
||||
search_history.append((thought, action, "No action taken but answer not found"))
|
||||
return "Unable to find answer", search_history
|
||||
except Exception as e:
|
||||
return analysis_response, search_history
|
||||
return answer, search_history
|
||||
|
||||
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False):
|
||||
if isinstance(messages[0], dict):
|
||||
messages = [messages]
|
||||
responses = self._generate_batch_responses(
|
||||
batch_messages=messages,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
do_sample=False,
|
||||
frequency_penalty=0.5,
|
||||
reasoning_effort="none",
|
||||
return_text_only=not record
|
||||
)
|
||||
processed_responses = []
|
||||
for response in responses:
|
||||
if record:
|
||||
content, usage_dict = response
|
||||
else:
|
||||
content = response
|
||||
usage_dict = None
|
||||
try:
|
||||
prompt_type = stage_to_prompt_type.get(stage, None)
|
||||
if prompt_type:
|
||||
corrected, error = fix_and_validate_response(content, prompt_type)
|
||||
if error:
|
||||
raise ValueError(f"Validation failed for prompt_type '{prompt_type}'")
|
||||
else:
|
||||
corrected = content
|
||||
if corrected and corrected.strip():
|
||||
if record:
|
||||
processed_responses.append((corrected, usage_dict))
|
||||
else:
|
||||
processed_responses.append(corrected)
|
||||
else:
|
||||
raise ValueError("Invalid response")
|
||||
except Exception as e:
|
||||
print(f"Failed to process response: {str(e)}")
|
||||
if record:
|
||||
usage_dict = {'completion_tokens': 0, 'total_tokens': 0, 'time': 0}
|
||||
processed_responses.append(("[]", usage_dict))
|
||||
else:
|
||||
processed_responses.append("[]")
|
||||
return processed_responses
|
||||
0
atlas_rag/llm_generator/prompt/__init__.py
Normal file
0
atlas_rag/llm_generator/prompt/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
atlas_rag/llm_generator/prompt/__pycache__/react.cpython-311.pyc
Normal file
BIN
atlas_rag/llm_generator/prompt/__pycache__/react.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
263
atlas_rag/llm_generator/prompt/lkg_prompt.py
Normal file
263
atlas_rag/llm_generator/prompt/lkg_prompt.py
Normal file
@ -0,0 +1,263 @@
|
||||
import json
|
||||
import jsonschema
|
||||
from typing import Any
|
||||
|
||||
ner_prompt =[
|
||||
{"role": "system",
|
||||
"content": """
|
||||
You are a domain analysis engine. You must provide keywords for searching relevant documents.
|
||||
When given any academic question, follow these steps:
|
||||
1. **Identify Tested Skills:** Determine the *abstract knowledge/skills* required to solve the problem (e.g., "translating universal statements into predicate logic"), not the concrete entities in the question (e.g., "children," "school").
|
||||
2. **Extract Domain Specific Term:** Extract domain-specific technical terms (e.g., "school" in *educational policy*), exclude common nouns/verbs describing the question's *scenario* (e.g., "child," "school," "goes to").
|
||||
3. **Prioritize Formal Structures:** For logic/math problems, focus on notation rules (e.g., quantifier order, implication vs. conjunction), not scenario labels.
|
||||
4. **Capture Rare Technical Terms:** Include uncommon domain-specific terms critical to the question (e.g., "epigenetics" in biology, "monad" in computer science), even if they appear infrequently in general language.
|
||||
Your input fields are:
|
||||
1. question (str): Query for keyword extraction
|
||||
Your output fields are:
|
||||
1. keywords (array): Extracted keywords in JSON format
|
||||
All interactions will be structured as:
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords ## ]]
|
||||
{keywords}
|
||||
The output must be parseable according to JSON schema:
|
||||
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
|
||||
"""},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nSolve \(x^2 - 5x + 6 = 0\)."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"quadratic equation\", \"factoring\", \"roots\", \"algebraic manipulation\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nExplain the socio-economic causes of the French Revolution."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"historical causation\", \"class struggle\", \"economic inequality\", \"Enlightenment philosophy\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nProve that the square root of 2 is irrational."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"proof by contradiction\", \"irrational numbers\", \"number theory\", \"rational/irrational distinction\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nExplain the implications of Heisenberg's uncertainty principle on quantum measurements."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"Heisenberg uncertainty principle\", \"quantum mechanics\", \"measurement theory\", \"quantum state collapse\", \"observable quantities\"]}"
|
||||
}
|
||||
]
|
||||
|
||||
keyword_filtering_prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a precision-focused component of a knowledge retrieval system used by researchers and educators.
|
||||
Your task is to filter keywords based on their relevance to the **core domain knowledge** required to answer a given question.
|
||||
You must critically evaluate whether each item represents:
|
||||
1. Academic concepts/methodologies (entities)
|
||||
2. Critical processes or relationships (events)
|
||||
The output should be in JSON format, e.g., {"keywords": ["concept1", "concept2"]}. If no keywords are relevant, return {"keywords": []}.
|
||||
Accuracy is crucial, as these keywords drive document retrieval for critical research. Do not generate any new keywords or explanations.
|
||||
**Do not change the content of each object in the list. You must only use text from the candidate list and cannot generate new text.**
|
||||
**Include all characters of the selected keywords.**
|
||||
**Do not include any duplicate keywords.**
|
||||
**The keywords can be a sentence describing a event as long as it is helpful for searching.**
|
||||
---
|
||||
|
||||
**Input Fields:**
|
||||
1. question (str): Query requiring knowledge analysis
|
||||
2. keywords_before_filter (list): Candidate keywords to evaluate
|
||||
|
||||
**Output Field:**
|
||||
1. keywords_after_filter (list): Filtered keywords in JSON format
|
||||
|
||||
---
|
||||
|
||||
**Interaction Structure:**
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
{keywords_before_filter}
|
||||
[[ ## keywords_after_filter ## ]]
|
||||
{keywords_after_filter}
|
||||
|
||||
**JSON Schema Validation:**
|
||||
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
|
||||
"""},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Explain the causes of the French Revolution.
|
||||
[[ ## candidates ## ]]
|
||||
["French Revolution", "social inequality", "Enlightenment ideas spreading", "economic crisis", "Bastille storming event"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["social inequality", "Enlightenment ideas spreading", "economic crisis"]}"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Translate 'All children go to some school' into predicate logic.
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
["predicate logic", "children", "school", "universal quantifier (∀)", "attendance"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["predicate logic", "universal quantifier (∀)"]}"""
|
||||
},
|
||||
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Solve \(x^2 - 5x + 6 = 0\).
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
["quadratic equation", "polynomial", "algebra", "roots", "classroom teaching"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["quadratic equation", "polynomial", "roots"]}"""
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
test_messages = [
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: What is the phenotype of a congenital disorder impairing the secretion of leptin?
|
||||
A. Normal energy intake, normal body weight and hyperthyroidism
|
||||
B. Obesity, excess energy intake, normal growth and hypoinsulinaemia
|
||||
C. Obesity, abnormal growth, hypothyroidism, hyperinsulinaemia
|
||||
D. Underweight, abnormal growth, hypothyroidism, hyperinsulinaemia"""
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and five candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Which of the following is notan advantage of stratified random sampling over simple random sampling?
|
||||
A. When done correctly, a stratified random sample is less biased than a simple random sample.
|
||||
B. When done correctly, a stratified random sampling process has less variability from sample to sample than a simple random sample.
|
||||
C. When done correctly, a stratified random sample can provide, with a smaller sample size, an estimate that is just as reliable as that of a simple random sample with a larger sample size.
|
||||
D. A stratified random sample provides information about each stratum in the population as well as an estimate for the population as a whole, and a simple random sample does not."""
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
|
||||
A. 0
|
||||
B. 4
|
||||
C. 2
|
||||
D. 6"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral
|
||||
A. paralysis of the facial muscles.
|
||||
B. paralysis of the facial muscles and loss of taste.
|
||||
C. paralysis of the facial muscles, loss of taste and lacrimation.
|
||||
D. paralysis of the facial muscles, loss of taste, lacrimation and decreased salivation."""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: What is true for a type-Ia ("type one-a") supernova?
|
||||
A. This type occurs in binary systems.
|
||||
B. This type occurs in young galaxies.
|
||||
C. This type produces gamma-ray bursts.
|
||||
D. This type produces high amounts of X-rays."""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: _______ such as bitcoin are becoming increasingly mainstream and have a whole host of associated ethical implications, for example, they are______ and more ______. However, they have also been used to engage in _______.
|
||||
A. Cryptocurrencies, Expensive, Secure, Financial Crime
|
||||
B. Traditional currency, Cheap, Unsecure, Charitable giving
|
||||
C. Cryptocurrencies, Cheap, Secure, Financial crime
|
||||
D. Traditional currency, Expensive, Unsecure, Charitable giving"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: The access matrix approach to protection has the difficulty that
|
||||
A. the matrix, if stored directly, is large and can be clumsy to manage
|
||||
B. it is not capable of expressing complex protection requirements
|
||||
C. deciding whether a process has access to a resource is undecidable
|
||||
D. there is no way to express who has rights to change the access matrix itself"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: In the NoNicks operating system, the time required by a single file-read operation has four nonoverlapping components: disk seek time-25 msec disk latency time-8 msec disk transfer time- 1 msec per 1,000 bytes operating system overhead-1 msec per 1,000 bytes + 10 msec In version 1 of the system, the file read retrieved blocks of 1,000 bytes. In version 2, the file read (along with the underlying layout on disk) was modified to retrieve blocks of 4,000 bytes. The ratio of-the time required to read a large file under version 2 to the time required to read the same large file under version 1 is approximately
|
||||
A. 1:4
|
||||
B. 1:3.5
|
||||
C. 1:1
|
||||
D. 1.1:1"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Which of the following propositions is an immediate (one-step) consequence in PL of the given premises? ~E ⊃ ~F G ⊃ F H ∨ ~E H ⊃ I ~I
|
||||
A. E ⊃ F
|
||||
B. F ⊃ G
|
||||
C. H ⊃ ~E
|
||||
D. ~H"""
|
||||
}
|
||||
],
|
||||
]
|
||||
115
atlas_rag/llm_generator/prompt/rag_prompt.py
Normal file
115
atlas_rag/llm_generator/prompt/rag_prompt.py
Normal file
@ -0,0 +1,115 @@
|
||||
one_shot_rag_qa_docs = (
|
||||
"""Wikipedia Title: The Last Horse\nThe Last Horse (Spanish:El último caballo) is a 1950 Spanish comedy film directed by Edgar Neville starring Fernando Fernán Gómez.\n"""
|
||||
"""Wikipedia Title: Southampton\nThe University of Southampton, which was founded in 1862 and received its Royal Charter as a university in 1952, has over 22,000 students. The university is ranked in the top 100 research universities in the world in the Academic Ranking of World Universities 2010. In 2010, the THES - QS World University Rankings positioned the University of Southampton in the top 80 universities in the world. The university considers itself one of the top 5 research universities in the UK. The university has a global reputation for research into engineering sciences, oceanography, chemistry, cancer sciences, sound and vibration research, computer science and electronics, optoelectronics and textile conservation at the Textile Conservation Centre (which is due to close in October 2009.) It is also home to the National Oceanography Centre, Southampton (NOCS), the focus of Natural Environment Research Council-funded marine research.\n"""
|
||||
"""Wikipedia Title: Stanton Township, Champaign County, Illinois\nStanton Township is a township in Champaign County, Illinois, USA. As of the 2010 census, its population was 505 and it contained 202 housing units.\n"""
|
||||
"""Wikipedia Title: Neville A. Stanton\nNeville A. Stanton is a British Professor of Human Factors and Ergonomics at the University of Southampton. Prof Stanton is a Chartered Engineer (C.Eng), Chartered Psychologist (C.Psychol) and Chartered Ergonomist (C.ErgHF). He has written and edited over a forty books and over three hundered peer-reviewed journal papers on applications of the subject. Stanton is a Fellow of the British Psychological Society, a Fellow of The Institute of Ergonomics and Human Factors and a member of the Institution of Engineering and Technology. He has been published in academic journals including "Nature". He has also helped organisations design new human-machine interfaces, such as the Adaptive Cruise Control system for Jaguar Cars.\n"""
|
||||
"""Wikipedia Title: Finding Nemo\nFinding Nemo Theatrical release poster Directed by Andrew Stanton Produced by Graham Walters Screenplay by Andrew Stanton Bob Peterson David Reynolds Story by Andrew Stanton Starring Albert Brooks Ellen DeGeneres Alexander Gould Willem Dafoe Music by Thomas Newman Cinematography Sharon Calahan Jeremy Lasky Edited by David Ian Salter Production company Walt Disney Pictures Pixar Animation Studios Distributed by Buena Vista Pictures Distribution Release date May 30, 2003 (2003 - 05 - 30) Running time 100 minutes Country United States Language English Budget $$94 million Box office $$940.3 million"""
|
||||
)
|
||||
|
||||
one_shot_ircot_demo = (
|
||||
f'{one_shot_rag_qa_docs}'
|
||||
'\n\nQuestion: '
|
||||
f"When was Neville A. Stanton's employer founded?"
|
||||
'\nThought: '
|
||||
f"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. So the answer is: 1862."
|
||||
'\n\n'
|
||||
)
|
||||
|
||||
rag_qa_system = (
|
||||
'As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations.'
|
||||
)
|
||||
|
||||
one_shot_rag_qa_input = (
|
||||
f"{one_shot_rag_qa_docs}"
|
||||
"\n\nQuestion: "
|
||||
"When was Neville A. Stanton's employer founded?"
|
||||
'\nThought: '
|
||||
)
|
||||
|
||||
one_shot_rag_qa_output = (
|
||||
"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. "
|
||||
"\nAnswer: 1862."
|
||||
)
|
||||
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": rag_qa_system},
|
||||
{"role": "user", "content": one_shot_rag_qa_input},
|
||||
{"role": "assistant", "content": one_shot_rag_qa_output},
|
||||
]
|
||||
|
||||
# from https://github.com/OSU-NLP-Group/HippoRAG/blob/main/src/qa/qa_reader.py
|
||||
|
||||
cot_system_instruction = ('As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. If the information is not enough, you can use your own knowledge to answer the question.'
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
cot_system_instruction_no_doc = ('As an advanced reading comprehension assistant, your task is to analyze the questions and then answer them. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
cot_system_instruction_kg = ('As an advanced reading comprehension assistant, your task is to analyze extracted information and corresponding questions meticulously. If the knowledge graph information is not enough, you can use your own knowledge to answer the question. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
|
||||
filter_triple_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a critical component of a high-stakes question-answering system used by top researchers and decision-makers worldwide.
|
||||
Your task is to filter facts based on their relevance to a given query.
|
||||
The query requires careful analysis and possibly multi-hop reasoning to connect different pieces of information.
|
||||
You must select all relevant facts from the provided candidate list, aiding in reasoning and providing an accurate answer.
|
||||
The output should be in JSON format, e.g., {"fact": [["s1", "p1", "o1"], ["s2", "p2", "o2"]]}, and if no facts are relevant, return an empty list, {"fact": []}.
|
||||
The accuracy of your response is paramount, as it will directly impact the decisions made by these high-level stakeholders. You must only use facts from the candidate list and not generate new facts.
|
||||
The future of critical decision-making relies on your ability to accurately filter and present relevant information.
|
||||
|
||||
Your input fields are:
|
||||
1. question (str): Query for retrieval
|
||||
2. fact_before_filter (str): Candidate facts to be filtered
|
||||
|
||||
Your output fields are:
|
||||
1. fact_after_filter (Fact): Filtered facts in JSON format
|
||||
|
||||
All interactions will be structured as:
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{fact_before_filter}
|
||||
|
||||
[[ ## fact_after_filter ## ]]
|
||||
{fact_after_filter}
|
||||
|
||||
The output must be parseable according to JSON schema: {"type": "object", "properties": {"fact": {"type": "array", "items": {"type": "array", "items": {"type": "string"}}}}, "required": ["fact"]}"""
|
||||
},
|
||||
# Example 1
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Are Imperial River (Florida) and Amaradia (Dolj) both located in the same country?
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{"fact": [["imperial river", "is located in", "florida"], ["imperial river", "is a river in", "united states"], ["imperial river", "may refer to", "south america"], ["amaradia", "flows through", "ro ia de amaradia"], ["imperial river", "may refer to", "united states"]]}"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"fact":[["imperial river","is located in","florida"],["imperial river","is a river in","united states"],["amaradia","flows through","ro ia de amaradia"]]}"""
|
||||
},
|
||||
|
||||
# Example 2
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
When is the director of film The Ancestor 's birthday?
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{"fact": [["jean jacques annaud", "born on", "1 october 1943"], ["tsui hark", "born on", "15 february 1950"], ["pablo trapero", "born on", "4 october 1971"], ["the ancestor", "directed by", "guido brignone"], ["benh zeitlin", "born on", "october 14 1982"]]}"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"fact":[["the ancestor","directed by","guido brignone"]]}"""
|
||||
},
|
||||
]
|
||||
108
atlas_rag/llm_generator/prompt/react.py
Normal file
108
atlas_rag/llm_generator/prompt/react.py
Normal file
@ -0,0 +1,108 @@
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import Union
|
||||
from logging import Logger
|
||||
|
||||
class ReAct():
|
||||
def __init__(self, llm:LLMGenerator):
|
||||
self.llm = llm
|
||||
|
||||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'], max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||||
"""
|
||||
Generate a response using RAG with ReAct framework, starting with an initial search using the original query.
|
||||
|
||||
Args:
|
||||
question (str): The question to answer
|
||||
retriever: The retriever instance to use for searching
|
||||
max_iterations (int): Maximum number of ReAct iterations
|
||||
max_new_tokens (int): Maximum number of tokens to generate per iteration
|
||||
|
||||
Returns:
|
||||
tuple: (final_answer, search_history)
|
||||
- final_answer: The final answer generated
|
||||
- search_history: List of (thought, action, observation) tuples
|
||||
"""
|
||||
search_history = []
|
||||
|
||||
# Perform initial search with the original query
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = ". ".join(initial_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = "\n".join(initial_context)
|
||||
|
||||
# Start ReAct process with the initial context
|
||||
for iteration in range(max_iterations):
|
||||
# First, analyze if we can answer with current context
|
||||
analysis_response = self.llm.generate_with_react(
|
||||
question=question,
|
||||
context=current_context,
|
||||
max_new_tokens=max_new_tokens,
|
||||
search_history=search_history,
|
||||
logger = logger
|
||||
)
|
||||
|
||||
if logger:
|
||||
logger.info(f"Analysis response: {analysis_response}")
|
||||
|
||||
try:
|
||||
# Parse the analysis response
|
||||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||||
if logger:
|
||||
logger.info(f"Thought: {thought}")
|
||||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||||
answer = analysis_response.split("Answer:")[1].strip()
|
||||
|
||||
# If the answer indicates we can answer with current context
|
||||
if answer.lower() != "need more information":
|
||||
search_history.append((thought, action, "Using current context"))
|
||||
return answer, search_history
|
||||
|
||||
# If we need more information, perform the search
|
||||
if "search" in action.lower():
|
||||
# Extract search query from the action
|
||||
search_query = action.split("search for")[-1].strip()
|
||||
|
||||
# Perform the search
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split(". ")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = ". ".join(new_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split("\n")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = "\n".join(new_context)
|
||||
|
||||
# Store the search results as observation
|
||||
if new_context:
|
||||
observation = f"Found information: {new_context}"
|
||||
else:
|
||||
observation = "No new information found. Consider searching for related entities or events."
|
||||
search_history.append((thought, action, observation))
|
||||
|
||||
# Update context with new search results
|
||||
if new_context:
|
||||
current_context = f"{current_context}\n{new_context}"
|
||||
if logger:
|
||||
logger.info(f"New search results: {new_context}")
|
||||
else:
|
||||
if logger:
|
||||
logger.info("No new information found, suggesting to try related entities")
|
||||
|
||||
else:
|
||||
# If no search is needed but we can't answer, something went wrong
|
||||
search_history.append((thought, action, "No action taken but answer not found"))
|
||||
return "Unable to find answer", search_history
|
||||
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.error(f"Error parsing ReAct response: {e}")
|
||||
return analysis_response, search_history
|
||||
|
||||
# If we've reached max iterations, return the last answer
|
||||
return answer, search_history
|
||||
313
atlas_rag/llm_generator/prompt/triple_extraction_prompt.py
Normal file
313
atlas_rag/llm_generator/prompt/triple_extraction_prompt.py
Normal file
@ -0,0 +1,313 @@
|
||||
TRIPLE_INSTRUCTIONS = {
|
||||
"en":{
|
||||
"system": "You are a helpful assistant who always response in a valid array of JSON objects without any explanation",
|
||||
"entity_relation": """Given a passage, summarize all the important entities and the relations between them in a concise manner. Relations should briefly capture the connections between entities, without repeating information from the head and tail entities. The entities should be as specific as possible. Exclude pronouns from being considered as entities.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Head": "{a noun}",
|
||||
"Relation": "{a verb}",
|
||||
"Tail": "{a noun}",
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_entity": """Please analyze and summarize the participation relations between the events and entities in the given paragraph. Each event is a single independent sentence. Additionally, identify all the entities that participated in the events. Do not use ellipses.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Event": "{a simple sentence describing an event}",
|
||||
"Entity": ["entity 1", "entity 2", "..."]
|
||||
}...
|
||||
] """,
|
||||
|
||||
"event_relation": """Please analyze and summarize the relationships between the events in the paragraph. Each event is a single independent sentence. Identify temporal and causal relationships between the events using the following types: before, after, at the same time, because, and as a result. Each extracted triple should be specific, meaningful, and able to stand alone. Do not use ellipses.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Head": "{a simple sentence describing the event 1}",
|
||||
"Relation": "{temporal or causality relation between the events}",
|
||||
"Tail": "{a simple sentence describing the event 2}"
|
||||
}...
|
||||
]""",
|
||||
"passage_start" : """Here is the passage."""
|
||||
},
|
||||
"zh-CN": {
|
||||
"system": """"你是一个始终以有效JSON数组格式回应的助手""",
|
||||
"entity_relation": """给定一段文字,提取所有重要实体及其关系,并以简洁的方式总结。关系描述应清晰表达实体间的联系,且不重复头尾实体的信息。实体需具体明确,排除代词。
|
||||
|
||||
**重要格式要求:**
|
||||
1. Head字段:必须是一个字符串,不能为空
|
||||
2. Relation字段:必须是一个字符串且不能为空,如果不确定请用"相关"
|
||||
3. Tail字段:必须是一个字符串,不能为空。如果有多个项目,请用顿号(、)连接
|
||||
|
||||
返回格式必须为以下JSON结构,内容需用简体中文表述:
|
||||
[
|
||||
{
|
||||
"Head": "{名词}",
|
||||
"Relation": "{动词或关系描述,不能为空}",
|
||||
"Tail": "{名词,多个用顿号连接}"
|
||||
}...
|
||||
]
|
||||
|
||||
示例:
|
||||
输入:"企业内部审计数字化产品技术要求包括审计作业、审计管理、数据建设及应用"
|
||||
输出:
|
||||
[
|
||||
{"Head": "企业内部审计数字化产品", "Relation": "技术要求包括", "Tail": "审计作业、审计管理、数据建设及应用"}
|
||||
]""",
|
||||
|
||||
"event_entity": """分析段落中的事件及其参与实体。每个事件应为独立单句,列出所有相关实体(需具体,不含代词)。
|
||||
返回格式必须为以下JSON结构,内容需用简体中文表述:
|
||||
[
|
||||
{
|
||||
"Event": "{描述事件的简单句子}",
|
||||
"Entity": ["实体1", "实体2", "..."]
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_relation": """分析事件间的时序或因果关系,关系类型包括:之前,之后,同时,因为,结果.每个事件应为独立单句。
|
||||
返回格式必须为以下JSON结构.内容需用简体中文表述.
|
||||
[
|
||||
{
|
||||
"Head": "{事件1描述}",
|
||||
"Relation": "{时序/因果关系}",
|
||||
"Tail": "{事件2描述}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"passage_start": "给定以下段落:"
|
||||
},
|
||||
"zh-HK": {
|
||||
"system": "你是一個始終以有效JSON數組格式回覆的助手",
|
||||
"entity_relation": """給定一段文字,提取所有重要實體及其關係,並以簡潔的方式總結。關係描述應清晰表達實體間的聯繫,且不重複頭尾實體的信息。實體需具體明確,排除代詞。
|
||||
返回格式必須為以下JSON結構,內容需用繁體中文表述:
|
||||
[
|
||||
{
|
||||
"Head": "{名詞}",
|
||||
"Relation": "{動詞或關係描述}",
|
||||
"Tail": "{名詞}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_entity": """分析段落中的事件及其參與實體。每個事件應為獨立單句,列出所有相關實體(需具體,不含代詞)。
|
||||
返回格式必須為以下JSON結構,內容需用繁體中文表述:
|
||||
[
|
||||
{
|
||||
"Event": "{描述事件的簡單句子}",
|
||||
"Entity": ["實體1", "實體2", "..."]
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_relation": """分析事件間的時序或因果關係,關係類型包括:之前,之後,同時,因為,結果.每個事件應為獨立單句。
|
||||
返回格式必須為以下JSON結構.內容需用繁體中文表述.
|
||||
[
|
||||
{
|
||||
"Head": "{事件1描述}",
|
||||
"Relation": "{時序/因果關係}",
|
||||
"Tail": "{事件2描述}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"passage_start": "給定以下段落:"
|
||||
}
|
||||
}
|
||||
|
||||
CONCEPT_INSTRUCTIONS = {
|
||||
"en": {
|
||||
"event": """I will give you an EVENT. You need to give several phrases containing 1-2 words for the ABSTRACT EVENT of this EVENT.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract event words should fulfill the following requirements.
|
||||
1. The ABSTRACT EVENT phrases can well represent the EVENT, and it could be the type of the EVENT or the related concepts of the EVENT.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
EVENT: A man retreats to mountains and forests.
|
||||
Your answer: retreat, relaxation, escape, nature, solitude
|
||||
EVENT: A cat chased a prey into its shelter
|
||||
Your answer: hunting, escape, predation, hidding, stalking
|
||||
EVENT: Sam playing with his dog
|
||||
Your answer: relaxing event, petting, playing, bonding, friendship
|
||||
EVENT: [EVENT]
|
||||
Your answer:""",
|
||||
"entity":"""I will give you an ENTITY. You need to give several phrases containing 1-2 words for the ABSTRACT ENTITY of this ENTITY.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract intention words should fulfill the following requirements.
|
||||
1. The ABSTRACT ENTITY phrases can well represent the ENTITY, and it could be the type of the ENTITY or the related concepts of the ENTITY.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
ENTITY: Soul
|
||||
CONTEXT: premiered BFI London Film Festival, became highest-grossing Pixar release
|
||||
Your answer: movie, film
|
||||
|
||||
ENTITY: Thinkpad X60
|
||||
CONTEXT: Richard Stallman announced he is using Trisquel on a Thinkpad X60
|
||||
Your answer: Thinkpad, laptop, machine, device, hardware, computer, brand
|
||||
|
||||
ENTITY: Harry Callahan
|
||||
CONTEXT: bluffs another robber, tortures Scorpio
|
||||
Your answer: person, Amarican, character, police officer, detective
|
||||
|
||||
ENTITY: Black Mountain College
|
||||
CONTEXT: was started by John Andrew Rice, attracted faculty
|
||||
Your answer: college, university, school, liberal arts college
|
||||
|
||||
EVENT: 1st April
|
||||
CONTEXT: Utkal Dibas celebrates
|
||||
Your answer: date, day, time, festival
|
||||
|
||||
ENTITY: [ENTITY]
|
||||
CONTEXT: [CONTEXT]
|
||||
Your answer:""",
|
||||
"relation":"""I will give you an RELATION. You need to give several phrases containing 1-2 words for the ABSTRACT RELATION of this RELATION.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract intention words should fulfill the following requirements.
|
||||
1. The ABSTRACT RELATION phrases can well represent the RELATION, and it could be the type of the RELATION or the simplest concepts of the RELATION.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
RELATION: participated in
|
||||
Your answer: become part of, attend, take part in, engage in, involve in
|
||||
RELATION: be included in
|
||||
Your answer: join, be a part of, be a member of, be a component of
|
||||
RELATION: [RELATION]
|
||||
Your answer:"""
|
||||
},
|
||||
"zh-CN": {
|
||||
"event": """我将给你一个事件。你需要为这个事件的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象事件短语应满足以下要求:
|
||||
1. 能很好地代表该事件的类型或相关概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
事件:一个人退隐到山林中
|
||||
你的回答:退隐, 放松, 逃避, 自然, 独处
|
||||
事件:一只猫将猎物追进巢穴
|
||||
你的回答:捕猎, 逃跑, 捕食, 躲藏, 潜行
|
||||
事件:山姆和他的狗玩耍
|
||||
你的回答:休闲活动, 抚摸, 玩耍, bonding, 友谊
|
||||
事件:[EVENT]
|
||||
请以简体中文输出你的回答:""",
|
||||
"entity":"""我将给你一个实体。你需要为这个实体的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象实体短语应满足以下要求:
|
||||
1. 能很好地代表该实体的类型或相关概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
实体:心灵奇旅
|
||||
上下文:在BFI伦敦电影节首映,成为皮克斯最卖座影片
|
||||
你的回答:电影, 影片
|
||||
实体:Thinkpad X60
|
||||
上下文:Richard Stallman宣布他在Thinkpad X60上使用Trisquel系统
|
||||
你的回答:Thinkpad, 笔记本电脑, 机器, 设备, 硬件, 电脑, 品牌
|
||||
实体:哈利·卡拉汉
|
||||
上下文:吓退另一个劫匪,折磨天蝎座
|
||||
你的回答:人物, 美国人, 角色, 警察, 侦探
|
||||
实体:黑山学院
|
||||
上下文:由John Andrew Rice创办,吸引了众多教员
|
||||
你的回答:学院, 大学, 学校, 文理学院
|
||||
事件:4月1日
|
||||
上下文:庆祝Utkal Dibas
|
||||
你的回答:日期, 日子, 时间, 节日
|
||||
实体:[ENTITY]
|
||||
上下文:[CONTEXT]
|
||||
请以简体中文输出你的回答:""",
|
||||
"relation":"""我将给你一个关系。你需要为这个关系的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象关系短语应满足以下要求:
|
||||
1. 能很好地代表该关系的类型或最基本概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
关系:参与
|
||||
你的回答:成为一部分, 参加, 参与其中, 涉及, 卷入
|
||||
关系:被包含在
|
||||
你的回答:加入, 成为一部分, 成为成员, 成为组成部分
|
||||
关系:[RELATION]
|
||||
请以简体中文输出你的回答:"""
|
||||
},
|
||||
"zh-HK": {
|
||||
"event": """我將給你一個事件。你需要為這個事件的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象事件短語應滿足以下要求:
|
||||
1. 能很好地代表該事件的類型或相關概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
事件:一個人退隱到山林中
|
||||
你的回答:退隱, 放鬆, 逃避, 自然, 獨處
|
||||
事件:一隻貓將獵物追進巢穴
|
||||
你的回答:捕獵, 逃跑, 捕食, 躲藏, 潛行
|
||||
事件:山姆和他的狗玩耍
|
||||
你的回答:休閒活動, 撫摸, 玩耍, bonding, 友誼
|
||||
事件:[EVENT]
|
||||
請以繁體中文輸出你的回答:""",
|
||||
"entity":"""我將給你一個實體。你需要為這個實體的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象實體短語應滿足以下要求:
|
||||
1. 能很好地代表該實體的類型或相關概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
實體:心靈奇旅
|
||||
上下文:在BFI倫敦電影節首映,成為皮克斯最賣座影片
|
||||
你的回答:電影, 影片
|
||||
實體:Thinkpad X60
|
||||
上下文:Richard Stallman宣布他在Thinkpad X60上使用Trisquel系統
|
||||
你的回答:Thinkpad, 筆記本電腦, 機器, 設備, 硬件, 電腦, 品牌
|
||||
實體:哈利·卡拉漢
|
||||
上下文:嚇退另一個劫匪,折磨天蠍座
|
||||
你的回答:人物, 美國人, 角色, 警察, 偵探
|
||||
實體:黑山學院
|
||||
上下文:由John Andrew Rice創辦,吸引了眾多教員
|
||||
你的回答:學院, 大學, 學校, 文理學院
|
||||
事件:4月1日
|
||||
上下文:慶祝Utkal Dibas
|
||||
你的回答:日期, 日子, 時間, 節日
|
||||
實體:[ENTITY]
|
||||
上下文:[CONTEXT]
|
||||
請以繁體中文輸出你的回答:""",
|
||||
"relation":"""我將給你一個關係。你需要為這個關係的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象關係短語應滿足以下要求:
|
||||
1. 能很好地代表該關係的類型或最基本概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
關係:參與
|
||||
你的回答:成為一部分, 參加, 參與其中, 涉及, 捲入
|
||||
關係:被包含在
|
||||
你的回答:加入, 成為一部分, 成為成員, 成為組成部分
|
||||
關係:[RELATION]
|
||||
請以繁體中文輸出你的回答:"""
|
||||
}
|
||||
}
|
||||
19
atlas_rag/logging.py
Normal file
19
atlas_rag/logging.py
Normal file
@ -0,0 +1,19 @@
|
||||
import logging
|
||||
from logging import Logger
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
import datetime
|
||||
from atlas_rag.evaluation.benchmark import BenchMarkConfig
|
||||
|
||||
def setup_logger(config:BenchMarkConfig, logger_name = "MyLogger") -> Logger:
|
||||
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
log_file_path = f'./log/{config.dataset_name}_event{config.include_events}_concept{config.include_concept}_{date_time}.log'
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
max_bytes = 50 * 1024 * 1024
|
||||
if not os.path.exists(log_file_path):
|
||||
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
||||
handler = RotatingFileHandler(log_file_path, maxBytes=max_bytes, backupCount=5)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
4
atlas_rag/retriever/__init__.py
Normal file
4
atlas_rag/retriever/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .hipporag import HippoRAGRetriever
|
||||
from .hipporag2 import HippoRAG2Retriever
|
||||
from .simple_retriever import SimpleGraphRetriever, SimpleTextRetriever
|
||||
from .tog import TogRetriever
|
||||
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/base.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-39.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag2.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag2.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/inference_config.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/inference_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/simple_retriever.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/simple_retriever.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/tog.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/tog.cpython-311.pyc
Normal file
Binary file not shown.
27
atlas_rag/retriever/base.py
Normal file
27
atlas_rag/retriever/base.py
Normal file
@ -0,0 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
class BaseEdgeRetriever(BaseRetriever):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
class BasePassageRetriever(BaseRetriever):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
140
atlas_rag/retriever/hipporag.py
Normal file
140
atlas_rag/retriever/hipporag.py
Normal file
@ -0,0 +1,140 @@
|
||||
from tqdm import tqdm
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
|
||||
class HippoRAGRetriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
file_id_to_node_id = {}
|
||||
self.KG = data["KG"]
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.KG.nodes[node_id]['type'] == "passage":
|
||||
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
|
||||
self.file_id_to_node_id = file_id_to_node_id
|
||||
|
||||
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
|
||||
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
|
||||
|
||||
|
||||
self.logger :Logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=10):
|
||||
|
||||
# extract entities from the query
|
||||
entities = self.llm_generator.ner(query)
|
||||
entities = entities.split(", ")
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG NER Entities: {entities}")
|
||||
# print("Entities:", entities)
|
||||
|
||||
if len(entities) == 0:
|
||||
# If the NER cannot extract any entities, we
|
||||
# use the query as the entity to do approximate search
|
||||
entities = [query]
|
||||
|
||||
# evenly distribute the topk for each entity
|
||||
topk_for_each_entity = topN//len(entities)
|
||||
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
if entity in self.node_name_list:
|
||||
# get the index of the entity in the node list
|
||||
index = self.node_name_list.index(entity)
|
||||
topk_nodes.append(self.node_list[index])
|
||||
else:
|
||||
topk_for_this_entity = 1
|
||||
|
||||
# print("Topk for this entity:", topk_for_this_entity)
|
||||
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = self.node_embeddings@entity_embedding[0].T
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
|
||||
topk_nodes += [self.node_list[i] for i in index_matrix]
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
|
||||
# assert len(topk_nodes) <= topN
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
|
||||
|
||||
# print("Topk nodes:", topk_nodes)
|
||||
# find the number of docs that one work appears in
|
||||
freq_dict_for_nodes = {}
|
||||
for node in topk_nodes:
|
||||
node_data = self.KG.nodes[node]
|
||||
# print(node_data)
|
||||
file_ids = node_data["file_id"]
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
freq_dict_for_nodes[node] = len(file_ids_list)
|
||||
|
||||
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
|
||||
|
||||
# print("personalization dict: ")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_nodes = self.inference_config.topk_nodes
|
||||
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
|
||||
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
|
||||
|
||||
for node in pr:
|
||||
pr[node] = round(pr[node], 4)
|
||||
if pr[node] < 0.001:
|
||||
pr[node] = 0
|
||||
|
||||
passage_probabilities_sum = {}
|
||||
for node in pr:
|
||||
node_data = self.KG.nodes[node]
|
||||
file_ids = node_data["file_id"]
|
||||
# for each file id check through each text_id
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
# file id to node id
|
||||
|
||||
for file_id in file_ids_list:
|
||||
if file_id == 'concept_file':
|
||||
continue
|
||||
for node_id in self.file_id_to_node_id[file_id]:
|
||||
if node_id not in passage_probabilities_sum:
|
||||
passage_probabilities_sum[node_id] = 0
|
||||
passage_probabilities_sum[node_id] += pr[node]
|
||||
|
||||
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
|
||||
top_passages = sorted_passages[:topN]
|
||||
top_passages, scores = zip(*top_passages)
|
||||
|
||||
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
|
||||
|
||||
return passag_contents, top_passages
|
||||
237
atlas_rag/retriever/hipporag2.py
Normal file
237
atlas_rag/retriever/hipporag2.py
Normal file
@ -0,0 +1,237 @@
|
||||
import networkx as nx
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Tuple
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import json_repair
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
def min_max_normalize(x):
|
||||
min_val = np.min(x)
|
||||
max_val = np.max(x)
|
||||
range_val = max_val - min_val
|
||||
|
||||
# Handle the case where all values are the same (range is zero)
|
||||
if range_val == 0:
|
||||
return np.ones_like(x) # Return an array of ones with the same shape as x
|
||||
|
||||
return (x - min_val) / range_val
|
||||
|
||||
|
||||
class HippoRAG2Retriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator,
|
||||
sentence_encoder:BaseEmbeddingModel,
|
||||
data : dict,
|
||||
inference_config: Optional[InferenceConfig] = None,
|
||||
logger = None,
|
||||
**kwargs):
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
self.edge_embeddings = data["edge_embeddings"]
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.text_id_list = list(self.passage_dict.keys())
|
||||
self.KG = data["KG"]
|
||||
self.KG = self.KG.subgraph(self.node_list + self.text_id_list)
|
||||
|
||||
|
||||
self.logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
hipporag2mode = "query2edge"
|
||||
if hipporag2mode == "query2edge":
|
||||
self.retrieve_node_fn = self.query2edge
|
||||
elif hipporag2mode == "query2node":
|
||||
self.retrieve_node_fn = self.query2node
|
||||
elif hipporag2mode == "ner2node":
|
||||
self.retrieve_node_fn = self.ner2node
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.")
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
node_id_to_file_id = {}
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage":
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"]
|
||||
else:
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"]
|
||||
self.node_id_to_file_id = node_id_to_file_id
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.ner(text)
|
||||
|
||||
def ner2node(self, query, topN = 10):
|
||||
entities = self.ner(query)
|
||||
entities = entities.split(", ")
|
||||
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
node_score_dict = {}
|
||||
for entity_index, entity in enumerate(entities):
|
||||
topk_for_this_entity = 1
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T)
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
if node not in topk_nodes:
|
||||
topk_nodes.append(node)
|
||||
node_score_dict[node] = sim_score
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
result_node_score_dict = {}
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
for node in topk_nodes:
|
||||
if node in node_score_dict:
|
||||
result_node_score_dict[node] = node_score_dict[node]
|
||||
return result_node_score_dict
|
||||
|
||||
def query2node(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="entity")
|
||||
scores = min_max_normalize(self.node_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
result_node_score_dict = {}
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
result_node_score_dict[node] = sim_score
|
||||
|
||||
return result_node_score_dict
|
||||
|
||||
def query2edge(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="edge")
|
||||
scores = min_max_normalize(self.edge_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
log_edge_list = []
|
||||
for index in index_matrix:
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
log_edge_list.append(edge_str)
|
||||
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
# construct the edge list
|
||||
before_filter_edge_json = {}
|
||||
before_filter_edge_json['fact'] = []
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
before_filter_edge_json['fact'].append(edge_str)
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}")
|
||||
filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False))
|
||||
filtered_facts = json_repair.loads(filtered_facts)['fact']
|
||||
if len(filtered_facts) == 0:
|
||||
return {}
|
||||
# use filtered facts to get the edge id and check if it exists in the original candidate list.
|
||||
node_score_dict = {}
|
||||
log_edge_list = []
|
||||
for edge in filtered_facts:
|
||||
edge_str = f'{edge[0]} {edge[1]} {edge[2]}'
|
||||
search_emb = self.sentence_encoder.encode([edge_str], query_type="search")
|
||||
D, I = self.edge_faiss_index.search(search_emb, 1)
|
||||
filtered_index = I[0][0]
|
||||
# get the edge and the original score
|
||||
edge = self.edge_list[filtered_index]
|
||||
log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']])
|
||||
head, tail = edge[0], edge[1]
|
||||
sim_score = scores[filtered_index]
|
||||
|
||||
if head not in node_score_dict:
|
||||
node_score_dict[head] = [sim_score]
|
||||
else:
|
||||
node_score_dict[head].append(sim_score)
|
||||
if tail not in node_score_dict:
|
||||
node_score_dict[tail] = [sim_score]
|
||||
else:
|
||||
node_score_dict[tail].append(sim_score)
|
||||
# average the scores
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}")
|
||||
|
||||
# take average of the scores
|
||||
for node in node_score_dict:
|
||||
node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node])
|
||||
|
||||
return node_score_dict
|
||||
|
||||
def query2passage(self, query, weight_adjust = 0.05):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability
|
||||
# create dict of passage id and score
|
||||
return dict(zip(self.text_id_list, sim_scores))
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05):
|
||||
node_dict = self.retrieve_node_fn(query, topN=topN)
|
||||
text_dict = self.query2passage(query, weight_adjust=weight_adjust)
|
||||
|
||||
return node_dict, text_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_edges = self.inference_config.topk_edges
|
||||
weight_adjust = self.inference_config.weight_adjust
|
||||
|
||||
node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust)
|
||||
|
||||
personalization_dict = {}
|
||||
if len(node_dict) == 0:
|
||||
# return topN text passages
|
||||
sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages = sorted_passages[:topN]
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(float(score))
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
|
||||
personalization_dict.update(node_dict)
|
||||
personalization_dict.update(text_dict)
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personalization_dict,
|
||||
alpha = self.inference_config.ppr_alpha,
|
||||
max_iter=self.inference_config.ppr_max_iter,
|
||||
tol=self.inference_config.ppr_tol)
|
||||
|
||||
# get the top N passages based on the text_id list and pagerank score
|
||||
text_dict_score = {}
|
||||
for node in self.text_id_list:
|
||||
# filter out nodes that have 0 score
|
||||
if pr[node] > 0.0:
|
||||
text_dict_score[node] = pr[node]
|
||||
|
||||
# return topN passages
|
||||
sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages_ids = sorted_passages_ids[:topN]
|
||||
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages_ids:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(score)
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
23
atlas_rag/retriever/inference_config.py
Normal file
23
atlas_rag/retriever/inference_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""
|
||||
Configuration class for inference settings.
|
||||
|
||||
Attributes:
|
||||
topk (int): Number of top results to retrieve. Default is 5.
|
||||
Dmax (int): Maximum depth for search. Default is 4.
|
||||
weight_adjust (float): Weight adjustment factor for passage retrieval. Default is 0.05.
|
||||
topk_edges (int): Number of top edges to retrieve. Default is 50.
|
||||
topk_nodes (int): Number of top nodes to retrieve. Default is 10.
|
||||
"""
|
||||
keyword: str = "musique"
|
||||
topk: int = 5
|
||||
Dmax: int = 4
|
||||
weight_adjust: float = 1.0
|
||||
topk_edges: int = 50
|
||||
topk_nodes: int = 10
|
||||
ppr_alpha: float = 0.99
|
||||
ppr_max_iter: int = 2000
|
||||
ppr_tol: float = 1e-7
|
||||
0
atlas_rag/retriever/lkg_retriever/__init__.py
Normal file
0
atlas_rag/retriever/lkg_retriever/__init__.py
Normal file
40
atlas_rag/retriever/lkg_retriever/base.py
Normal file
40
atlas_rag/retriever/lkg_retriever/base.py
Normal file
@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLargeKGRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve passages based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
class BaseLargeKGEdgeRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve Edges / Paths based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
313
atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
313
atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
@ -0,0 +1,313 @@
|
||||
from difflib import get_close_matches
|
||||
from logging import Logger
|
||||
import faiss
|
||||
from neo4j import GraphDatabase
|
||||
import time
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
from graphdatascience import GraphDataScience
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
import string
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
||||
|
||||
|
||||
|
||||
class LargeKGRetriever(BaseLargeKGRetriever):
|
||||
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
|
||||
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
node_index:faiss.Index, passage_index:faiss.Index,
|
||||
topN: int = 5,
|
||||
number_of_source_nodes_per_ner: int = 10,
|
||||
sampling_area : int = 250,logger:Logger = None):
|
||||
# istantiate one kg resources
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.gds_driver = GraphDataScience(self.neo4j_driver)
|
||||
self.topN = topN
|
||||
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
|
||||
self.sampling_area = sampling_area
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_faiss_index = node_index
|
||||
# self.edge_faiss_index = self.edge_indexes[keyword]
|
||||
self.passage_faiss_index = passage_index
|
||||
|
||||
self.verbose = False if logger is None else True
|
||||
self.logger = logger
|
||||
|
||||
self.ppr_weight_threshold = 0.00005
|
||||
|
||||
def set_model(self, model):
|
||||
if self.llm_generator.inference_type == 'openai':
|
||||
self.llm_generator.model_name = model
|
||||
else:
|
||||
raise ValueError("Model can only be set for OpenAI inference type.")
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.large_kg_ner(text)
|
||||
|
||||
def convert_numeric_id_to_name(self, numeric_id):
|
||||
if numeric_id.isdigit():
|
||||
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
|
||||
else:
|
||||
return numeric_id
|
||||
|
||||
def has_intersection(self, word_set, input_string):
|
||||
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
|
||||
if self.keyword == 'cc_en':
|
||||
# Check if any phrase in word_set is a substring of cleaned_string
|
||||
for phrase in word_set:
|
||||
if phrase in cleaned_string:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
# Check if any word in word_set is present in the cleaned_string's words
|
||||
words_in_string = set(cleaned_string.split())
|
||||
return not word_set.isdisjoint(words_in_string)
|
||||
|
||||
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
|
||||
graph = self.gds_driver.graph.get('largekgrag_graph')
|
||||
node_count = graph.node_count()
|
||||
sampling_ratio = sampling_area / node_count
|
||||
aggregation_node_dict = []
|
||||
ppr_weight_threshold = self.ppr_weight_threshold
|
||||
start_time = time.time()
|
||||
|
||||
# Pre-filter nodes based on ppr_weight threshold
|
||||
# Precompute word sets based on keyword
|
||||
if self.keyword == 'cc_en':
|
||||
filtered_personalization = {
|
||||
node_id: ppr_weight
|
||||
for node_id, ppr_weight in personalization_dict.items()
|
||||
if ppr_weight >= ppr_weight_threshold
|
||||
}
|
||||
stop_words = set(stopwords.words('english'))
|
||||
word_set_phrases = set()
|
||||
word_set_words = set()
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
name = self.gds_driver.util.asNode(node_id)['name']
|
||||
if name:
|
||||
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
|
||||
if cleaned_phrase:
|
||||
# Process for 'cc_en': remove stop words and add cleaned phrase
|
||||
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
|
||||
if filtered_words:
|
||||
cleaned_phrase_filtered = ' '.join(filtered_words)
|
||||
word_set_phrases.add(cleaned_phrase_filtered)
|
||||
|
||||
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
|
||||
if self.verbose:
|
||||
self.logger.info(f"Optimized word set: {word_set}")
|
||||
else:
|
||||
filtered_personalization = personalization_dict
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
|
||||
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
|
||||
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
|
||||
# Process each node in the filtered personalization dict
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
try:
|
||||
self.gds_driver.graph.drop('rwr_sample')
|
||||
start_time = time.time()
|
||||
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
|
||||
restartProbability = 0.4, logProgress = False)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
result = self.gds_driver.pageRank.stream(
|
||||
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
|
||||
).sort_values("score", ascending=False)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"pagerank type: {type(result)}")
|
||||
self.logger.info(f"pagerank result: {result}")
|
||||
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
# if self.keyword == 'cc_en':
|
||||
if self.keyword != 'cc_en':
|
||||
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
|
||||
else:
|
||||
result = result.to_dict('records')
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG :result: {result}")
|
||||
for entry in result:
|
||||
if self.keyword == 'cc_en':
|
||||
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
|
||||
if not self.has_intersection(word_set, node_name):
|
||||
continue
|
||||
|
||||
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
|
||||
aggregation_node_dict.append({
|
||||
'nodeId': numeric_id,
|
||||
'score': entry['score'] * ppr_weight
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.logger.error(f"Error processing node {node_id}: {e}")
|
||||
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
with self.neo4j_driver.session() as session:
|
||||
intermediate_time = time.time()
|
||||
# Step 1: Distribute entity scores to connected text nodes and find the top 5
|
||||
query_scores = """
|
||||
UNWIND $entries AS entry
|
||||
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
|
||||
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
|
||||
ORDER BY total_score DESC
|
||||
LIMIT $topN
|
||||
RETURN textId, total_score
|
||||
"""
|
||||
# Execute query to aggregate scores
|
||||
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
|
||||
top_numeric_ids = []
|
||||
top_scores = []
|
||||
|
||||
# Extract the top text node IDs and scores
|
||||
for record in result_scores:
|
||||
top_numeric_ids.append(record["textId"])
|
||||
top_scores.append(record["total_score"])
|
||||
|
||||
# Step 2: Use top numeric IDs to retrieve the original text
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
|
||||
intermediate_time = time.time()
|
||||
query_text = """
|
||||
UNWIND $textIds AS textId
|
||||
MATCH (t:Text {numeric_id: textId})
|
||||
RETURN t.original_text AS text, t.numeric_id AS textId
|
||||
"""
|
||||
result_texts = session.run(query_text, textIds=top_numeric_ids)
|
||||
topN_passages = []
|
||||
score_dict = dict(zip(top_numeric_ids, top_scores))
|
||||
|
||||
# Combine original text with scores
|
||||
for record in result_texts:
|
||||
original_text = record["text"]
|
||||
text_id = record["textId"]
|
||||
score = score_dict.get(text_id, 0)
|
||||
topN_passages.append((original_text, score))
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
|
||||
# Sort passages by score
|
||||
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
|
||||
top_texts = [item[0] for item in topN_passages][:topN]
|
||||
top_scores = [item[1] for item in topN_passages][:topN]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
|
||||
self.logger.info(f"Top passages: {top_texts}")
|
||||
self.logger.info(f"Top scores: {top_scores}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
|
||||
return top_texts, top_scores
|
||||
|
||||
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
|
||||
# extract entities from the query
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
num_entities = len(entities)
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
|
||||
initial_nodes += [str(i)for i in I[0]]
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
|
||||
name_id_map = {}
|
||||
for node_id in initial_nodes:
|
||||
name = self.convert_numeric_id_to_name(node_id)
|
||||
name_id_map[name] = node_id
|
||||
|
||||
topk_nodes = list(set(initial_nodes))
|
||||
# convert the numeric id to string and filter again then return numeric id
|
||||
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
|
||||
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
|
||||
|
||||
# Second pass: Add filtered keywords
|
||||
filtered_top_k_nodes = []
|
||||
filter_log_dict = {}
|
||||
match_threshold = 0.8
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
|
||||
for keyword in filtered_keywords:
|
||||
# Check for an exact match first
|
||||
if keyword in name_id_map:
|
||||
filtered_top_k_nodes.append(name_id_map[keyword])
|
||||
filter_log_dict[keyword] = name_id_map[keyword]
|
||||
else:
|
||||
# Look for close matches using difflib's get_close_matches
|
||||
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
|
||||
|
||||
if close_matches:
|
||||
# If a close match is found, add the corresponding node
|
||||
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
|
||||
|
||||
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
|
||||
|
||||
topk_nodes = list(set(filtered_top_k_nodes))
|
||||
if len(topk_nodes) > 2 * num_entities:
|
||||
topk_nodes = topk_nodes[:2 * num_entities]
|
||||
return topk_nodes
|
||||
|
||||
def _process_text(self, text):
|
||||
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
|
||||
text = text.lower()
|
||||
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
|
||||
return set(text.split())
|
||||
|
||||
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
|
||||
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
|
||||
if topk_nodes == []:
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
|
||||
return {}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
|
||||
freq_dict_for_nodes = {}
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
|
||||
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
|
||||
"""
|
||||
with self.neo4j_driver.session() as session:
|
||||
result = session.run(query, nodes=topk_nodes)
|
||||
for record in result:
|
||||
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
|
||||
# Create the personalization dictionary
|
||||
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve_passages(self, query):
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
|
||||
topN = self.topN
|
||||
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
|
||||
sampling_area = self.sampling_area
|
||||
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
|
||||
if personalization_dict == {}:
|
||||
return [], [0]
|
||||
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
|
||||
return topN_passages, topN_scores
|
||||
|
||||
|
||||
|
||||
469
atlas_rag/retriever/lkg_retriever/tog.py
Normal file
469
atlas_rag/retriever/lkg_retriever/tog.py
Normal file
@ -0,0 +1,469 @@
|
||||
from neo4j import GraphDatabase
|
||||
import faiss
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
import time
|
||||
import logging
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
|
||||
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
|
||||
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
|
||||
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
|
||||
node_index: faiss.Index,
|
||||
topN : int = 5,
|
||||
Dmax : int = 3,
|
||||
Wmax : int = 3,
|
||||
prune_size: int = 10,
|
||||
logger: logging.Logger = None):
|
||||
"""
|
||||
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
|
||||
|
||||
Args:
|
||||
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
|
||||
neo4j_driver (GraphDatabase): Neo4j driver for database access.
|
||||
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
|
||||
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
|
||||
node_index (faiss.Index): FAISS index for node embeddings.
|
||||
logger (Logger, optional): Logger for verbose output.
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.filter_encoder = filter_encoder
|
||||
self.node_faiss_index = node_index
|
||||
self.verbose = logger is not None
|
||||
self.logger = logger
|
||||
self.topN = topN
|
||||
self.Dmax = Dmax
|
||||
self.Wmax = Wmax
|
||||
self.prune_size = prune_size
|
||||
|
||||
def ner(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract named entities from the query text using the LLM.
|
||||
|
||||
Args:
|
||||
text (str): The query text.
|
||||
|
||||
Returns:
|
||||
List[str]: List of extracted entities.
|
||||
"""
|
||||
entities = self.llm_generator.large_kg_ner(text)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Extracted entities: {entities}")
|
||||
return entities
|
||||
|
||||
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
|
||||
"""
|
||||
Retrieve top-k nodes similar to entities in the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
top_k_nodes (int): Number of nodes to retrieve per entity.
|
||||
|
||||
Returns:
|
||||
List[str]: List of node numeric_ids.
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
ner_time = time.time() - start_time
|
||||
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
|
||||
if not entities:
|
||||
entities = [query]
|
||||
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, 3)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
|
||||
if len(I[0]) > 0: # Check if results exist
|
||||
initial_nodes.extend([str(i) for i in I[0]])
|
||||
# no need filtering as ToG pruning will handle it.
|
||||
topk_nodes_ids = list(set(initial_nodes))
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
start_time = time.time()
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.numeric_id IN $topk_nodes_ids
|
||||
RETURN n.numeric_id AS id, n.name AS name
|
||||
"""
|
||||
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
|
||||
topk_nodes_dict = {}
|
||||
for record in result:
|
||||
topk_nodes_dict[record["id"]] = record["name"]
|
||||
if self.verbose:
|
||||
neo4j_time = time.time() - start_time
|
||||
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
|
||||
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
|
||||
|
||||
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
|
||||
"""
|
||||
Expand each path by adding neighbors of the last node.
|
||||
|
||||
Args:
|
||||
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: List of expanded paths.
|
||||
"""
|
||||
last_nodes = []
|
||||
last_node_ids = []
|
||||
last_node_types = []
|
||||
|
||||
paths_end_with_text = []
|
||||
paths_end_with_text_id = []
|
||||
paths_end_with_text_type = []
|
||||
|
||||
paths_end_with_node = []
|
||||
paths_end_with_node_id = []
|
||||
paths_end_with_node_type = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanding paths, current paths: {P}")
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
if not p or not pid or not ptype: # Skip empty paths
|
||||
continue
|
||||
t = ptype[-1]
|
||||
if t == "Text":
|
||||
paths_end_with_text.append(p)
|
||||
paths_end_with_text_id.append(pid)
|
||||
paths_end_with_text_type.append(ptype)
|
||||
continue
|
||||
last_node = p[-1] # Last node in the path
|
||||
last_node_id = pid[-1] # Last node numeric_id
|
||||
|
||||
last_nodes.append(last_node)
|
||||
last_node_ids.append(last_node_id)
|
||||
last_node_types.append(t)
|
||||
|
||||
paths_end_with_node.append(p)
|
||||
paths_end_with_node_id.append(pid)
|
||||
paths_end_with_node_type.append(ptype)
|
||||
|
||||
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
|
||||
|
||||
if not last_node_ids:
|
||||
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
# Query Node relationships
|
||||
start_time = time.time()
|
||||
outgoing_query = """
|
||||
CALL apoc.cypher.runTimeboxed(
|
||||
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
|
||||
WITH n, r, m ORDER BY rand() LIMIT 60000
|
||||
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
|
||||
{last_node_ids: $last_node_ids},
|
||||
60000
|
||||
)
|
||||
YIELD value
|
||||
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
|
||||
"""
|
||||
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
|
||||
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
for record in outgoing_result]
|
||||
if self.verbose:
|
||||
outgoing_time = time.time() - start_time
|
||||
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
|
||||
|
||||
# # Query outgoing Node -> Text relationships
|
||||
# start_time = time.time()
|
||||
# outgoing_text_query = """
|
||||
# MATCH (n:Node)-[r:Source]->(t:Text)
|
||||
# WHERE n.numeric_id IN $last_node_ids
|
||||
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
|
||||
# """
|
||||
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
|
||||
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
# for record in outgoing_text_result]
|
||||
# if self.verbose:
|
||||
# outgoing_text_time = time.time() - start_time
|
||||
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
|
||||
|
||||
last_node_to_new_paths = defaultdict(list)
|
||||
last_node_to_new_paths_ids = defaultdict(list)
|
||||
last_node_to_new_paths_types = defaultdict(list)
|
||||
|
||||
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
last_node = p[-1]
|
||||
last_node_id = pid[-1]
|
||||
# Outgoing Node -> Node
|
||||
for source, source_name, rel_type, target, target_name, target_type in outgoing:
|
||||
if source == last_node_id and target_name not in p:
|
||||
new_path = p + [rel_type, target_name]
|
||||
if target_name.lower() in stopwords.words('english'):
|
||||
continue
|
||||
last_node_to_new_paths[last_node].append(new_path)
|
||||
last_node_to_new_paths_ids[last_node].append(pid + [target])
|
||||
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Outgoing Node -> Text
|
||||
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
|
||||
# if source == last_node_id and target_name not in p:
|
||||
# new_path = p + [rel_type, target_name]
|
||||
|
||||
# last_node_to_new_paths_text[last_node].append(new_path)
|
||||
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
|
||||
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Incoming Node -> Node
|
||||
# for source, rel_type, target, source_name, source_type in incoming:
|
||||
# if target == last_node_id and source not in p:
|
||||
# new_path = p + [rel_type, source_name]
|
||||
|
||||
|
||||
num_paths = 0
|
||||
for last_node, new_paths in last_node_to_new_paths.items():
|
||||
num_paths += len(new_paths)
|
||||
# for last_node, new_paths in last_node_to_new_paths_text.items():
|
||||
# num_paths += len(new_paths)
|
||||
new_paths = []
|
||||
new_pids = []
|
||||
new_ptypes = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Number of new paths before filtering: {num_paths}")
|
||||
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
|
||||
if num_paths > len(last_node_ids) * width:
|
||||
# Apply filtering when total paths exceed threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
if len(new_ps) > width:
|
||||
path_embeddings = self.filter_encoder.encode(new_ps)
|
||||
query_embeddings = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
|
||||
top_indices = np.argsort(scores)[-width:]
|
||||
new_paths.extend([new_ps[i] for i in top_indices])
|
||||
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
|
||||
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
|
||||
else:
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
else:
|
||||
# Collect all paths without filtering when total is at or below threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanded paths count: {len(new_paths)}")
|
||||
self.logger.info(f"Expanded paths: {new_paths}")
|
||||
return new_paths, new_pids, new_ptypes
|
||||
|
||||
def path_to_string(self, path: List[str]) -> str:
|
||||
"""
|
||||
Convert a path to a human-readable string for LLM rating.
|
||||
|
||||
Args:
|
||||
path (List[str]): Path as a list of node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
str: String representation of the path.
|
||||
"""
|
||||
if len(path) < 1:
|
||||
return ""
|
||||
path_str = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for i in range(0, len(path), 2):
|
||||
node_id = path[i]
|
||||
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
|
||||
node_name = result.single()["n.name"] if result.single() else node_id
|
||||
if i + 1 < len(path):
|
||||
rel_type = path[i + 1]
|
||||
path_str.append(f"{node_name} ---> {rel_type} --->")
|
||||
else:
|
||||
path_str.append(node_name)
|
||||
return " ".join(path_str).strip()
|
||||
|
||||
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
|
||||
"""
|
||||
Prune paths to keep the top N based on LLM relevance ratings.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): List of paths to prune.
|
||||
topN (int): Number of paths to retain.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: Top N paths.
|
||||
"""
|
||||
ratings = []
|
||||
path_strings = P
|
||||
|
||||
# Process paths in chunks of 10
|
||||
for i in range(0, len(path_strings), self.prune_size):
|
||||
chunk = path_strings[i:i + self.prune_size]
|
||||
|
||||
# Construct user prompt with the current chunk of paths listed
|
||||
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
|
||||
for j, path_str in enumerate(chunk, 1):
|
||||
user_prompt += f"{j + i}. {path_str}\n"
|
||||
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
|
||||
|
||||
# Define system prompt to expect a list of integers
|
||||
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
|
||||
|
||||
# Send the prompt to the language model
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
|
||||
if self.verbose:
|
||||
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
|
||||
|
||||
# Parse the response into a list of ratings
|
||||
rating_str = response.strip()
|
||||
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
|
||||
if len(chunk_ratings) > len(chunk):
|
||||
chunk_ratings = chunk_ratings[:len(chunk)]
|
||||
if self.verbose:
|
||||
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
|
||||
ratings.extend(chunk_ratings) # Concatenate ratings
|
||||
|
||||
# Ensure ratings length matches number of paths, padding with 0s if necessary
|
||||
if len(ratings) < len(path_strings):
|
||||
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
|
||||
# ratings += [0] * (len(path_strings) - len(ratings))
|
||||
# fall back to use filter encoder to get topN
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
|
||||
path_embeddings = self.filter_encoder.encode(path_strings)
|
||||
query_embedding = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embedding.T).flatten()
|
||||
top_indices = np.argsort(scores)[-topN:]
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
|
||||
elif len(ratings) > len(path_strings):
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
|
||||
ratings = ratings[:len(path_strings)]
|
||||
|
||||
# Sort indices based on ratings in descending order
|
||||
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
|
||||
|
||||
# Filter out indices where the rating is 0
|
||||
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
|
||||
|
||||
# Take the top N indices from the filtered list
|
||||
top_indices = filtered_indices[:topN]
|
||||
|
||||
# Use the filtered indices to get the top paths, PIDS, and PTYPES
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top indices after pruning: {top_indices}")
|
||||
self.logger.info(f"length of path_strings: {len(path_strings)}")
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
top_pids = [PIDS[i] for i in top_indices]
|
||||
top_ptypes = [PTYPES[i] for i in top_indices]
|
||||
|
||||
|
||||
# Log top paths if verbose mode is enabled
|
||||
if self.verbose:
|
||||
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
|
||||
|
||||
return top_paths, top_pids, top_ptypes
|
||||
|
||||
def reasoning(self, query: str, P: List[List[str]]) -> bool:
|
||||
"""
|
||||
Check if the current paths are sufficient to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): Current list of paths.
|
||||
|
||||
Returns:
|
||||
bool: True if sufficient, False otherwise.
|
||||
"""
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
if len(path) < 3:
|
||||
continue
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
triples_str = ". ".join(triples)
|
||||
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
|
||||
messages = [
|
||||
{"role": "system", "content": "Answer Yes or No only."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reasoning result: {response}")
|
||||
return "yes" in response.lower()
|
||||
|
||||
def retrieve_passages(self, query: str) -> List[str]:
|
||||
"""
|
||||
Retrieve the top N paths to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
topN (int): Number of paths to return.
|
||||
Dmax (int): Maximum depth of path expansion.
|
||||
Wmax (int): Maximum width of path expansion.
|
||||
|
||||
Returns:
|
||||
List[str]: List of triples as strings.
|
||||
"""
|
||||
topN = self.topN
|
||||
Dmax = self.Dmax
|
||||
Wmax = self.Wmax
|
||||
if self.verbose:
|
||||
self.logger.info(f"Retrieving paths for query: {query}")
|
||||
|
||||
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
|
||||
if not initial_nodes:
|
||||
if self.verbose:
|
||||
self.logger.info("No initial nodes found.")
|
||||
return []
|
||||
|
||||
P = [[node] for node in initial_nodes]
|
||||
PIDS = [[node_id] for node_id in initial_nodes_ids]
|
||||
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
|
||||
for D in range(Dmax + 1):
|
||||
if self.verbose:
|
||||
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
|
||||
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
|
||||
if not P:
|
||||
if self.verbose:
|
||||
self.logger.info("No paths to expand.")
|
||||
break
|
||||
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
|
||||
if D == Dmax:
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
|
||||
break
|
||||
if self.reasoning(query, P):
|
||||
if self.verbose:
|
||||
self.logger.info("Paths sufficient, stopping expansion.")
|
||||
break
|
||||
|
||||
# Extract final triples
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Final triples: {triples}")
|
||||
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages
|
||||
51
atlas_rag/retriever/simple_retriever.py
Normal file
51
atlas_rag/retriever/simple_retriever.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
class SimpleGraphRetriever(BaseEdgeRetriever):
|
||||
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict):
|
||||
|
||||
self.KG = data["KG"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_faiss_index = data["node_faiss_index"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
# retrieve the top k edges
|
||||
topk_edges = []
|
||||
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
|
||||
D, I = self.edge_faiss_index.search(query_embedding, topN)
|
||||
|
||||
topk_edges += [self.edge_list[i] for i in I[0]]
|
||||
|
||||
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
|
||||
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
|
||||
|
||||
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
|
||||
|
||||
class SimpleTextRetriever(BasePassageRetriever):
|
||||
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.passage_dict = passage_dict
|
||||
self.passage_list = list(passage_dict.values())
|
||||
self.passage_keys = list(passage_dict.keys())
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
|
||||
|
||||
# Retrieve top-k passages
|
||||
topk_passages = [self.passage_list[i] for i in topk_indices]
|
||||
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
|
||||
return topk_passages, topk_passages_ids
|
||||
195
atlas_rag/retriever/tog.py
Normal file
195
atlas_rag/retriever/tog.py
Normal file
@ -0,0 +1,195 @@
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
class TogRetriever(BaseEdgeRetriever):
|
||||
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
|
||||
self.KG = data["KG"]
|
||||
|
||||
self.node_list = list(self.KG.nodes)
|
||||
self.edge_list = list(self.KG.edges)
|
||||
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
|
||||
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
|
||||
|
||||
self.llm_generator:LLMGenerator = llm_generator
|
||||
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
|
||||
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.edge_embeddings = data["edge_embeddings"]
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
|
||||
def ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
|
||||
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"},
|
||||
]
|
||||
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
generated_text = response
|
||||
# print(generated_text)
|
||||
return generated_text
|
||||
|
||||
|
||||
|
||||
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
|
||||
# extract entities from the query
|
||||
entities = self.ner(query)
|
||||
entities = entities.split(", ")
|
||||
|
||||
if len(entities) == 0:
|
||||
# If the NER cannot extract any entities, we
|
||||
# use the query as the entity to do approximate search
|
||||
entities = [query]
|
||||
|
||||
# evenly distribute the topk for each entity
|
||||
topk_for_each_entity = topN//len(entities)
|
||||
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
if entity in self.node_list:
|
||||
topk_nodes.append(entity)
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
topk_for_this_entity = topk_for_each_entity + 1
|
||||
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
# Calculate similarity scores using dot product
|
||||
scores = self.node_embeddings @ entity_embedding[0].T
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
topk_nodes += [self.node_list[i] for i in top_indices]
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
return topk_nodes
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
"""
|
||||
Retrieve the top N paths that connect the entities in the query.
|
||||
Dmax is the maximum depth of the search.
|
||||
"""
|
||||
Dmax = self.inference_config.Dmax
|
||||
# in the first step, we retrieve the top k nodes
|
||||
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
|
||||
E = initial_nodes
|
||||
P = [ [e] for e in E]
|
||||
D = 0
|
||||
|
||||
while D <= Dmax:
|
||||
P = self.search(query, P)
|
||||
P = self.prune(query, P, topN)
|
||||
|
||||
if self.reasoning(query, P):
|
||||
generated_text = self.generate(query, P)
|
||||
break
|
||||
|
||||
D += 1
|
||||
|
||||
if D > Dmax:
|
||||
generated_text = self.generate(query, P)
|
||||
|
||||
# print(generated_text)
|
||||
return generated_text
|
||||
|
||||
def search(self, query, P):
|
||||
new_paths = []
|
||||
for path in P:
|
||||
tail_entity = path[-1]
|
||||
sucessors = list(self.KG.successors(tail_entity))
|
||||
predecessors = list(self.KG.predecessors(tail_entity))
|
||||
|
||||
# print(f"tail_entity: {tail_entity}")
|
||||
# print(f"sucessors: {sucessors}")
|
||||
# print(f"predecessors: {predecessors}")
|
||||
|
||||
# # print the attributes of the tail_entity
|
||||
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
|
||||
|
||||
# remove the entity that is already in the path
|
||||
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
|
||||
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
|
||||
|
||||
if len(sucessors) == 0 and len(predecessors) == 0:
|
||||
new_paths.append(path)
|
||||
continue
|
||||
for neighbour in sucessors:
|
||||
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
|
||||
new_path = path + [relation, neighbour]
|
||||
new_paths.append(new_path)
|
||||
|
||||
for neighbour in predecessors:
|
||||
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
|
||||
new_path = path + [relation, neighbour]
|
||||
new_paths.append(new_path)
|
||||
|
||||
return new_paths
|
||||
|
||||
def prune(self, query, P, topN=3):
|
||||
ratings = []
|
||||
|
||||
for path in P:
|
||||
path_string = ""
|
||||
for index, node_or_relation in enumerate(path):
|
||||
if index % 2 == 0:
|
||||
id_path = self.KG.nodes[node_or_relation]["id"]
|
||||
else:
|
||||
id_path = node_or_relation
|
||||
path_string += f"{id_path} --->"
|
||||
path_string = path_string[:-5]
|
||||
|
||||
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
|
||||
|
||||
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||||
{"role": "user", "content": f"{prompt}"}]
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
# print(response)
|
||||
rating = int(response)
|
||||
ratings.append(rating)
|
||||
|
||||
# sort the paths based on the ratings
|
||||
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
|
||||
|
||||
return sorted_paths[:topN]
|
||||
|
||||
def reasoning(self, query, P):
|
||||
triples = []
|
||||
for path in P:
|
||||
for i in range(0, len(path)-2, 2):
|
||||
|
||||
# triples.append((path[i], path[i+1], path[i+2]))
|
||||
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||||
|
||||
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||||
triples_string = ". ".join(triples_string)
|
||||
|
||||
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
|
||||
|
||||
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||||
{"role": "user", "content": f"{prompt}"}]
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
return "yes" in response.lower()
|
||||
|
||||
def generate(self, query, P):
|
||||
triples = []
|
||||
for path in P:
|
||||
for i in range(0, len(path)-2, 2):
|
||||
# triples.append((path[i], path[i+1], path[i+2]))
|
||||
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||||
|
||||
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||||
|
||||
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
|
||||
return triples_string, ["N/A" for _ in range(len(triples_string))]
|
||||
1
atlas_rag/vectorstore/__init__.py
Normal file
1
atlas_rag/vectorstore/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .create_graph_index import create_embeddings_and_index
|
||||
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user