tensorgirl commited on
Commit
7938cd4
·
verified ·
1 Parent(s): 79ce274

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +125 -132
main.py CHANGED
@@ -1,132 +1,125 @@
1
- from fastapi import FastAPI
2
- from app import predict
3
- import os
4
- from huggingface_hub import login
5
- from pydantic import BaseModel
6
- import sys
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.prompts import PromptTemplate
9
- from langchain.memory import ConversationBufferMemory
10
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_core.runnables import RunnablePassthrough
13
- import os
14
- import PyPDF2 as pdf
15
- import gradio as gr
16
- from langchain_community.document_loaders import PyPDFLoader
17
- import os
18
- from langchain_text_splitters import RecursiveCharacterTextSplitter
19
- from langchain_community.embeddings.sentence_transformer import (
20
- SentenceTransformerEmbeddings,
21
- )
22
- from langchain_chroma import Chroma
23
- from sentence_transformers import SentenceTransformer
24
- from langchain_core.messages import AIMessage, HumanMessage
25
- from fastapi import FastAPI, Request, UploadFile, File
26
-
27
- os.environ['HF_HOME'] = '/hug/cache/'
28
- os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
29
-
30
- app = FastAPI()
31
- app.recursion_limit = 10**4
32
-
33
- def predict(message, db):
34
-
35
- llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
36
- template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions
37
- {context}
38
- Question: {question}
39
- Helpful Answer:"""
40
- QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
41
- memory = ConversationBufferMemory(
42
- memory_key="chat_history",
43
- return_messages=True
44
- )
45
-
46
- retriever = db.as_retriever(k=3)
47
-
48
- contextualize_q_system_prompt = """Given a chat history and the latest user question \
49
- which might reference context in the chat history, formulate a standalone question \
50
- which can be understood without the chat history. Do NOT answer the question, \
51
- just reformulate it if needed and otherwise return it as is."""
52
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
53
- [
54
- ("system", contextualize_q_system_prompt),
55
- MessagesPlaceholder(variable_name="chat_history"),
56
- ("human", "{question}"),
57
- ]
58
- )
59
- contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
60
- def contextualized_question(input: dict):
61
- if input.get("chat_history"):
62
- return contextualize_q_chain
63
- else:
64
- return input["question"]
65
-
66
- rag_chain = (
67
- RunnablePassthrough.assign(
68
- context=contextualized_question | retriever
69
- )
70
- | QA_CHAIN_PROMPT
71
- | llm
72
- )
73
- history = []
74
- ai_msg = rag_chain.invoke({"question": message, "chat_history": history})
75
- print(ai_msg)
76
- bot_response = ai_msg.content.strip()
77
-
78
- # Ensure history is correctly formatted as a list of tuples (user_message, bot_response)
79
- history.append((HumanMessage(content=message), AIMessage(content=bot_response)))
80
-
81
- docs = db.similarity_search(message,k=3)
82
- extra = "\n" + "*"*100 + "\n"
83
- additional_info = []
84
- for d in docs:
85
- citations = d.metadata["source"] + " pg." + str(d.metadata["page"])
86
- additional_info = d.page_content
87
- extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n"
88
- # Return the bot's response and the updated history
89
- return bot_response + extra
90
-
91
- def upload_file(file_path):
92
-
93
- loaders = []
94
- print(file_path)
95
- loaders.append(PyPDFLoader(file_path))
96
-
97
- documents = []
98
- for loader in loaders:
99
- documents.extend(loader.load())
100
-
101
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16)
102
- docs = text_splitter.split_documents(documents)
103
-
104
- model = "thenlper/gte-large"
105
- embedding_function = SentenceTransformerEmbeddings(model_name=model)
106
- print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}")
107
- collection_name = "Autism"
108
- persist_directory = "./chroma"
109
- print(len(docs))
110
- db = Chroma.from_documents(docs, embedding_function)
111
- print("Done Processing, you can query")
112
-
113
- return db
114
-
115
-
116
- class Item(BaseModel):
117
- code: str
118
-
119
- @app.get("/")
120
- async def root():
121
- return {"Code Review Automation":"Version 1.0 'First Draft'"}
122
-
123
- @app.post("/UploadFile/")
124
- def predict(question: str, file: UploadFile = File(...)):
125
- contents = file.file.read()
126
- with open(file.filename, 'wb') as f:
127
- f.write(contents)
128
-
129
- db = upload_file(file.filename)
130
- result = predict(question, db)
131
- return {"answer":result}
132
-
 
1
+ from fastapi import FastAPI
2
+ import os
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain.memory import ConversationBufferMemory
6
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ import os
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ import os
12
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
+ from langchain_community.embeddings.sentence_transformer import (
14
+ SentenceTransformerEmbeddings,
15
+ )
16
+ from langchain_chroma import Chroma
17
+ from sentence_transformers import SentenceTransformer
18
+ from langchain_core.messages import AIMessage, HumanMessage
19
+ from fastapi import FastAPI, Request, UploadFile, File
20
+
21
+ os.environ['HF_HOME'] = '/hug/cache/'
22
+ os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
23
+
24
+ app = FastAPI()
25
+
26
+ def predict(message, db):
27
+
28
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
29
+ template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions
30
+ {context}
31
+ Question: {question}
32
+ Helpful Answer:"""
33
+ QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
34
+ memory = ConversationBufferMemory(
35
+ memory_key="chat_history",
36
+ return_messages=True
37
+ )
38
+
39
+ retriever = db.as_retriever(k=3)
40
+
41
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
42
+ which might reference context in the chat history, formulate a standalone question \
43
+ which can be understood without the chat history. Do NOT answer the question, \
44
+ just reformulate it if needed and otherwise return it as is."""
45
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
46
+ [
47
+ ("system", contextualize_q_system_prompt),
48
+ MessagesPlaceholder(variable_name="chat_history"),
49
+ ("human", "{question}"),
50
+ ]
51
+ )
52
+ contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
53
+ def contextualized_question(input: dict):
54
+ if input.get("chat_history"):
55
+ return contextualize_q_chain
56
+ else:
57
+ return input["question"]
58
+
59
+ rag_chain = (
60
+ RunnablePassthrough.assign(
61
+ context=contextualized_question | retriever
62
+ )
63
+ | QA_CHAIN_PROMPT
64
+ | llm
65
+ )
66
+ history = []
67
+ ai_msg = rag_chain.invoke({"question": message, "chat_history": history})
68
+ print(ai_msg)
69
+ bot_response = ai_msg.content.strip()
70
+
71
+ # Ensure history is correctly formatted as a list of tuples (user_message, bot_response)
72
+ history.append((HumanMessage(content=message), AIMessage(content=bot_response)))
73
+
74
+ docs = db.similarity_search(message,k=3)
75
+ extra = "\n" + "*"*100 + "\n"
76
+ additional_info = []
77
+ for d in docs:
78
+ citations = d.metadata["source"] + " pg." + str(d.metadata["page"])
79
+ additional_info = d.page_content
80
+ extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n"
81
+ # Return the bot's response and the updated history
82
+ return bot_response + extra
83
+
84
+ def upload_file(file_path):
85
+
86
+ loaders = []
87
+ print(file_path)
88
+ loaders.append(PyPDFLoader(file_path))
89
+
90
+ documents = []
91
+ for loader in loaders:
92
+ documents.extend(loader.load())
93
+
94
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16)
95
+ docs = text_splitter.split_documents(documents)
96
+
97
+ model = "thenlper/gte-large"
98
+ embedding_function = SentenceTransformerEmbeddings(model_name=model)
99
+ print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}")
100
+ collection_name = "Autism"
101
+ persist_directory = "./chroma"
102
+ print(len(docs))
103
+ db = Chroma.from_documents(docs, embedding_function)
104
+ print("Done Processing, you can query")
105
+
106
+ return db
107
+
108
+
109
+ class Item(BaseModel):
110
+ code: str
111
+
112
+ @app.get("/")
113
+ async def root():
114
+ return {"Code Review Automation":"Version 1.0 'First Draft'"}
115
+
116
+ @app.post("/UploadFile/")
117
+ def predict(question: str, file: UploadFile = File(...)):
118
+ contents = file.file.read()
119
+ with open(file.filename, 'wb') as f:
120
+ f.write(contents)
121
+
122
+ db = upload_file(file.filename)
123
+ result = predict(question, db)
124
+ return {"answer":result}
125
+