imSleepy commited on
Commit
ac47d9a
·
verified ·
1 Parent(s): 38e1b11

added fastAPI at chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +66 -55
chatbot.py CHANGED
@@ -1,55 +1,66 @@
1
- from transformers import T5Tokenizer, T5ForConditionalGeneration
2
- from sentence_transformers import SentenceTransformer
3
- from pinecone import Pinecone
4
-
5
- device = 'cpu'
6
-
7
- # Initialize Pinecone instance
8
- pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
9
-
10
- # Check if the index exists; if not, create it
11
- index_name = 'abstractive-question-answering'
12
- index = pc.Index(index_name)
13
-
14
- def load_models():
15
- print("Loading models...")
16
-
17
- retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
18
- tokenizer = T5Tokenizer.from_pretrained('t5-base')
19
- generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
20
-
21
- return retriever, generator, tokenizer
22
-
23
- retriever, generator, tokenizer = load_models()
24
-
25
- def process_query(query):
26
- # Query Pinecone
27
- xq = retriever.encode([query]).tolist()
28
- xc = index.query(vector=xq, top_k=1, include_metadata=True)
29
-
30
- # Print the response to check the structure
31
- print("Pinecone response:", xc)
32
-
33
- # Check if 'matches' exists and is a list
34
- if 'matches' in xc and isinstance(xc['matches'], list):
35
- context = [m['metadata']['Output'] for m in xc['matches']]
36
- context_str = " ".join(context)
37
- formatted_query = f"answer the question: {query} context: {context_str}"
38
- else:
39
- # Handle the case where 'matches' isn't found or isn't in the expected format
40
- context_str = ""
41
- formatted_query = f"answer the question: {query} context: {context_str}"
42
-
43
- # Generate answer using T5 model
44
- output_text = context_str
45
- if len(output_text.splitlines()) > 5:
46
- return output_text
47
-
48
- if output_text.lower() == "none":
49
- return "The topic is not covered in the student manual."
50
-
51
- inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
52
- ids = generator.generate(inputs, num_beams=4, min_length=10, max_length=60, repetition_penalty=1.2)
53
- answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
54
-
55
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ from sentence_transformers import SentenceTransformer
5
+ from pinecone import Pinecone
6
+
7
+ device = 'cpu'
8
+
9
+ # Initialize Pinecone instance
10
+ pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
11
+
12
+ # Check if the index exists; if not, create it
13
+ index_name = 'abstractive-question-answering'
14
+ index = pc.Index(index_name)
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI()
18
+
19
+ # Initialize the models
20
+ def load_models():
21
+ print("Loading models...")
22
+
23
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
24
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
25
+ generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
26
+
27
+ return retriever, generator, tokenizer
28
+
29
+ retriever, generator, tokenizer = load_models()
30
+
31
+ class QueryInput(BaseModel):
32
+ input: str
33
+
34
+ @app.post("/predict")
35
+ def predict(query: QueryInput):
36
+ query_text = query.input
37
+ # Query Pinecone
38
+ xq = retriever.encode([query_text]).tolist()
39
+ xc = index.query(vector=xq, top_k=1, include_metadata=True)
40
+
41
+ # Check if 'matches' exists and is a list
42
+ if 'matches' in xc and isinstance(xc['matches'], list):
43
+ context = [m['metadata']['Output'] for m in xc['matches']]
44
+ context_str = " ".join(context)
45
+ formatted_query = f"answer the question: {query_text} context: {context_str}"
46
+ else:
47
+ # Handle the case where 'matches' isn't found or isn't in the expected format
48
+ context_str = ""
49
+ formatted_query = f"answer the question: {query_text} context: {context_str}"
50
+
51
+ # Generate answer using T5 model
52
+ output_text = context_str
53
+ if len(output_text.splitlines()) > 5:
54
+ return {"response": output_text}
55
+
56
+ if output_text.lower() == "none":
57
+ return {"response": "The topic is not covered in the student manual."}
58
+
59
+ inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
60
+ ids = generator.generate(inputs, num_beams=4, min_length=10, max_length=60, repetition_penalty=1.2)
61
+ answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
62
+
63
+ return {"response": answer}
64
+
65
+ # To run the server (use uvicorn when deploying):
66
+ # uvicorn chatbot:app --reload