Spaces:
Runtime error
Runtime error
Commit
•
a5162e0
1
Parent(s):
ea667dd
Upload 4 files
Browse files- .env-sample +1 -0
- app.py +182 -0
- prompt_generation.py +117 -0
- utils.py +18 -0
.env-sample
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
OPENAI_API_KEY=""
|
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_chat import message
|
3 |
+
from utils import PAGE, read_pdf
|
4 |
+
from prompt_generation import OpenAILLM
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
|
10 |
+
def init():
|
11 |
+
if 'current_page' not in st.session_state:
|
12 |
+
st.session_state.current_page = PAGE.MAIN
|
13 |
+
st.session_state.mcq_question_number = 3
|
14 |
+
st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number)
|
15 |
+
st.session_state.chat_start = False
|
16 |
+
st.session_state.chat_messages = []
|
17 |
+
|
18 |
+
# Setting page title and header
|
19 |
+
st.set_page_config(page_title="AILearningBuddy", page_icon=":book:")
|
20 |
+
st.markdown("<h1 style='text-align: center;'>AI Learning Buddy</h1>", unsafe_allow_html=True)
|
21 |
+
|
22 |
+
|
23 |
+
def main_page():
|
24 |
+
# Header
|
25 |
+
st.header("Main Page")
|
26 |
+
|
27 |
+
# Upload docs
|
28 |
+
file = st.file_uploader("Upload documents", type=["pdf", "txt"])
|
29 |
+
MAX_FILE_SIZE = 2 * 1024 * 1024 # 2 MB
|
30 |
+
if file is not None:
|
31 |
+
# Check the file size
|
32 |
+
file_size = file.size
|
33 |
+
if file_size > MAX_FILE_SIZE:
|
34 |
+
st.error(f"File size should not exceed {MAX_FILE_SIZE / (1024 * 1024)} MB. Please upload a smaller file.")
|
35 |
+
else:
|
36 |
+
st.success("File uploaded successfully!")
|
37 |
+
|
38 |
+
# Read file based on its type
|
39 |
+
if file.type == "application/pdf":
|
40 |
+
text = read_pdf(file)
|
41 |
+
st.session_state.llm.upload_text(text)
|
42 |
+
elif file.type == "text/plain":
|
43 |
+
text = file.read().decode("utf-8")
|
44 |
+
st.session_state.llm.upload_text(text)
|
45 |
+
else:
|
46 |
+
st.error("Unsupported file type.")
|
47 |
+
|
48 |
+
# Display buttons if file is uploaded
|
49 |
+
if st.session_state.llm.is_text_uploaded():
|
50 |
+
col1, col2 = st.columns([1, 1])
|
51 |
+
|
52 |
+
with col1:
|
53 |
+
st.markdown("<h4>LEARN</h4>", unsafe_allow_html=True)
|
54 |
+
|
55 |
+
if st.button("Create summary", key="summary_button"):
|
56 |
+
st.session_state.current_page = PAGE.SUMMARY
|
57 |
+
st.rerun()
|
58 |
+
if st.button("Chat about the file", key="chat_button"):
|
59 |
+
st.session_state.current_page = PAGE.CHAT
|
60 |
+
st.session_state.chat_start = True
|
61 |
+
st.rerun()
|
62 |
+
with col2:
|
63 |
+
st.markdown("<h4>TEST</h4>", unsafe_allow_html=True)
|
64 |
+
|
65 |
+
if st.button("Create quiz", key="mcq_button"):
|
66 |
+
st.session_state.current_page = PAGE.MCQ
|
67 |
+
st.session_state.current_question = 0
|
68 |
+
st.rerun()
|
69 |
+
|
70 |
+
|
71 |
+
def summary_page():
|
72 |
+
# Header
|
73 |
+
if st.button(":back: Main Page"):
|
74 |
+
st.session_state.current_page = PAGE.MAIN
|
75 |
+
st.session_state.llm.empty_text()
|
76 |
+
st.rerun()
|
77 |
+
st.header("Summary")
|
78 |
+
|
79 |
+
# Get the summary
|
80 |
+
summary = st.session_state.llm.get_text_summary()
|
81 |
+
|
82 |
+
# Write summary
|
83 |
+
st.write(summary)
|
84 |
+
|
85 |
+
|
86 |
+
def chat_page():
|
87 |
+
# Header
|
88 |
+
if st.button(":back: Main Page"):
|
89 |
+
st.session_state.current_page = PAGE.MAIN
|
90 |
+
st.session_state.chat_start = False
|
91 |
+
st.session_state.chat_messages = []
|
92 |
+
st.session_state.llm.empty_text()
|
93 |
+
st.rerun()
|
94 |
+
st.header("Chat About the Document")
|
95 |
+
|
96 |
+
# Response and user container
|
97 |
+
response_container = st.container()
|
98 |
+
user_container = st.container()
|
99 |
+
|
100 |
+
with user_container:
|
101 |
+
with st.form(key='my_form', clear_on_submit=True):
|
102 |
+
user_input = st.text_area("Type here:", key='input', height=100)
|
103 |
+
send_button = st.form_submit_button(label='Send')
|
104 |
+
|
105 |
+
if send_button or st.session_state.chat_start:
|
106 |
+
# Get the model response, and save it
|
107 |
+
if st.session_state.chat_start:
|
108 |
+
user_input, model_response = st.session_state.llm.start_chat()
|
109 |
+
st.session_state.chat_start = False
|
110 |
+
else:
|
111 |
+
model_response = st.session_state.llm.get_chat_response(user_input)
|
112 |
+
st.session_state.chat_messages += [user_input, model_response]
|
113 |
+
|
114 |
+
# Display chat messages
|
115 |
+
with response_container:
|
116 |
+
for i in range(1, len(st.session_state.chat_messages)):
|
117 |
+
if i % 2:
|
118 |
+
message(st.session_state.chat_messages[i], key=f'{str(i)}_AI', avatar_style="pixel-art")
|
119 |
+
else:
|
120 |
+
message(st.session_state.chat_messages[i], is_user=True, key=f'{str(i)}_user',
|
121 |
+
avatar_style="adventurer-neutral")
|
122 |
+
|
123 |
+
|
124 |
+
def mcq_page():
|
125 |
+
# Header
|
126 |
+
if st.button(":back: Main Page"):
|
127 |
+
st.session_state.current_page = PAGE.MAIN
|
128 |
+
st.session_state.current_question = 0
|
129 |
+
st.session_state.llm.empty_text()
|
130 |
+
st.rerun()
|
131 |
+
|
132 |
+
# Setup MCQ
|
133 |
+
if st.session_state.current_question == 0:
|
134 |
+
st.session_state.llm.start_mcq()
|
135 |
+
|
136 |
+
# For every MCQ question
|
137 |
+
if st.session_state.current_question < st.session_state.mcq_question_number:
|
138 |
+
# QA header
|
139 |
+
st.header(f"Question {st.session_state.current_question + 1} / {st.session_state.mcq_question_number}")
|
140 |
+
|
141 |
+
# Generate the QA text
|
142 |
+
question, answers = st.session_state.llm.get_mcq_question()
|
143 |
+
|
144 |
+
# QA form
|
145 |
+
with st.form(key='my_form', clear_on_submit=True):
|
146 |
+
selected_answer = st.radio(f"{question}:", answers)
|
147 |
+
send_button = st.form_submit_button(label="Next")
|
148 |
+
if send_button:
|
149 |
+
print("SELECTED ANSWER: ", selected_answer)
|
150 |
+
st.session_state.llm.mcq_record_answer(selected_answer)
|
151 |
+
st.session_state.current_question += 1
|
152 |
+
else:
|
153 |
+
# Results header
|
154 |
+
st.header("Results")
|
155 |
+
|
156 |
+
# For the last QA, show score
|
157 |
+
st.session_state.current_question += 1
|
158 |
+
score, score_perc = st.session_state.llm.get_mcq_score()
|
159 |
+
st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
|
160 |
+
|
161 |
+
# List your answers and the correct ones
|
162 |
+
for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet):
|
163 |
+
question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
|
164 |
+
st.write("---")
|
165 |
+
st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
|
166 |
+
st.write(f"**Correct answer:** {answer}")
|
167 |
+
st.write(f"**User answer:** {user_answer}")
|
168 |
+
|
169 |
+
|
170 |
+
# Main structure
|
171 |
+
init()
|
172 |
+
|
173 |
+
# Page selector
|
174 |
+
match st.session_state.current_page:
|
175 |
+
case PAGE.MAIN:
|
176 |
+
main_page()
|
177 |
+
case PAGE.SUMMARY:
|
178 |
+
summary_page()
|
179 |
+
case PAGE.CHAT:
|
180 |
+
chat_page()
|
181 |
+
case PAGE.MCQ:
|
182 |
+
mcq_page()
|
prompt_generation.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import ChatOpenAI
|
2 |
+
from langchain.chains import ConversationChain
|
3 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationSummaryMemory
|
4 |
+
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
|
5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from langchain.docstore.document import Document
|
7 |
+
from langchain.chains.summarize import load_summarize_chain
|
8 |
+
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
|
12 |
+
|
13 |
+
class OpenAILLM:
|
14 |
+
def __init__(self, temperature: float = 1.,
|
15 |
+
model_name: str = 'gpt-4o',
|
16 |
+
mcq_question_number: int = 10):
|
17 |
+
# Model-related instantiations
|
18 |
+
self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
|
19 |
+
self.Memory = ConversationBufferMemory
|
20 |
+
self.chain_summary = load_summarize_chain(self.llm, chain_type="map_reduce", verbose=True)
|
21 |
+
self.chain_chat = ConversationChain(llm=self.llm, verbose=False, memory=self.Memory())
|
22 |
+
|
23 |
+
# Other utils instantiation
|
24 |
+
self.docs = []
|
25 |
+
self.text_splitter = RecursiveCharacterTextSplitter()
|
26 |
+
self.chat_document_intro = "Read the following document: "
|
27 |
+
self.chat_message_begin = "What would you like to know about the uploaded document?"
|
28 |
+
self.mcq_question_number = mcq_question_number
|
29 |
+
self.mcq_intro = """
|
30 |
+
Generate a question, correct answer and 3 possible false answers from the inputted document.
|
31 |
+
Make sure that it is unique from the ones you have generated before!
|
32 |
+
Only create 3 possible false answers and a correct answers!
|
33 |
+
"""
|
34 |
+
self.mcq_answer_sheet = []
|
35 |
+
|
36 |
+
def upload_text(self, text):
|
37 |
+
texts = self.text_splitter.split_text(text)
|
38 |
+
self.docs = [Document(text) for text in texts]
|
39 |
+
|
40 |
+
def is_text_uploaded(self):
|
41 |
+
return True if self.docs else False
|
42 |
+
|
43 |
+
def empty_text(self):
|
44 |
+
self.docs = []
|
45 |
+
self.chain_chat.memory = self.Memory()
|
46 |
+
|
47 |
+
def get_text_summary(self):
|
48 |
+
summary = self.chain_summary.run(self.docs)
|
49 |
+
return summary
|
50 |
+
|
51 |
+
def start_chat(self):
|
52 |
+
# Add document to the system's context
|
53 |
+
self.chain_chat.memory.save_context({"input": self.chat_document_intro}, {"output": ""})
|
54 |
+
for doc in self.docs:
|
55 |
+
self.chain_chat.memory.save_context({"input": doc.page_content}, {"output": ""})
|
56 |
+
|
57 |
+
return str(self.chain_chat.memory), self.chat_message_begin
|
58 |
+
|
59 |
+
def get_chat_response(self, user_input: str):
|
60 |
+
response = self.chain_chat.predict(input=user_input)
|
61 |
+
return response
|
62 |
+
|
63 |
+
def start_mcq(self):
|
64 |
+
# Instantiate response schema to define JSON output
|
65 |
+
response_schemas = [
|
66 |
+
ResponseSchema(name="question", description="Question generated from provided document."),
|
67 |
+
ResponseSchema(name="answer", description="One correct answer for the asked question."),
|
68 |
+
ResponseSchema(name="choices",
|
69 |
+
description="3 available false options for a multiple-choice question in comma separated."),
|
70 |
+
]
|
71 |
+
output_format_instructions = StructuredOutputParser.from_response_schemas(
|
72 |
+
response_schemas).get_format_instructions()
|
73 |
+
|
74 |
+
# Define the prompt that will be used for MCQ questions
|
75 |
+
prompt = PromptTemplate(
|
76 |
+
template="{task_instructions}\n {output_format_instructions}",
|
77 |
+
input_variables=["task_instructions", "output_format_instructions"]
|
78 |
+
)
|
79 |
+
|
80 |
+
# Get the MCQ query based on the prompt (by filling in the prompt values)
|
81 |
+
self.mcq_query = prompt.format(task_instructions=self.mcq_intro,
|
82 |
+
output_format_instructions=output_format_instructions)
|
83 |
+
|
84 |
+
# Uplaod the document to the model
|
85 |
+
self.start_chat()
|
86 |
+
|
87 |
+
def get_mcq_question(self):
|
88 |
+
print("HERE")
|
89 |
+
while True:
|
90 |
+
try:
|
91 |
+
response = self.chain_chat.predict(input=self.mcq_query)
|
92 |
+
print(response)
|
93 |
+
response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
|
94 |
+
|
95 |
+
question = response_parsed["question"]
|
96 |
+
answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
|
97 |
+
response_parsed["choices"].split(',')]
|
98 |
+
break
|
99 |
+
except Exception as e:
|
100 |
+
print(e)
|
101 |
+
|
102 |
+
self.mcq_answer_sheet.append({
|
103 |
+
"question": question,
|
104 |
+
"answer": answers[0],
|
105 |
+
"user_answer": None,
|
106 |
+
"choices": answers
|
107 |
+
})
|
108 |
+
return question, random.sample(answers, len(answers))
|
109 |
+
|
110 |
+
def mcq_record_answer(self, answer):
|
111 |
+
self.mcq_answer_sheet[-1]["user_answer"] = answer
|
112 |
+
|
113 |
+
def get_mcq_score(self):
|
114 |
+
score = sum([sheet['answer'] == sheet['user_answer'] for sheet in self.mcq_answer_sheet])
|
115 |
+
score_perc = round(score / self.mcq_question_number, 4) * 100
|
116 |
+
|
117 |
+
return score, score_perc
|
utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import PyPDF2
|
3 |
+
|
4 |
+
class PAGE(Enum):
|
5 |
+
MAIN = 1
|
6 |
+
SUMMARY = 2
|
7 |
+
MCQ = 3
|
8 |
+
CHAT = 4
|
9 |
+
|
10 |
+
|
11 |
+
# Function to read PDF content
|
12 |
+
def read_pdf(file):
|
13 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
14 |
+
text = ""
|
15 |
+
for page_num in range(len(pdf_reader.pages)):
|
16 |
+
page = pdf_reader.pages[page_num]
|
17 |
+
text += page.extract_text()
|
18 |
+
return text
|