first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

@ -0,0 +1,267 @@
import time
import uvicorn
from fastapi import FastAPI, HTTPException, Response
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from logging import Logger
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever, BaseLargeKGEdgeRetriever
from atlas_rag.kg_construction.neo4j.utils import start_up_large_kg_index_graph
from atlas_rag.llm_generator import LLMGenerator
from neo4j import Driver
from dataclasses import dataclass
import traceback
@dataclass
class LargeKGConfig:
largekg_retriever: BaseLargeKGRetriever | BaseLargeKGEdgeRetriever = None
reader_llm_generator : LLMGenerator = None
driver: Driver = None
logger: Logger = None
is_felm: bool = False
is_mmlu: bool = False
rag_exemption_list = [
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to assess whether each text segment contains errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: 8923164*7236571?\nSegments: \n1. The product of 8923164 and 7236571 is: 6,461,216,222,844\n2. So, 8923164 multiplied by 7236571 is equal to 6,461,216,222,844.\n\nBelow are your outputs:\nAnswer: 1,2\nIt means segment 1,2 contain errors.""",
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to determine whether each text segment contains factual errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: A company offers a 10% discount on all purchases over $100. A customer purchases three items, each costing $80. Does the customer qualify for the discount?\nSegments: \n1. To solve this problem, we need to use deductive reasoning. We know that the company offers a 10% discount on purchases over $100, so we need to calculate the total cost of the customer's purchase.\n2. The customer purchased three items, each costing $80, so the total cost of the purchase is: 3 x $80 = $200.\n3. Since the total cost of the purchase is greater than $100, the customer qualifies for the discount. \n4. To calculate the discounted price, we can multiply the total cost by 0.1 (which represents the 10% discount): $200 x 0.1 = $20.\n5. So the customer is eligible for a discount of $20, and the final cost of the purchase would be: $200 - $20 = $180.\n6. Therefore, the customer would pay a total of $216 for the three items with the discount applied.\n\nBelow are your outputs:\nAnswer: 2,3,4,5,6\nIt means segment 2,3,4,5,6 contains errors.""",
]
mmlu_check_list = [
"""Given the following question and four candidate answers (A, B, C and D), choose the answer."""
]
app = FastAPI()
@app.on_event("startup")
async def startup():
global large_kg_config
start_up_large_kg_index_graph(large_kg_config.driver)
@app.on_event("shutdown")
async def shutdown():
global large_kg_config
print("Shutting down the model...")
del large_kg_config
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "test"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: str = None
name: Optional[str] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
retriever_config: Optional[dict] = {
"topN": 5,
"number_of_source_nodes_per_ner": 10,
"sampling_area": 250,
"Dmax": 2,
"Wmax": 3
}
class Config:
extra = "allow"
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global large_kg_config
try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
print(request)
# raise HTTPException(status_code=400, detail="Invalid request")
if large_kg_config.logger is not None:
large_kg_config.logger.info(f"Request: {request}")
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
is_exemption = any(exemption in question for exemption in LargeKGConfig.rag_exemption_list)
is_mmlu = any(exemption in question for exemption in LargeKGConfig.mmlu_check_list)
print(f"Is exemption: {is_exemption}, Is MMLU: {is_mmlu}")
if is_mmlu:
rag_text = question
else:
parts = question.rsplit("Question:", 1)
rag_text = parts[-1] if len(parts) > 1 else None
print(f"RAG text: {rag_text}")
if not is_exemption:
passages, passages_score = large_kg_config.largekg_retriever.retrieve_passages(rag_text)
context = "No retrieved context, Please answer the question with your own knowledge." if not passages else "\n".join([f"Passage {i+1}: {text}" for i, text in enumerate(reversed(passages))])
if is_mmlu:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""Here is the context: {context} \n\n
If the context is not useful, you can answer the question with your own knowledge. \n {question}\nThink step by step. Your response should end with 'The answer is ([the_answer_letter])' where the [the_answer_letter] is one of A, B, C and D."""
}
]
elif not is_exemption:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} Reference doc: {context}"""
}
]
else:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
if large_kg_config.logger is not None:
large_kg_config.logger.info(rag_chat_content)
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
except Exception as e:
print("ERROR: ", e)
print("Catched error")
traceback.print_exc()
system_prompt = 'You are a helpful assistant.'
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
def start_app(user_config:LargeKGConfig, host="0.0.0.0", port=10090, reload=False):
"""Function to start the FastAPI application."""
global large_kg_config
large_kg_config = user_config # Use the passed context if provided
uvicorn.run(f"atlas_rag.kg_construction.neo4j.neo4j_api:app", host=host, port=port, reload=reload)