brunogreen25 commited on
Commit
a5162e0
1 Parent(s): ea667dd

Upload 4 files

Browse files
Files changed (4) hide show
  1. .env-sample +1 -0
  2. app.py +182 -0
  3. prompt_generation.py +117 -0
  4. utils.py +18 -0
.env-sample ADDED
@@ -0,0 +1 @@
 
 
1
+ OPENAI_API_KEY=""
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ from utils import PAGE, read_pdf
4
+ from prompt_generation import OpenAILLM
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+
10
+ def init():
11
+ if 'current_page' not in st.session_state:
12
+ st.session_state.current_page = PAGE.MAIN
13
+ st.session_state.mcq_question_number = 3
14
+ st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number)
15
+ st.session_state.chat_start = False
16
+ st.session_state.chat_messages = []
17
+
18
+ # Setting page title and header
19
+ st.set_page_config(page_title="AILearningBuddy", page_icon=":book:")
20
+ st.markdown("<h1 style='text-align: center;'>AI Learning Buddy</h1>", unsafe_allow_html=True)
21
+
22
+
23
+ def main_page():
24
+ # Header
25
+ st.header("Main Page")
26
+
27
+ # Upload docs
28
+ file = st.file_uploader("Upload documents", type=["pdf", "txt"])
29
+ MAX_FILE_SIZE = 2 * 1024 * 1024 # 2 MB
30
+ if file is not None:
31
+ # Check the file size
32
+ file_size = file.size
33
+ if file_size > MAX_FILE_SIZE:
34
+ st.error(f"File size should not exceed {MAX_FILE_SIZE / (1024 * 1024)} MB. Please upload a smaller file.")
35
+ else:
36
+ st.success("File uploaded successfully!")
37
+
38
+ # Read file based on its type
39
+ if file.type == "application/pdf":
40
+ text = read_pdf(file)
41
+ st.session_state.llm.upload_text(text)
42
+ elif file.type == "text/plain":
43
+ text = file.read().decode("utf-8")
44
+ st.session_state.llm.upload_text(text)
45
+ else:
46
+ st.error("Unsupported file type.")
47
+
48
+ # Display buttons if file is uploaded
49
+ if st.session_state.llm.is_text_uploaded():
50
+ col1, col2 = st.columns([1, 1])
51
+
52
+ with col1:
53
+ st.markdown("<h4>LEARN</h4>", unsafe_allow_html=True)
54
+
55
+ if st.button("Create summary", key="summary_button"):
56
+ st.session_state.current_page = PAGE.SUMMARY
57
+ st.rerun()
58
+ if st.button("Chat about the file", key="chat_button"):
59
+ st.session_state.current_page = PAGE.CHAT
60
+ st.session_state.chat_start = True
61
+ st.rerun()
62
+ with col2:
63
+ st.markdown("<h4>TEST</h4>", unsafe_allow_html=True)
64
+
65
+ if st.button("Create quiz", key="mcq_button"):
66
+ st.session_state.current_page = PAGE.MCQ
67
+ st.session_state.current_question = 0
68
+ st.rerun()
69
+
70
+
71
+ def summary_page():
72
+ # Header
73
+ if st.button(":back: Main Page"):
74
+ st.session_state.current_page = PAGE.MAIN
75
+ st.session_state.llm.empty_text()
76
+ st.rerun()
77
+ st.header("Summary")
78
+
79
+ # Get the summary
80
+ summary = st.session_state.llm.get_text_summary()
81
+
82
+ # Write summary
83
+ st.write(summary)
84
+
85
+
86
+ def chat_page():
87
+ # Header
88
+ if st.button(":back: Main Page"):
89
+ st.session_state.current_page = PAGE.MAIN
90
+ st.session_state.chat_start = False
91
+ st.session_state.chat_messages = []
92
+ st.session_state.llm.empty_text()
93
+ st.rerun()
94
+ st.header("Chat About the Document")
95
+
96
+ # Response and user container
97
+ response_container = st.container()
98
+ user_container = st.container()
99
+
100
+ with user_container:
101
+ with st.form(key='my_form', clear_on_submit=True):
102
+ user_input = st.text_area("Type here:", key='input', height=100)
103
+ send_button = st.form_submit_button(label='Send')
104
+
105
+ if send_button or st.session_state.chat_start:
106
+ # Get the model response, and save it
107
+ if st.session_state.chat_start:
108
+ user_input, model_response = st.session_state.llm.start_chat()
109
+ st.session_state.chat_start = False
110
+ else:
111
+ model_response = st.session_state.llm.get_chat_response(user_input)
112
+ st.session_state.chat_messages += [user_input, model_response]
113
+
114
+ # Display chat messages
115
+ with response_container:
116
+ for i in range(1, len(st.session_state.chat_messages)):
117
+ if i % 2:
118
+ message(st.session_state.chat_messages[i], key=f'{str(i)}_AI', avatar_style="pixel-art")
119
+ else:
120
+ message(st.session_state.chat_messages[i], is_user=True, key=f'{str(i)}_user',
121
+ avatar_style="adventurer-neutral")
122
+
123
+
124
+ def mcq_page():
125
+ # Header
126
+ if st.button(":back: Main Page"):
127
+ st.session_state.current_page = PAGE.MAIN
128
+ st.session_state.current_question = 0
129
+ st.session_state.llm.empty_text()
130
+ st.rerun()
131
+
132
+ # Setup MCQ
133
+ if st.session_state.current_question == 0:
134
+ st.session_state.llm.start_mcq()
135
+
136
+ # For every MCQ question
137
+ if st.session_state.current_question < st.session_state.mcq_question_number:
138
+ # QA header
139
+ st.header(f"Question {st.session_state.current_question + 1} / {st.session_state.mcq_question_number}")
140
+
141
+ # Generate the QA text
142
+ question, answers = st.session_state.llm.get_mcq_question()
143
+
144
+ # QA form
145
+ with st.form(key='my_form', clear_on_submit=True):
146
+ selected_answer = st.radio(f"{question}:", answers)
147
+ send_button = st.form_submit_button(label="Next")
148
+ if send_button:
149
+ print("SELECTED ANSWER: ", selected_answer)
150
+ st.session_state.llm.mcq_record_answer(selected_answer)
151
+ st.session_state.current_question += 1
152
+ else:
153
+ # Results header
154
+ st.header("Results")
155
+
156
+ # For the last QA, show score
157
+ st.session_state.current_question += 1
158
+ score, score_perc = st.session_state.llm.get_mcq_score()
159
+ st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
160
+
161
+ # List your answers and the correct ones
162
+ for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet):
163
+ question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
164
+ st.write("---")
165
+ st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
166
+ st.write(f"**Correct answer:** {answer}")
167
+ st.write(f"**User answer:** {user_answer}")
168
+
169
+
170
+ # Main structure
171
+ init()
172
+
173
+ # Page selector
174
+ match st.session_state.current_page:
175
+ case PAGE.MAIN:
176
+ main_page()
177
+ case PAGE.SUMMARY:
178
+ summary_page()
179
+ case PAGE.CHAT:
180
+ chat_page()
181
+ case PAGE.MCQ:
182
+ mcq_page()
prompt_generation.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain.chains import ConversationChain
3
+ from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationSummaryMemory
4
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.docstore.document import Document
7
+ from langchain.chains.summarize import load_summarize_chain
8
+ from langchain.output_parsers import StructuredOutputParser, ResponseSchema
9
+ import json
10
+ import random
11
+
12
+
13
+ class OpenAILLM:
14
+ def __init__(self, temperature: float = 1.,
15
+ model_name: str = 'gpt-4o',
16
+ mcq_question_number: int = 10):
17
+ # Model-related instantiations
18
+ self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
19
+ self.Memory = ConversationBufferMemory
20
+ self.chain_summary = load_summarize_chain(self.llm, chain_type="map_reduce", verbose=True)
21
+ self.chain_chat = ConversationChain(llm=self.llm, verbose=False, memory=self.Memory())
22
+
23
+ # Other utils instantiation
24
+ self.docs = []
25
+ self.text_splitter = RecursiveCharacterTextSplitter()
26
+ self.chat_document_intro = "Read the following document: "
27
+ self.chat_message_begin = "What would you like to know about the uploaded document?"
28
+ self.mcq_question_number = mcq_question_number
29
+ self.mcq_intro = """
30
+ Generate a question, correct answer and 3 possible false answers from the inputted document.
31
+ Make sure that it is unique from the ones you have generated before!
32
+ Only create 3 possible false answers and a correct answers!
33
+ """
34
+ self.mcq_answer_sheet = []
35
+
36
+ def upload_text(self, text):
37
+ texts = self.text_splitter.split_text(text)
38
+ self.docs = [Document(text) for text in texts]
39
+
40
+ def is_text_uploaded(self):
41
+ return True if self.docs else False
42
+
43
+ def empty_text(self):
44
+ self.docs = []
45
+ self.chain_chat.memory = self.Memory()
46
+
47
+ def get_text_summary(self):
48
+ summary = self.chain_summary.run(self.docs)
49
+ return summary
50
+
51
+ def start_chat(self):
52
+ # Add document to the system's context
53
+ self.chain_chat.memory.save_context({"input": self.chat_document_intro}, {"output": ""})
54
+ for doc in self.docs:
55
+ self.chain_chat.memory.save_context({"input": doc.page_content}, {"output": ""})
56
+
57
+ return str(self.chain_chat.memory), self.chat_message_begin
58
+
59
+ def get_chat_response(self, user_input: str):
60
+ response = self.chain_chat.predict(input=user_input)
61
+ return response
62
+
63
+ def start_mcq(self):
64
+ # Instantiate response schema to define JSON output
65
+ response_schemas = [
66
+ ResponseSchema(name="question", description="Question generated from provided document."),
67
+ ResponseSchema(name="answer", description="One correct answer for the asked question."),
68
+ ResponseSchema(name="choices",
69
+ description="3 available false options for a multiple-choice question in comma separated."),
70
+ ]
71
+ output_format_instructions = StructuredOutputParser.from_response_schemas(
72
+ response_schemas).get_format_instructions()
73
+
74
+ # Define the prompt that will be used for MCQ questions
75
+ prompt = PromptTemplate(
76
+ template="{task_instructions}\n {output_format_instructions}",
77
+ input_variables=["task_instructions", "output_format_instructions"]
78
+ )
79
+
80
+ # Get the MCQ query based on the prompt (by filling in the prompt values)
81
+ self.mcq_query = prompt.format(task_instructions=self.mcq_intro,
82
+ output_format_instructions=output_format_instructions)
83
+
84
+ # Uplaod the document to the model
85
+ self.start_chat()
86
+
87
+ def get_mcq_question(self):
88
+ print("HERE")
89
+ while True:
90
+ try:
91
+ response = self.chain_chat.predict(input=self.mcq_query)
92
+ print(response)
93
+ response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
94
+
95
+ question = response_parsed["question"]
96
+ answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
97
+ response_parsed["choices"].split(',')]
98
+ break
99
+ except Exception as e:
100
+ print(e)
101
+
102
+ self.mcq_answer_sheet.append({
103
+ "question": question,
104
+ "answer": answers[0],
105
+ "user_answer": None,
106
+ "choices": answers
107
+ })
108
+ return question, random.sample(answers, len(answers))
109
+
110
+ def mcq_record_answer(self, answer):
111
+ self.mcq_answer_sheet[-1]["user_answer"] = answer
112
+
113
+ def get_mcq_score(self):
114
+ score = sum([sheet['answer'] == sheet['user_answer'] for sheet in self.mcq_answer_sheet])
115
+ score_perc = round(score / self.mcq_question_number, 4) * 100
116
+
117
+ return score, score_perc
utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import PyPDF2
3
+
4
+ class PAGE(Enum):
5
+ MAIN = 1
6
+ SUMMARY = 2
7
+ MCQ = 3
8
+ CHAT = 4
9
+
10
+
11
+ # Function to read PDF content
12
+ def read_pdf(file):
13
+ pdf_reader = PyPDF2.PdfReader(file)
14
+ text = ""
15
+ for page_num in range(len(pdf_reader.pages)):
16
+ page = pdf_reader.pages[page_num]
17
+ text += page.extract_text()
18
+ return text