-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
157 lines (120 loc) · 5.56 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import RagTokenizer, RagRetriever, RagModel, AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import torch
import numpy as np
import atexit
from datasets import load_from_disk
from torch.cuda.amp import autocast
import gc
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
app = FastAPI(root_path="/proxy/8000")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
rag_tokenizer = None
rag_retriever = None
rag_model = None
gpt2_tokenizer = None
gpt2_model = None
dataset = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_models():
global rag_tokenizer, rag_retriever, rag_model, gpt2_tokenizer, gpt2_model, dataset
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq", use_auth_token=huggingface_token)
rag_retriever = RagRetriever.from_pretrained("custom_rag_retriever", use_auth_token=huggingface_token)
rag_model = RagModel.from_pretrained("facebook/rag-sequence-nq", use_auth_token=huggingface_token)
gpt2_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
gpt2_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
dataset_path = "custom_dataset_with_embeddings"
dataset = load_from_disk(dataset_path)
rag_model.to(device)
gpt2_model.to(device)
initialize_models()
class Query(BaseModel):
question: str
def postprocess_output(text):
end_punctuation = {".", "!", "?"}
for i in range(len(text) - 1, -1, -1):
if text[i] in end_punctuation:
return text[:i + 1]
return text
def clean_context(context):
context = context.replace("Title: No Title\n", "")
context = context.replace("Description: ", "")
context = context.replace("More Info: No URL\n", "")
return context
@app.post("/generate")
def generate_answer(query: Query):
try:
logger.info("Received question: %s", query.question)
inputs = rag_tokenizer(query.question, return_tensors="pt").to(device)
input_ids = inputs.input_ids
logger.info("Tokenized input_ids: %s", input_ids)
question_hidden_states = rag_model.question_encoder(input_ids)[0]
logger.info("Encoded question hidden states: %s", question_hidden_states.shape)
question_hidden_states_np = question_hidden_states.detach().cpu().numpy()
question_input_ids_np = input_ids.detach().cpu().numpy()
logger.info("Calling retriever with question_input_ids: %s and question_hidden_states: %s", question_input_ids_np, question_hidden_states_np)
retrieved_docs = rag_retriever(question_input_ids=question_input_ids_np, question_hidden_states=question_hidden_states_np)
logger.info("Retrieved documents: %s", retrieved_docs)
if not retrieved_docs or 'doc_ids' not in retrieved_docs:
raise ValueError("No documents retrieved or 'doc_ids' missing in retrieved documents.")
doc_ids = retrieved_docs['doc_ids']
if doc_ids.size == 0 or np.all(doc_ids == -1):
raise ValueError("No valid document IDs found in retrieved documents.")
valid_doc_ids = doc_ids[doc_ids != -1]
if valid_doc_ids.size == 0:
raise ValueError("No valid document IDs found in retrieved documents.")
contexts = []
for doc_id in valid_doc_ids:
doc = dataset[int(doc_id)]
context = (
f"Title: {doc['title']}\n"
f"Description: {doc['text']}\n"
f"More Info: {doc.get('url', 'No URL')}\n"
)
contexts.append(context)
detailed_context = "\n\n".join(contexts)
cleaned_context = clean_context(detailed_context)
gpt2_input_text = f"Context: {cleaned_context}\n\nQuestion: {query.question}\n\nAnswer:"
gpt2_input_ids = gpt2_tokenizer(gpt2_input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids.to(device)
logger.info("Context input for GPT-2 model: %s", gpt2_input_text)
torch.cuda.empty_cache()
with autocast():
with torch.no_grad():
outputs = gpt2_model.generate(gpt2_input_ids, num_return_sequences=1, num_beams=3, max_new_tokens=150, eos_token_id=gpt2_tokenizer.eos_token_id)
answer = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = postprocess_output(answer)
logger.info("Generated answer: %s", answer)
del inputs, input_ids, question_hidden_states, retrieved_docs, outputs
gc.collect()
torch.cuda.empty_cache()
return {"answer": answer}
except Exception as e:
logger.exception("Error generating answer")
torch.cuda.empty_cache()
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
def clean_up_resources():
logger.info("Cleaning up resources...")
global rag_tokenizer, rag_retriever, rag_model, gpt2_tokenizer, gpt2_model, dataset
del rag_tokenizer, rag_retriever, rag_model, gpt2_tokenizer, gpt2_model, dataset
gc.collect()
torch.cuda.empty_cache()
atexit.register(clean_up_resources)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)