kastan commited on
Commit
493fae2
Β·
1 Parent(s): 3cd1483

initial commit, fixing chat history

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. README.md +6 -7
  2. app.py +346 -0
  3. clip_for_ppts.py +158 -0
  4. gpu_memory_utils.py +57 -0
  5. input_features/slides_001_tensor.pt +3 -0
  6. input_features/slides_002_tensor.pt +3 -0
  7. input_features/slides_003_tensor.pt +3 -0
  8. input_features/slides_004_tensor.pt +3 -0
  9. input_features/slides_005_tensor.pt +3 -0
  10. input_features/slides_006_tensor.pt +3 -0
  11. input_features/slides_007_tensor.pt +3 -0
  12. input_features/slides_008_tensor.pt +3 -0
  13. input_features/slides_009_tensor.pt +3 -0
  14. input_features/slides_010_tensor.pt +3 -0
  15. input_features/slides_01b_tensor.pt +3 -0
  16. input_features/slides_01c_tensor.pt +3 -0
  17. input_features/slides_01d_tensor.pt +3 -0
  18. input_features/slides_020_tensor.pt +3 -0
  19. input_features/slides_021_tensor.pt +3 -0
  20. input_features/slides_022_tensor.pt +3 -0
  21. input_features/slides_023_tensor.pt +3 -0
  22. input_features/slides_024_tensor.pt +3 -0
  23. input_features/slides_025_tensor.pt +3 -0
  24. input_features/slides_026_tensor.pt +3 -0
  25. input_features/slides_027_tensor.pt +3 -0
  26. input_features/slides_028_tensor.pt +3 -0
  27. input_features/slides_040_tensor.pt +3 -0
  28. input_features/slides_041_tensor.pt +3 -0
  29. input_features/slides_042_tensor.pt +3 -0
  30. input_features/slides_043_tensor.pt +3 -0
  31. input_features/slides_044_tensor.pt +3 -0
  32. input_features/slides_045_tensor.pt +3 -0
  33. input_features/slides_046_tensor.pt +3 -0
  34. input_features/slides_047_tensor.pt +3 -0
  35. input_features/slides_048_tensor.pt +3 -0
  36. input_features/slides_049_tensor.pt +3 -0
  37. input_features/slides_050_tensor.pt +3 -0
  38. input_features/slides_051_tensor.pt +3 -0
  39. input_features/slides_052_tensor.pt +3 -0
  40. input_features/slides_053_tensor.pt +3 -0
  41. input_features/slides_054_tensor.pt +3 -0
  42. input_features/slides_055_tensor.pt +3 -0
  43. input_features/slides_056_tensor.pt +3 -0
  44. input_features/slides_057_tensor.pt +3 -0
  45. input_features/slides_058_tensor.pt +3 -0
  46. input_features/slides_059_tensor.pt +3 -0
  47. input_features/slides_060_tensor.pt +3 -0
  48. input_features/slides_080_tensor.pt +3 -0
  49. input_features/slides_081_tensor.pt +3 -0
  50. input_features/slides_082_tensor.pt +3 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Ai Teaching Assistant Beta
3
- emoji: 🏒
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: [BETA] AI Teaching Assistant
3
+ emoji: πŸ› οΈ
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.20.1
8
  app_file: app.py
9
+ pinned: False
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import retrieval
5
+ # UNCOMMENT ONLY WHEN RUNNING LOCALLY (not on Spaces)
6
+ # from dotenv import load_dotenv
7
+ from text_generation import Client, InferenceAPIClient
8
+
9
+ # load API keys from globally-availabe .env file
10
+ # SECRETS_FILEPATH = "/mnt/project/chatbotai/huggingface_cache/internal_api_keys.env"
11
+ # load_dotenv(dotenv_path=SECRETS_FILEPATH, override=True)
12
+
13
+ openchat_preprompt = (
14
+ "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
15
+ "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
16
+ "community. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
17
+ "but I am programmed to be helpful, polite, honest, and friendly. I'm really smart at answering electrical engineering questions.\n")
18
+
19
+ # LOAD MODELS
20
+ ta = retrieval.Retrieval()
21
+ NUM_ANSWERS_GENERATED = 3
22
+
23
+
24
+ def clip_img_search(img):
25
+ if img is None:
26
+ return []
27
+ else:
28
+ return ta.reverse_img_search(img)
29
+
30
+
31
+ def get_client(model: str):
32
+ if model == "Rallio67/joi2_20Be_instruct_alpha":
33
+ return Client(os.getenv("JOI_API_URL"))
34
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
35
+ return Client(os.getenv("OPENCHAT_API_URL"))
36
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
37
+
38
+
39
+ def get_usernames(model: str):
40
+ """
41
+ Returns:
42
+ (str, str, str, str): pre-prompt, username, bot name, separator
43
+ """
44
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
45
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
46
+ if model == "Rallio67/joi2_20Be_instruct_alpha":
47
+ return "", "User: ", "Joi: ", "\n\n"
48
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
49
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
50
+ return "", "User: ", "Assistant: ", "\n"
51
+
52
+
53
+ def predict(
54
+ model: str,
55
+ inputs: str,
56
+ typical_p: float,
57
+ top_p: float,
58
+ temperature: float,
59
+ top_k: int,
60
+ repetition_penalty: float,
61
+ watermark: bool,
62
+ chatbot,
63
+ history,
64
+ ):
65
+ client = get_client(model)
66
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
67
+
68
+ history.append(inputs)
69
+
70
+ past = []
71
+ for data in chatbot:
72
+ user_data, model_data = data
73
+
74
+ if not user_data.startswith(user_name):
75
+ user_data = user_name + user_data
76
+ if not model_data.startswith(sep + assistant_name):
77
+ model_data = sep + assistant_name + model_data
78
+
79
+ past.append(user_data + model_data.rstrip() + sep)
80
+
81
+ if not inputs.startswith(user_name):
82
+ inputs = user_name + inputs
83
+
84
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
85
+
86
+ partial_words = ""
87
+
88
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
89
+ iterator = client.generate_stream(
90
+ total_inputs,
91
+ typical_p=typical_p,
92
+ truncate=1000,
93
+ watermark=watermark,
94
+ max_new_tokens=500,
95
+ )
96
+ else:
97
+ iterator = client.generate_stream(
98
+ total_inputs,
99
+ top_p=top_p if top_p < 1.0 else None,
100
+ top_k=top_k,
101
+ truncate=1000,
102
+ repetition_penalty=repetition_penalty,
103
+ watermark=watermark,
104
+ temperature=temperature,
105
+ max_new_tokens=500,
106
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
107
+ )
108
+
109
+ for i, response in enumerate(iterator):
110
+ if response.token.special:
111
+ continue
112
+
113
+ partial_words = partial_words + response.token.text
114
+ if partial_words.endswith(user_name.rstrip()):
115
+ partial_words = partial_words.rstrip(user_name.rstrip())
116
+ if partial_words.endswith(assistant_name.rstrip()):
117
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
118
+
119
+ if i == 0:
120
+ history.append(" " + partial_words)
121
+ elif response.token.text not in user_name:
122
+ history[-1] = partial_words
123
+
124
+ chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
125
+ yield chat, history, None, None, None, []
126
+
127
+ # Pinecone context retrieval
128
+ top_context_list = ta.retrieve_contexts_from_pinecone(user_question=inputs, topk=NUM_ANSWERS_GENERATED)
129
+ # yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], []
130
+ yield None, None, top_context_list[0], top_context_list[1], top_context_list[2], []
131
+
132
+ # run CLIP
133
+ images_list = ta.clip_text_to_image(inputs)
134
+ # yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], images_list
135
+ yield None, None, top_context_list[0], top_context_list[1], top_context_list[2], images_list
136
+
137
+
138
+ def reset_textbox():
139
+ return gr.update(value="")
140
+
141
+
142
+ def radio_on_change(
143
+ value: str,
144
+ disclaimer,
145
+ typical_p,
146
+ top_p,
147
+ top_k,
148
+ temperature,
149
+ repetition_penalty,
150
+ watermark,
151
+ ):
152
+ if value == "OpenAssistant/oasst-sft-1-pythia-12b":
153
+ typical_p = typical_p.update(value=0.2, visible=True)
154
+ top_p = top_p.update(visible=False)
155
+ top_k = top_k.update(visible=False)
156
+ temperature = temperature.update(visible=False)
157
+ disclaimer = disclaimer.update(visible=False)
158
+ repetition_penalty = repetition_penalty.update(visible=False)
159
+ watermark = watermark.update(False)
160
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
161
+ typical_p = typical_p.update(visible=False)
162
+ top_p = top_p.update(value=0.25, visible=True)
163
+ top_k = top_k.update(value=50, visible=True)
164
+ temperature = temperature.update(value=0.6, visible=True)
165
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
166
+ watermark = watermark.update(False)
167
+ disclaimer = disclaimer.update(visible=True)
168
+ else:
169
+ typical_p = typical_p.update(visible=False)
170
+ top_p = top_p.update(value=0.95, visible=True)
171
+ top_k = top_k.update(value=4, visible=True)
172
+ temperature = temperature.update(value=0.5, visible=True)
173
+ repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
174
+ watermark = watermark.update(True)
175
+ disclaimer = disclaimer.update(visible=False)
176
+ return (
177
+ disclaimer,
178
+ typical_p,
179
+ top_p,
180
+ top_k,
181
+ temperature,
182
+ repetition_penalty,
183
+ watermark,
184
+ )
185
+
186
+
187
+ title = """<h1 align="center">πŸ”₯Teaching Assistant Chatbot"""
188
+ description = """
189
+ """
190
+
191
+ openchat_disclaimer = """
192
+ <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
193
+ """
194
+
195
+ with gr.Blocks(css="""#col_container {margin-left: auto; margin-right: auto;}
196
+ #chatbot {height: 520px; overflow: auto;}""") as demo:
197
+ gr.HTML(title)
198
+ with gr.Row():
199
+ with gr.Accordion("Model choices", open=False, visible=True):
200
+ model = gr.Radio(
201
+ value="OpenAssistant/oasst-sft-1-pythia-12b",
202
+ choices=[
203
+ "OpenAssistant/oasst-sft-1-pythia-12b",
204
+ # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
205
+ "Rallio67/joi2_20Be_instruct_alpha",
206
+ "google/flan-t5-xxl",
207
+ "google/flan-ul2",
208
+ "bigscience/bloom",
209
+ "bigscience/bloomz",
210
+ "EleutherAI/gpt-neox-20b",
211
+ ],
212
+ label="",
213
+ interactive=True,
214
+ )
215
+ # with gr.Row():
216
+ # with gr.Column():
217
+ # use_gpt3_checkbox = gr.Checkbox(label="Include GPT-3 (paid)?")
218
+ # with gr.Column():
219
+ # use_equation_checkbox = gr.Checkbox(label="Prioritize equations?")
220
+ state = gr.State([])
221
+
222
+ with gr.Row():
223
+ with gr.Column():
224
+ chatbot = gr.Chatbot(elem_id="chatbot")
225
+ inputs = gr.Textbox(placeholder="Ask an Electrical Engineering question!", label="Send a message...")
226
+ examples = gr.Examples(
227
+ examples=[
228
+ "What is a Finite State Machine?",
229
+ "How do you design a functional a Two-Bit Gray Code Counter?",
230
+ "How can we compare an 8-bit 2's complement number to the value -1 using AND, OR, and NOT?",
231
+ "What does the uninterrupted counting cycle label mean?",
232
+ ],
233
+ inputs=[inputs],
234
+ outputs=[],
235
+ )
236
+ gr.Markdown("## Relevant Textbook Passages & Lecture Transcripts")
237
+ with gr.Row():
238
+ with gr.Column():
239
+ context1 = gr.Textbox(label="Context 1")
240
+ with gr.Column():
241
+ context2 = gr.Textbox(label="Context 2")
242
+ with gr.Column():
243
+ context3 = gr.Textbox(label="Context 3")
244
+
245
+ gr.Markdown("## Relevant Lecture Slides")
246
+ with gr.Row():
247
+ with gr.Column(scale=2.6):
248
+ lec_gallery = gr.Gallery(label="Lecture images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
249
+ with gr.Column(scale=1):
250
+ inp_image = gr.Image(type="pil", label="Reverse Image Search (optional)", shape=(224, 398))
251
+
252
+ inp_image.change(fn=clip_img_search, inputs=inp_image, outputs=lec_gallery, scroll_to_output=True)
253
+ disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
254
+ # state = gr.State([])
255
+
256
+ with gr.Row():
257
+ with gr.Accordion("Parameters", open=False, visible=True):
258
+ typical_p = gr.Slider(
259
+ minimum=-0,
260
+ maximum=1.0,
261
+ value=0.2,
262
+ step=0.05,
263
+ interactive=True,
264
+ label="Typical P mass",
265
+ )
266
+ top_p = gr.Slider(
267
+ minimum=-0,
268
+ maximum=1.0,
269
+ value=0.25,
270
+ step=0.05,
271
+ interactive=True,
272
+ label="Top-p (nucleus sampling)",
273
+ visible=False,
274
+ )
275
+ temperature = gr.Slider(
276
+ minimum=-0,
277
+ maximum=5.0,
278
+ value=0.6,
279
+ step=0.1,
280
+ interactive=True,
281
+ label="Temperature",
282
+ visible=False,
283
+ )
284
+ top_k = gr.Slider(
285
+ minimum=1,
286
+ maximum=50,
287
+ value=50,
288
+ step=1,
289
+ interactive=True,
290
+ label="Top-k",
291
+ visible=False,
292
+ )
293
+ repetition_penalty = gr.Slider(
294
+ minimum=0.1,
295
+ maximum=3.0,
296
+ value=1.03,
297
+ step=0.01,
298
+ interactive=True,
299
+ label="Repetition Penalty",
300
+ visible=False,
301
+ )
302
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
303
+
304
+ model.change(
305
+ lambda value: radio_on_change(
306
+ value,
307
+ disclaimer,
308
+ typical_p,
309
+ top_p,
310
+ top_k,
311
+ temperature,
312
+ repetition_penalty,
313
+ watermark,
314
+ ),
315
+ inputs=model,
316
+ outputs=[
317
+ disclaimer,
318
+ typical_p,
319
+ top_p,
320
+ top_k,
321
+ temperature,
322
+ repetition_penalty,
323
+ watermark,
324
+ ],
325
+ )
326
+
327
+ inputs.submit(
328
+ predict,
329
+ [
330
+ model,
331
+ inputs,
332
+ typical_p,
333
+ top_p,
334
+ temperature,
335
+ top_k,
336
+ repetition_penalty,
337
+ watermark,
338
+ chatbot,
339
+ state,
340
+ ],
341
+ [chatbot, state, context1, context2, context3, lec_gallery],
342
+ )
343
+ inputs.submit(reset_textbox, [], [inputs])
344
+
345
+ gr.Markdown(description)
346
+ demo.queue(concurrency_count=16).launch(debug=True)
clip_for_ppts.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import torch
5
+ from PIL import Image
6
+
7
+ # import sys
8
+ # from pptx import Presentation
9
+ # from pptx.enum.shapes import MSO_SHAPE_TYPE
10
+ # import time
11
+
12
+
13
+ class ClipImage:
14
+
15
+ def __init__(self, path_of_ppt_folders, path_to_save_image_features, mode='image', device='cuda'):
16
+ """
17
+ :param input_image_path: path of the input image (mode = 'image') or the actual text to be searched (mode='text')
18
+ :param path_of_ppt_folders: path of the folder containing all the ppt folders
19
+ :param path_to_save_image_features: path to save the image features
20
+ :param mode: 'image' or 'text' based on the type of input
21
+ :param device: device to run the model on
22
+ """
23
+ print("HEADS UPP -- ALWAYS using CPU for this 'spaces' version of the project. Otherwise we get FP32/16 conflicts.")
24
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ device = "cpu"
26
+ # Path
27
+ directory = 'input_features'
28
+ path = os.path.join(path_to_save_image_features, directory)
29
+ if not os.path.exists(path):
30
+ # Create the directory
31
+ os.mkdir(path)
32
+ print("Directory '% s' created" % directory)
33
+
34
+ self.res = []
35
+ if not os.path.isdir(path_of_ppt_folders):
36
+ raise TypeError(f"{path_of_ppt_folders} is not a directory. Please only enter a directory")
37
+
38
+ # if mode == 'image' and not os.path.exists(input_image_path):
39
+ # raise FileNotFoundError(f"{input_image_path} does not exist.")
40
+ if not os.path.exists(path_to_save_image_features) or not os.path.isdir(path_to_save_image_features):
41
+ raise FileNotFoundError(f"{path_to_save_image_features} is not a directory or doesn't exist.")
42
+ self.mode = mode
43
+ self.path_of_ppt_folders = path_of_ppt_folders
44
+ self.path_to_save_image_features = path_to_save_image_features
45
+ self.device = device
46
+
47
+ # consider ViT-L/14 should be the best one
48
+ self.model, self.preprocess = clip.load('ViT-B/32', self.device)
49
+
50
+ #print("πŸ‘‰ RUNNING CLIP'S ONE-TIME ENCODING STEP... will be slow the first time, and hopefully only the first time.")
51
+ # passing in an image as a cheap hack, to make one funciton work for initial embedding.
52
+ #self.calculate_similarity('/home/rsalvi/chatbotai/rohan/ai-teaching-assistant-uiuc/lecture_slides/001/Slide1.jpeg')
53
+ #print("πŸ”₯ DONE with CLIP's ONE TIME ENCODING")
54
+
55
+ def text_to_image_search(self, search_text: str, top_k_to_return: int = 4):
56
+ """ Written after the fact by kastan, so that we don't have to call init every time. """
57
+ assert type(search_text) == str, f"Must provide a single string, instead I got type {type(search_text)}"
58
+ # self.create_input_features(search_text, mode='text')
59
+ self.mode = 'text'
60
+ return self.calculate_similarity(search_text, top_k_to_return)
61
+
62
+ # TODO: WIP.
63
+ def image_to_images_search(self, input_image, top_k_to_return: int = 4):
64
+ """ Written after the fact by kastan, so that we don't have to call init every time. """
65
+ self.mode = 'image'
66
+ return self.calculate_similarity(input_image, top_k_to_return)
67
+
68
+ def create_input_features(self, input_text_or_img):
69
+ if self.mode == 'image':
70
+ # Load the image
71
+ #input_image = Image.open(input_text_or_img) # Not needed as image comes from gradio in PIL format
72
+ # Preprocess the image
73
+ input_arr = torch.cat([self.preprocess(input_text_or_img).unsqueeze(0)]).to(self.device)
74
+
75
+ elif self.mode == 'text':
76
+ # Preprocess the text
77
+ input_arr = torch.cat([clip.tokenize(f"{input_text_or_img}")]).to(self.device)
78
+
79
+ # Encode the image or text
80
+ with torch.no_grad():
81
+ if self.mode == 'image':
82
+ input_features = self.model.encode_image(input_arr)
83
+ elif self.mode == 'text':
84
+ input_features = self.model.encode_text(input_arr)
85
+ input_features /= input_features.norm(dim=-1, keepdim=True)
86
+ return input_features
87
+
88
+ def new_most_similar_slide_file(self, top_k: int):
89
+ # Sort the results
90
+ ans = sorted(self.res, key=lambda x: x[2], reverse=True)
91
+ return ans[:top_k]
92
+
93
+ def calculate_similarity(self, input_text_or_img, topk_val: int = 4):
94
+ ## Similarities across folders
95
+ self.res = []
96
+ all_similarities = []
97
+ slide_numbers = []
98
+ # Create the input features
99
+ input_features = self.create_input_features(input_text_or_img)
100
+
101
+ # Iterate through all the folders
102
+ ppts = list(os.listdir(self.path_of_ppt_folders))
103
+ #start_time = time.monotonic()
104
+ for i in ppts:
105
+ # Get the path of the folder containing the ppt images
106
+ imgs = list(os.listdir(os.path.join(self.path_of_ppt_folders, i)))
107
+ slide_numbers.append(imgs)
108
+ # Iterate through all the images and preprocess them
109
+
110
+ # Check if the preprocessed file exists and load it
111
+ img_flag = os.path.exists(self.path_to_save_image_features + '/input_features' + "/slides_" + i + "_tensor.pt")
112
+ if img_flag:
113
+ image_features = torch.load(self.path_to_save_image_features + '/input_features' + "/slides_" + i + "_tensor.pt",
114
+ map_location=self.device)
115
+ else:
116
+ # Encode the images and save the encoding
117
+ with torch.no_grad():
118
+ image_input = torch.cat([
119
+ self.preprocess(Image.open(os.path.join(self.path_of_ppt_folders, i, image))).unsqueeze(0) for image in imgs
120
+ ]).to(self.device)
121
+ image_features = self.model.encode_image(image_input)
122
+ image_features /= image_features.norm(dim=-1, keepdim=True)
123
+ torch.save(image_features, self.path_to_save_image_features + '/input_features' + "/slides_" + i + "_tensor.pt")
124
+ print("Saved the image features (for faster future loading) to: ", self.path_to_save_image_features + "/slides_" + i + "_tensor.pt")
125
+
126
+ # Calculate the similarity between the input image and the images in the folder
127
+
128
+ # TODO: THIS REQUIRES REFACTOR. We're only looking in a SINGLE FOLDER. need to APPEND to similarity.
129
+ if self.mode == 'image':
130
+ similarity = (100.0 * input_features @ image_features.T).softmax(dim=-1)
131
+ all_similarities.append((i, similarity))
132
+ elif self.mode == 'text':
133
+ similarity = (100.0 * input_features @ image_features.T).softmax(dim=-1)
134
+ all_similarities.append((i, similarity))
135
+
136
+ ## Looking over all the folders
137
+ similarity_results = []
138
+
139
+ for j in range(0, len(all_similarities)):
140
+ folder_name = all_similarities[j][0]
141
+ folder_values = all_similarities[j][1][0]
142
+ for i in range(0, len(folder_values)):
143
+ self.res.append((folder_name, slide_numbers[j][i], folder_values[i]))
144
+
145
+ #print(self.res)
146
+
147
+ return self.new_most_similar_slide_file(topk_val)
148
+ # Return the sorted results
149
+
150
+
151
+ # if __name__ == "__main__":
152
+
153
+ # demo = ClipImage('/home/rsalvi/chatbotai/rohan/ai-teaching-assistant-uiuc/lecture_slides','/home/rsalvi/chatbotai/rohan/ai-teaching-assistant-uiuc')
154
+ # #op = demo.image_to_images_search('/home/rsalvi/chatbotai/rohan/ai-teaching-assistant-uiuc/lecture_slides/01c/Slide5.jpeg')
155
+ # op = demo.text_to_image_search("Unsigned Bit Pattern")
156
+ # print(op)
157
+ # op = demo.text_to_image_search("Graycode")
158
+ # print(op)
gpu_memory_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import GPUtil # pip install gputil
2
+
3
+
4
+ def get_gpu_ids_with_sufficient_memory(memory_requirement_GB):
5
+ '''
6
+ Returns the MINIMAL SET of GPU IDs that, combined, have at least `memory_requirement` MB of free memory.
7
+ You will need to use all returned GPU IDs to get the desired memory requirement.
8
+ It returns lower IDs first [0, 1, ...]
9
+
10
+ If `memory_requirement` is 0, returns all available GPUs.
11
+ If `memory_requirement` is not available, returns an empty list.
12
+ '''
13
+ memory_requirement_MB = float(memory_requirement_GB * 1024)
14
+ GPUs = sorted(GPUtil.getGPUs(), key=lambda x: x.memoryFree, reverse=True)
15
+ total_memory = sum(gpu.memoryFree for gpu in GPUs)
16
+ if memory_requirement_MB > total_memory:
17
+ return []
18
+ GPU_IDs = []
19
+ for gpu in GPUs:
20
+ if memory_requirement_MB <= 0:
21
+ break
22
+ GPU_IDs.append(gpu.id)
23
+ memory_requirement_MB -= gpu.memoryFree
24
+ return GPU_IDs
25
+
26
+
27
+ def get_device_with_most_free_memory():
28
+ '''
29
+ Returns the GPU ID of the GPU with the most free memory.
30
+ '''
31
+ GPUs = GPUtil.getGPUs()
32
+ return sorted(GPUs, key=lambda x: x.memoryFree, reverse=True)[0].id
33
+
34
+
35
+ def get_free_memory_dict(leave_extra_memory_unused_GiB: float = 2, leave_extra_memory_unused_gpu0_GiB: float = 3):
36
+ '''
37
+ Returns a dictionary of GPU IDs and their free memory, in MiB.
38
+ Compatible with huggingface Accelerate formatting: `max_memory=get_free_memory_dict()`
39
+
40
+ Accelerate seems to use more memory than we give it, so we default to telling Accelerate we have 2 GiB less than we actually do.
41
+
42
+ Example output:
43
+ {0: '24753MiB', 1: '26223MiB', 2: '25603MiB', 3: '9044MiB'}
44
+ '''
45
+ GPUs = GPUtil.getGPUs()
46
+ memory_map = {gpu.id: int(round(gpu.memoryFree)) for gpu in GPUs}
47
+ if leave_extra_memory_unused_GiB > 0:
48
+ for device_id, memory_MiB in memory_map.items():
49
+ memory_map[device_id] = memory_MiB - (leave_extra_memory_unused_GiB * 1024)
50
+ if leave_extra_memory_unused_gpu0_GiB > 0 and 0 in memory_map:
51
+ memory_map[0] = memory_map[0] - (leave_extra_memory_unused_gpu0_GiB * 1024)
52
+
53
+ # format to Accelerate's liking
54
+ for device_id, memory_MiB in memory_map.items():
55
+ memory_map[device_id] = f"{int(round(memory_MiB))}MiB"
56
+
57
+ return memory_map
input_features/slides_001_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29d55bc24135d9c7a840999d704b182524d2a414ef96c648f1d619d790a399a2
3
+ size 45833
input_features/slides_002_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1364efec8977ffb12a3e20a3cec86d988937e1b84a1442d294f82eec3a0800f9
3
+ size 27401
input_features/slides_003_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee2f1e5771e45b2d67e4ee841855825025e77a5c79d5879af7440e46e6a81559
3
+ size 37641
input_features/slides_004_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:455d953f6215e7c0252ba76cebbb62a86b17b72d9609652cb43c2698c63af0a6
3
+ size 29449
input_features/slides_005_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48495df1914efc197141f803fa8e794a299c079099358b1b7647ad02f9beb1a0
3
+ size 31497
input_features/slides_006_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d88074ab14593c6a629a1d2409a199234468cc3ff389e3fc975e70f39d49437
3
+ size 45833
input_features/slides_007_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:117dc745cd98665fb0b9e1467e84c531943c3c4fe2430c7fb490bc14a89fe6f4
3
+ size 35593
input_features/slides_008_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dcbb170a6e0f528a9b2eeef78360d17c8d2909258f5476dbdb9347cf4a95482
3
+ size 37641
input_features/slides_009_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa40bb1cb0f53e9eac9da6df47c7f841211f91b9a038f1a131977017b2839277
3
+ size 51977
input_features/slides_010_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22352c6a30fbfb07d34329edee48272fb242e37f946444a04004554252f1cea
3
+ size 25353
input_features/slides_01b_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df00185579ac41442b11a57727388a4772c7c733f88b5b73178c3ab7953e676f
3
+ size 13065
input_features/slides_01c_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f9fe4560b197ff01204723d9e80d850e565c2a55b73e53d0f30e2a33ef5fb4a
3
+ size 27401
input_features/slides_01d_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2652b373af6a44fa4765fe31beee085578c8eae10b09b96cf8f9350430f2b40
3
+ size 27401
input_features/slides_020_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34aad91d50f03078ba169d951ff276930fccb937a939d98c18efc24c7169891e
3
+ size 58121
input_features/slides_021_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a38eb4401f0c79efcf9464f5e057af72e1eb7b028a2b34aaf786bf399098bc4d
3
+ size 51977
input_features/slides_022_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb884765cd44df108f5f94eb2647e002c8fc1b6bffa056bd724c1209572d4aa2
3
+ size 39689
input_features/slides_023_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fba282de40c69beb70d29b6aa76a09294f90b6137ad76e45d7a15b4e6f3ac1f
3
+ size 33545
input_features/slides_024_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfea2c6844a5ca3e8b9094ac2aa6cee5d70d2c3bdbac67843a53d13be5414ade
3
+ size 54025
input_features/slides_025_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8408286255449cc6359f25995b9f1a774deaf4f7219aba42afd069ba30b60654
3
+ size 19209
input_features/slides_026_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45321f6ea824030e5f8a75c8b77f0806ebc811135f74334e8584da5548d9fe73
3
+ size 31497
input_features/slides_027_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3ee4e0c257a24299739b7e99b0b3e804bb2312e98c7fe5147f299d6d6964626
3
+ size 15113
input_features/slides_028_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db469ab3c7ee77fb39a724d99fcd4e33cb66365ea9983a06cc89a88aacf3a6d
3
+ size 17161
input_features/slides_040_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce583525507f5954a306d3524e3809f88ea79ff3a4357243d06fd02d3bd904d8
3
+ size 68361
input_features/slides_041_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42abf43fc51d107299f02e912b96e4128cdb0f5ab6ea54812932c91edea9223a
3
+ size 29449
input_features/slides_042_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4b96a9bbca4fcf49756870bf991647298aa0a6c78dc3cfde15bd35d0180df6
3
+ size 33545
input_features/slides_043_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:378ae1abd3bedc7a86b5d81c628d57d0cf7fbc01e12334f8e7ce69995107f048
3
+ size 68361
input_features/slides_044_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:840e80355baf9c875d54760570f3f443e0dcef40f1250643259f5ac454dec99c
3
+ size 39689
input_features/slides_045_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32cd02868694bece799b9b948e86ac4ea34cb5855f9932f6d9110d93fc63ea7f
3
+ size 29449
input_features/slides_046_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cedf85806d5b49a71c8822f3ec7fa7fd1a0338e551edde97240b73bbc869cfd6
3
+ size 43785
input_features/slides_047_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e99058a8c1286a9ba4cbeb66a5a0f9b0e20db1be2d63ec1de31575db92398a2
3
+ size 21257
input_features/slides_048_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8ab3023965fc736f2bf06165d8dfeaa4cdd2bff090d05203f404e003d65e327
3
+ size 33545
input_features/slides_049_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0070eaa1b52fa44567f36b5a72542ff9c74692c405fbf5a736d4b01c42f31090
3
+ size 35593
input_features/slides_050_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87749e4455e80a989edd986483b566dac4aec4252a46748ad0e48d5a0f7adac
3
+ size 23305
input_features/slides_051_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:305ddc477c0da9533a7d13e932c01f28ca5df5603bc7e1ae27bb358de2b391cd
3
+ size 33545
input_features/slides_052_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0005bf20b296ad6e7232dbf5712c9ac374371056d6aafb4dc225860737232f60
3
+ size 43785
input_features/slides_053_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:759b63c06ee741068d356721d6dfbc4f056f10f00a63f841574eee40d59317d8
3
+ size 19209
input_features/slides_054_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f459db293c752e5526dd278332fdb9ba60b3153fd64cf758a0d8ed3fde592a17
3
+ size 35593
input_features/slides_055_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e54f164d4b096afc3ddccf29203f6626e5e5667b2ba0de04d42e5ddcffde229
3
+ size 31497
input_features/slides_056_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f8bec9b8e0003e2967e7bc15586eba4864a368f378e4915e00ff1c031c14122
3
+ size 27401
input_features/slides_057_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fbe4fa9e280214c653fd02f6d9a34c4a596744209fa9e7a859701da6df0d553
3
+ size 37641
input_features/slides_058_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8162a4c538bb2d6b416543c1ff06546d2871f29a3ca29d7d7361028b23950a03
3
+ size 17161
input_features/slides_059_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3a84477d929188b8962524533f3d33f481d63ac3d758f406855c0654bef096
3
+ size 37641
input_features/slides_060_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:417194c8f9c5145143fb0cca9347f74c42900e2c481b5e8a08c0903c4659dc02
3
+ size 25353
input_features/slides_080_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b7d8e85788693d043c4abcff1d67b44406f6ad6c0cd36c05e39017afd4c0098
3
+ size 17161
input_features/slides_081_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331d6c59e16f36a3b9b57f48309a2bd3da120924aabeef2c734d9f0a719b3ffc
3
+ size 47881
input_features/slides_082_tensor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38b80b21e6d0946fee444896d8f173b5665c79a2b3ebc223bb69977342980669
3
+ size 29449