Tonic commited on
Commit
5fda074
1 Parent(s): 54f1ed7

update demo forNV1bed

Browse files
.github/workflows/publish.yml DELETED
@@ -1,29 +0,0 @@
1
- name: Python package
2
- on:
3
- push:
4
- tags:
5
- - "v*.*.*"
6
- jobs:
7
- build:
8
- runs-on: ubuntu-latest
9
- steps:
10
- - uses: actions/checkout@v3
11
- - name: Set up Python 3.11
12
- uses: actions/setup-python@v4
13
- with:
14
- python-version: 3.11
15
- - name: Install python dependencies
16
- run: |
17
- pip install poetry
18
- poetry install
19
- poetry remove torch
20
- poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
21
- - name: Build package
22
- run: |
23
- poetry build
24
- - name: Publish package
25
- env:
26
- PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
27
- run: |
28
- poetry config pypi-token.pypi "$PYPI_TOKEN"
29
- poetry publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/tests.yml DELETED
@@ -1,34 +0,0 @@
1
- name: Integration test
2
-
3
- on: [push]
4
-
5
- env:
6
- TORCH_DEVICE: "cpu"
7
-
8
- jobs:
9
- build:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - uses: actions/checkout@v3
13
- - name: Set up Python 3.11
14
- uses: actions/setup-python@v4
15
- with:
16
- python-version: 3.11
17
- - name: Install python dependencies
18
- run: |
19
- pip install poetry
20
- poetry install
21
- poetry remove torch
22
- poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
23
- - name: Download benchmark data
24
- run: |
25
- wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1dbY0kBq2SUa885gmbLPUWSRzy5K7O5XJ"
26
- unzip benchmark_data.zip
27
- mv bench_data.json data/bench_data.json
28
- - name: Run benchmark test
29
- run: |
30
- poetry run texify_benchmark --max 16
31
- poetry run python scripts/verify_benchmark_scores.py data/bench_results.json
32
-
33
-
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,9 +1,9 @@
1
  ---
2
  license: mit
3
- title: Tonic's e5
4
  sdk: gradio
5
- emoji: 🐣🛌🏻🤗
6
- colorFrom: red
7
  colorTo: green
8
  pinned: true
9
  app_file: app.py
 
1
  ---
2
  license: mit
3
+ title: Tonic's NV-Embed
4
  sdk: gradio
5
+ emoji: n📽️n🛌🏻
6
+ colorFrom: pink
7
  colorTo: green
8
  pinned: true
9
  app_file: app.py
app.py CHANGED
@@ -6,16 +6,18 @@ import threading
6
  import queue
7
  import gradio as gr
8
  import os
 
 
 
9
 
10
  title = """
11
- # 👋🏻Welcome to 🙋🏻‍♂️Tonic's 🐣e5-mistral🛌🏻Embeddings """
 
12
  description = """
13
- You can use this ZeroGPU Space to test out the current model [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct). 🐣e5-mistral🛌🏻 has a larger context🪟window, a different prompting/return🛠️mechanism and generally better results than other embedding models. use it via API to create embeddings or try out the sentence similarity to see how various optimization parameters affect performance.
14
- You can also use 🐣e5-mistral🛌🏻 by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/e5?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></h3>
15
- Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
16
  """
17
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  tasks = {
21
  'ArguAna': 'Given a claim, find documents that refute the claim',
@@ -31,17 +33,45 @@ tasks = {
31
  'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
32
  'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
33
  'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
34
- 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  }
36
 
 
 
37
 
38
- # Global queue for embedding requests
 
 
 
 
39
  embedding_request_queue = queue.Queue()
40
  embedding_response_queue = queue.Queue()
41
 
42
-
43
- tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
44
- model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', torch_dtype=torch.float16, device_map=device)
45
 
46
  def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
47
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
@@ -52,18 +82,22 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
52
  batch_size = last_hidden_states.shape[0]
53
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
54
 
55
- def clear_cuda_cache():
56
- torch.cuda.empty_cache()
57
-
58
- def free_memory(*args):
59
- for arg in args:
60
- del arg
61
-
62
- def load_corpus_from_json(file_path):
63
- with open(file_path, 'r') as file:
64
- data = json.load(file)
65
- return data
66
-
 
 
 
 
67
 
68
  def embedding_worker():
69
  while True:
@@ -79,14 +113,13 @@ def embedding_worker():
79
  embedding_request_queue.task_done()
80
  clear_cuda_cache()
81
 
82
- threading.Thread(target=embedding_worker, daemon=True).start()
83
-
84
  def compute_embeddings(selected_task, input_text):
85
  try:
86
  task_description = tasks[selected_task]
87
  except KeyError:
88
  print(f"Selected task not found: {selected_task}")
89
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
 
90
  max_length = 2048
91
  processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
92
 
@@ -101,124 +134,42 @@ def compute_embeddings(selected_task, input_text):
101
  clear_cuda_cache()
102
  return embeddings_list
103
 
104
- def decode_embedding(embedding_str):
105
- try:
106
- embedding = [float(num) for num in embedding_str.split(',')]
107
- embedding_tensor = torch.tensor(embedding, dtype=torch.float16, device=device)
108
- decoded_embedding = tokenizer.decode(embedding_tensor[0], skip_special_tokens=True)
109
- return decoded_embedding.cpu().numpy().tolist()
110
- except Exception as e:
111
- return f"Error in decoding: {str(e)}"
112
-
113
  def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, extra_sentence2):
114
  try:
115
  task_description = tasks[selected_task]
116
  except KeyError:
117
  print(f"Selected task not found: {selected_task}")
118
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
 
119
  # Compute embeddings for each sentence
120
  embeddings1 = compute_embeddings(selected_task, sentence1)
121
  embeddings2 = compute_embeddings(selected_task, sentence2)
122
  embeddings3 = compute_embeddings(selected_task, extra_sentence1)
123
  embeddings4 = compute_embeddings(selected_task, extra_sentence2)
124
 
125
- # Convert embeddings to tensors
126
- embeddings_tensor1 = torch.tensor(embeddings1).to(device).half()
127
- embeddings_tensor2 = torch.tensor(embeddings2).to(device).half()
128
- embeddings_tensor3 = torch.tensor(embeddings3).to(device).half()
129
- embeddings_tensor4 = torch.tensor(embeddings4).to(device).half()
130
-
131
- # Compute cosine similarity
132
  similarity1 = compute_cosine_similarity(embeddings1, embeddings2)
133
  similarity2 = compute_cosine_similarity(embeddings1, embeddings3)
134
  similarity3 = compute_cosine_similarity(embeddings1, embeddings4)
135
 
136
- # Free memory
137
- free_memory(embeddings1, embeddings2, embeddings3, embeddings4)
138
-
139
  similarity_scores = {"Similarity 1-2": similarity1, "Similarity 1-3": similarity2, "Similarity 1-4": similarity3}
140
  clear_cuda_cache()
141
  return similarity_scores
142
-
143
  def compute_cosine_similarity(emb1, emb2):
144
  tensor1 = torch.tensor(emb1).to(device).half()
145
  tensor2 = torch.tensor(emb2).to(device).half()
146
  similarity = F.cosine_similarity(tensor1, tensor2).item()
147
- free_memory(tensor1, tensor2)
148
  clear_cuda_cache()
149
  return similarity
150
 
151
-
152
- def compute_embeddings_batch(input_texts):
153
- max_length = 2042
154
- processed_texts = [f'Instruct: {task_description}\nQuery: {text}' for text in input_texts]
155
-
156
- batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
157
- batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
158
- batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
159
- batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
160
- outputs = model(**batch_dict)
161
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
162
- embeddings = F.normalize(embeddings, p=2, dim=1)
163
- clear_cuda_cache()
164
- return embeddings.detach().cpu().numpy()
165
-
166
- def semantic_search(query_embedding, corpus_embeddings, top_k=5):
167
- scores = np.dot(corpus_embeddings, query_embedding.T).flatten()
168
- top_k_indices = np.argsort(scores)[::-1][:top_k]
169
- return top_k_indices, scores[top_k_indices]
170
-
171
- def search_similar_sentences(input_question, corpus_sentences, corpus_embeddings):
172
- question_embedding = compute_embeddings_batch([input_question])[0]
173
- top_k_indices, top_k_scores = semantic_search(question_embedding, corpus_embeddings)
174
- results = [(corpus_sentences[i], top_k_scores[i]) for i in top_k_indices]
175
- return results
176
-
177
- # openai response object formatting
178
- def format_response(embeddings):
179
- return {
180
- "data": [
181
- {
182
- "embedding": embeddings,
183
- "index": 0,
184
- "object": "embedding"
185
- }
186
- ],
187
- "model": "e5-mistral",
188
- "object": "list",
189
- "usage": {
190
- "prompt_tokens": 17,
191
- "total_tokens": 17
192
- }
193
- }
194
-
195
- def generate_and_format_embeddings(selected_task, input_text):
196
- embedding_request_queue.put((selected_task, input_text))
197
- response = embedding_response_queue.get()
198
- embedding_response_queue.task_done()
199
- clear_cuda_cache()
200
- return response
201
-
202
-
203
  def app_interface():
204
- corpus_sentences = []
205
- corpus_embeddings = []
206
  with gr.Blocks() as demo:
207
  gr.Markdown(title)
208
  gr.Markdown(description)
 
209
  with gr.Row():
210
  task_dropdown = gr.Dropdown(list(tasks.keys()), label="Select a Task", value=list(tasks.keys())[0])
211
 
212
- with gr.Tab("Embedding Generation"):
213
- input_text_box = gr.Textbox(label="📖Input Text")
214
- compute_button = gr.Button("Try🐣🛌🏻e5")
215
- output_display = gr.Textbox(label="🐣e5-mistral🛌🏻 Embeddings")
216
- compute_button.click(
217
- fn=compute_embeddings,
218
- inputs=[task_dropdown, input_text_box],
219
- outputs=output_display
220
- )
221
-
222
  with gr.Tab("Sentence Similarity"):
223
  sentence1_box = gr.Textbox(label="'Focus Sentence' - The 'Subject'")
224
  sentence2_box = gr.Textbox(label="'Input Sentence' - 1")
@@ -226,83 +177,17 @@ def app_interface():
226
  extra_sentence2_box = gr.Textbox(label="'Input Sentence' - 3")
227
  similarity_button = gr.Button("Compute Similarity")
228
  similarity_output = gr.Textbox(label="🐣e5-mistral🛌🏻 Similarity Scores")
 
229
  similarity_button.click(
230
  fn=compute_similarity,
231
  inputs=[task_dropdown, sentence1_box, sentence2_box, extra_sentence1_box, extra_sentence2_box],
232
  outputs=similarity_output
233
  )
234
- with gr.Tab("Load Corpus"):
235
- json_uploader = gr.File(label="Upload JSON File")
236
- load_corpus_button = gr.Button("Load Corpus")
237
- corpus_status = gr.Textbox(label="Corpus Status", value="Corpus not loaded")
238
-
239
- def load_corpus(file_info):
240
- if file_info is None:
241
- return "No file uploaded. Please upload a JSON file."
242
- try:
243
- global corpus_sentences, corpus_embeddings
244
- corpus_sentences = load_corpus_from_json(file_info['name'])
245
- corpus_embeddings = compute_embeddings_batch(corpus_sentences)
246
- return "Corpus loaded successfully with {} sentences.".format(len(corpus_sentences))
247
- except Exception as e:
248
- return "Error loading corpus: {}".format(e)
249
-
250
- load_corpus_button.click(
251
- fn=load_corpus,
252
- inputs=json_uploader,
253
- outputs=corpus_status
254
- )
255
-
256
- with gr.Tab("Semantic Search"):
257
- input_question_box = gr.Textbox(label="Enter your question")
258
- search_button = gr.Button("Search")
259
- search_results_output = gr.Textbox(label="Search Results")
260
-
261
- def perform_search(input_question):
262
- if not corpus_sentences or not corpus_embeddings:
263
- return "Corpus is not loaded. Please load a corpus first."
264
- return search_similar_sentences(input_question, corpus_sentences, corpus_embeddings)
265
-
266
- search_button.click(
267
- fn=perform_search,
268
- inputs=input_question_box,
269
- outputs=search_results_output
270
- )
271
-
272
- with gr.Tab("Connector-like Embeddings"):
273
- with gr.Row():
274
- input_text_box_connector = gr.Textbox(label="Input Text", placeholder="Enter text or array of texts")
275
- model_dropdown_connector = gr.Dropdown(label="Model", choices=["ArguAna", "ClimateFEVER", "DBPedia", "FEVER", "FiQA2018", "HotpotQA", "MSMARCO", "NFCorpus", "NQ", "QuoraRetrieval", "SCIDOCS", "SciFact", "Touche2020", "TRECCOVID"], value="text-embedding-ada-002")
276
- encoding_format_connector = gr.Radio(label="Encoding Format", choices=["float", "base64"], value="float")
277
- user_connector = gr.Textbox(label="User", placeholder="Enter user identifier (optional)")
278
- submit_button_connector = gr.Button("Generate Embeddings")
279
- output_display_connector = gr.JSON(label="Embeddings Output")
280
- submit_button_connector.click(
281
- fn=generate_and_format_embeddings,
282
- inputs=[model_dropdown_connector, input_text_box_connector],
283
- outputs=output_display_connector
284
- )
285
-
286
- # with gr.Tab("Decode Embedding"):
287
- # embedding_input = gr.Textbox(label="Enter Embedding (comma-separated floats)")
288
- # decode_button = gr.Button("Decode")
289
- # decoded_output = gr.Textbox(label="Decoded Embedding")
290
- #
291
- # decode_button.click(
292
- # fn=decode_embedding,
293
- # inputs=embedding_input,
294
- # outputs=decoded_output
295
- # )
296
-
297
- with gr.Row():
298
- with gr.Column():
299
- input_text_box
300
- with gr.Column():
301
- compute_button
302
- output_display
303
 
304
  return demo
305
 
 
 
306
 
307
  app_interface().queue()
308
  app_interface().launch(share=True)
 
6
  import queue
7
  import gradio as gr
8
  import os
9
+ import json
10
+ import numpy as np
11
+
12
 
13
  title = """
14
+ # 👋🏻Welcome to 🙋🏻‍♂️Tonic's 📽️Nvidia 🛌🏻Embed V-1 !"""
15
+
16
  description = """
17
+ You can use this Space to test out the current model [nvidia/NV-Embed-v1](https://huggingface.co/nvidia/NV-Embed-v1). 🐣a generalist embedding model that ranks No. 1 on the Massive Text Embedding Benchmark (MTEB benchmark)(as of May 24, 2024), with 56 tasks, encompassing retrieval, reranking, classification, clustering, and semantic textual similarity tasks.
18
+ You can also use 📽️Nvidia 🛌🏻Embed V-1 by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/NV-Embed?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></h3>
19
+ Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [MultiTonic](https://github.com/MultiTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
20
  """
 
 
21
 
22
  tasks = {
23
  'ArguAna': 'Given a claim, find documents that refute the claim',
 
33
  'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
34
  'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
35
  'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
36
+ 'Natural Language Inference' : 'Retrieve semantically similar text',
37
+ 'Natural Language Inference' : 'Given a premise, retrieve a hypothesis that is entailed by the premise 20k',
38
+ 'PAQ, MSMARCO' : 'Given a web search query, retrieve relevant passages that answer the query',
39
+ 'PAQ, MSMARCO' : 'Given a question, retrieve passages that answer the question',
40
+ 'SQUAD' : 'Given a question, retrieve Wikipedia passages that answer the question' ,
41
+ 'StackExchange' : 'Given a question paragraph at StackExchange, retrieve a question duplicated paragraph',
42
+ 'Natural Question' : 'Given a question, retrieve Wikipedia passages that answer the question',
43
+ 'BioASQ' : 'Given a question, retrieve detailed question descriptions that are duplicates to the given question',
44
+ 'STS12, STS22, STSBenchmark' : 'Retrieve semantically similar text.',
45
+ 'AmazonCounterfactualClassification' : 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual' ,
46
+ 'AmazonReviewsClassification' : 'Classify the given Amazon review into its appropriate rating category' ,
47
+ 'Banking77Classification' : 'Given a online banking query, find the corresponding intents',
48
+ 'EmotionClassification' : 'Classify the emotion expressed in the given Twitter message into one of the six emotions:anger, fear, joy, love, sadness, and surprise',
49
+ 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset',
50
+ 'MTOPIntentClassification' : 'Classify the intent of the given utterance in task-oriented conversation',
51
+ 'ToxicConversationsClassification' : 'Classify the given comments as either toxic or not toxic',
52
+ 'TweetSentimentExtractionClassification' : 'Classify the sentiment of a given tweet as either positive, negative, or neutral',
53
+ 'ArxivClusteringP2P' : 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts',
54
+ 'ArxivClusteringS2S' : 'Identify the main and secondary category of Arxiv papers based on the titles',
55
+ 'BiorxivClusteringP2P' : 'Identify the main category of Biorxiv papers based on the titles and abstracts' ,
56
+ 'BiorxivClusteringS2S' : 'Identify the main category of Biorxiv papers based on the titles',
57
+ 'MedrxivClusteringP2P' : 'Identify the main category of Medrxiv papers based on the titles and abstracts',
58
+ 'MedrxivClusteringS2S' : 'Identify the main category of Medrxiv papers based on the titles',
59
+ 'TwentyNewsgroupsClustering' : 'Identify the topic or theme of the given news articles'
60
  }
61
 
62
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
 
65
+ # Define the model and tokenizer globally
66
+ tokenizer = AutoTokenizer.from_pretrained('nvidia/NV-Embed-v1', trust_remote_code=True)
67
+ model = AutoModel.from_pretrained('nvidia/NV-Embed-v1', trust_remote_code=True).to(device)
68
+
69
+ # Embedding requests and response queues
70
  embedding_request_queue = queue.Queue()
71
  embedding_response_queue = queue.Queue()
72
 
73
+ def clear_cuda_cache():
74
+ torch.cuda.empty_cache()
 
75
 
76
  def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
77
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
 
82
  batch_size = last_hidden_states.shape[0]
83
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
84
 
85
+ def format_response(embeddings):
86
+ return {
87
+ "data": [
88
+ {
89
+ "embedding": embeddings,
90
+ "index": 0,
91
+ "object": "embedding"
92
+ }
93
+ ],
94
+ "model": "e5-mistral",
95
+ "object": "list",
96
+ "usage": {
97
+ "prompt_tokens": 17,
98
+ "total_tokens": 17
99
+ }
100
+ }
101
 
102
  def embedding_worker():
103
  while True:
 
113
  embedding_request_queue.task_done()
114
  clear_cuda_cache()
115
 
 
 
116
  def compute_embeddings(selected_task, input_text):
117
  try:
118
  task_description = tasks[selected_task]
119
  except KeyError:
120
  print(f"Selected task not found: {selected_task}")
121
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
122
+
123
  max_length = 2048
124
  processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
125
 
 
134
  clear_cuda_cache()
135
  return embeddings_list
136
 
 
 
 
 
 
 
 
 
 
137
  def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, extra_sentence2):
138
  try:
139
  task_description = tasks[selected_task]
140
  except KeyError:
141
  print(f"Selected task not found: {selected_task}")
142
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
143
+
144
  # Compute embeddings for each sentence
145
  embeddings1 = compute_embeddings(selected_task, sentence1)
146
  embeddings2 = compute_embeddings(selected_task, sentence2)
147
  embeddings3 = compute_embeddings(selected_task, extra_sentence1)
148
  embeddings4 = compute_embeddings(selected_task, extra_sentence2)
149
 
 
 
 
 
 
 
 
150
  similarity1 = compute_cosine_similarity(embeddings1, embeddings2)
151
  similarity2 = compute_cosine_similarity(embeddings1, embeddings3)
152
  similarity3 = compute_cosine_similarity(embeddings1, embeddings4)
153
 
 
 
 
154
  similarity_scores = {"Similarity 1-2": similarity1, "Similarity 1-3": similarity2, "Similarity 1-4": similarity3}
155
  clear_cuda_cache()
156
  return similarity_scores
157
+
158
  def compute_cosine_similarity(emb1, emb2):
159
  tensor1 = torch.tensor(emb1).to(device).half()
160
  tensor2 = torch.tensor(emb2).to(device).half()
161
  similarity = F.cosine_similarity(tensor1, tensor2).item()
 
162
  clear_cuda_cache()
163
  return similarity
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def app_interface():
 
 
166
  with gr.Blocks() as demo:
167
  gr.Markdown(title)
168
  gr.Markdown(description)
169
+
170
  with gr.Row():
171
  task_dropdown = gr.Dropdown(list(tasks.keys()), label="Select a Task", value=list(tasks.keys())[0])
172
 
 
 
 
 
 
 
 
 
 
 
173
  with gr.Tab("Sentence Similarity"):
174
  sentence1_box = gr.Textbox(label="'Focus Sentence' - The 'Subject'")
175
  sentence2_box = gr.Textbox(label="'Input Sentence' - 1")
 
177
  extra_sentence2_box = gr.Textbox(label="'Input Sentence' - 3")
178
  similarity_button = gr.Button("Compute Similarity")
179
  similarity_output = gr.Textbox(label="🐣e5-mistral🛌🏻 Similarity Scores")
180
+
181
  similarity_button.click(
182
  fn=compute_similarity,
183
  inputs=[task_dropdown, sentence1_box, sentence2_box, extra_sentence1_box, extra_sentence2_box],
184
  outputs=similarity_output
185
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return demo
188
 
189
+ embedding_worker_thread = threading.Thread(target=embedding_worker, daemon=True)
190
+ embedding_worker_thread.start()
191
 
192
  app_interface().queue()
193
  app_interface().launch(share=True)
benchmark.py DELETED
@@ -1,226 +0,0 @@
1
- import argparse
2
- import os.path
3
- import random
4
- import time
5
- from functools import partial
6
-
7
- import evaluate
8
- from tabulate import tabulate
9
- from tqdm import tqdm
10
-
11
- from texify.inference import batch_inference
12
- from texify.model.model import load_model
13
- from texify.model.processor import load_processor
14
- from PIL import Image
15
- from texify.settings import settings
16
- import json
17
- import base64
18
- import io
19
- from rapidfuzz.distance import Levenshtein
20
-
21
-
22
- def normalize_text(text):
23
- # Replace fences
24
- text = text.replace("$", "")
25
- text = text.replace("\[", "")
26
- text = text.replace("\]", "")
27
- text = text.replace("\(", "")
28
- text = text.replace("\)", "")
29
- text = text.strip()
30
- return text
31
-
32
-
33
- def score_text(predictions, references):
34
- bleu = evaluate.load("bleu")
35
- bleu_results = bleu.compute(predictions=predictions, references=references)
36
-
37
- meteor = evaluate.load('meteor')
38
- meteor_results = meteor.compute(predictions=predictions, references=references)
39
-
40
- lev_dist = []
41
- for p, r in zip(predictions, references):
42
- lev_dist.append(Levenshtein.normalized_distance(p, r))
43
-
44
- return {
45
- 'bleu': bleu_results["bleu"],
46
- 'meteor': meteor_results['meteor'],
47
- 'edit': sum(lev_dist) / len(lev_dist)
48
- }
49
-
50
-
51
- def image_to_pil(image):
52
- decoded = base64.b64decode(image)
53
- return Image.open(io.BytesIO(decoded))
54
-
55
-
56
- def load_images(source_data):
57
- images = [sd["image"] for sd in source_data]
58
- images = [image_to_pil(image) for image in images]
59
- return images
60
-
61
-
62
- def inference_texify(source_data, model, processor):
63
- images = load_images(source_data)
64
-
65
- write_data = []
66
- for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"):
67
- batch = images[i:i+settings.BATCH_SIZE]
68
- text = batch_inference(batch, model, processor)
69
- for j, t in enumerate(text):
70
- eq_idx = i + j
71
- write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]})
72
-
73
- return write_data
74
-
75
-
76
- def inference_pix2tex(source_data):
77
- from pix2tex.cli import LatexOCR
78
- model = LatexOCR()
79
-
80
- images = load_images(source_data)
81
- write_data = []
82
- for i in tqdm(range(len(images)), desc="Pix2tex inference"):
83
- try:
84
- text = model(images[i])
85
- except ValueError:
86
- # Happens when resize fails
87
- text = ""
88
- write_data.append({"text": text, "equation": source_data[i]["equation"]})
89
-
90
- return write_data
91
-
92
-
93
- def image_to_bmp(image):
94
- img_out = io.BytesIO()
95
- image.save(img_out, format="BMP")
96
- return img_out
97
-
98
-
99
- def inference_nougat(source_data, batch_size=1):
100
- import torch
101
- from nougat.postprocessing import markdown_compatible
102
- from nougat.utils.checkpoint import get_checkpoint
103
- from nougat.utils.dataset import ImageDataset
104
- from nougat.utils.device import move_to_device
105
- from nougat import NougatModel
106
-
107
- # Load images, then convert to bmp format for nougat
108
- images = load_images(source_data)
109
- images = [image_to_bmp(image) for image in images]
110
- predictions = []
111
-
112
- ckpt = get_checkpoint(None, model_tag="0.1.0-small")
113
- model = NougatModel.from_pretrained(ckpt)
114
- if settings.TORCH_DEVICE_MODEL != "cpu":
115
- move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA)
116
- model.eval()
117
-
118
- dataset = ImageDataset(
119
- images,
120
- partial(model.encoder.prepare_input, random_padding=False),
121
- )
122
-
123
- # Batch sizes higher than 1 explode memory usage on CPU/MPS
124
- dataloader = torch.utils.data.DataLoader(
125
- dataset,
126
- batch_size=batch_size,
127
- pin_memory=True,
128
- shuffle=False,
129
- )
130
-
131
- for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)):
132
- model.config.max_length = settings.MAX_TOKENS
133
- model_output = model.inference(image_tensors=sample, early_stopping=False)
134
- output = [markdown_compatible(o) for o in model_output["predictions"]]
135
- predictions.extend(output)
136
- return predictions
137
-
138
-
139
- def main():
140
- parser = argparse.ArgumentParser(description="Benchmark the performance of texify.")
141
- parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json"))
142
- parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json"))
143
- parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None)
144
- parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False)
145
- parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False)
146
- args = parser.parse_args()
147
-
148
- source_path = os.path.abspath(args.data_path)
149
- result_path = os.path.abspath(args.result_path)
150
- os.makedirs(os.path.dirname(result_path), exist_ok=True)
151
- model = load_model()
152
- processor = load_processor()
153
-
154
- with open(source_path, "r") as f:
155
- source_data = json.load(f)
156
-
157
- if args.max:
158
- random.seed(1)
159
- source_data = random.sample(source_data, args.max)
160
-
161
- start = time.time()
162
- predictions = inference_texify(source_data, model, processor)
163
- times = {"texify": time.time() - start}
164
- text = [normalize_text(p["text"]) for p in predictions]
165
- references = [normalize_text(p["equation"]) for p in predictions]
166
-
167
- scores = score_text(text, references)
168
-
169
- write_data = {
170
- "texify": {
171
- "scores": scores,
172
- "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)]
173
- }
174
- }
175
-
176
- if args.pix2tex:
177
- start = time.time()
178
- predictions = inference_pix2tex(source_data)
179
- times["pix2tex"] = time.time() - start
180
-
181
- p_text = [normalize_text(p["text"]) for p in predictions]
182
-
183
- p_scores = score_text(p_text, references)
184
-
185
- write_data["pix2tex"] = {
186
- "scores": p_scores,
187
- "text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)]
188
- }
189
-
190
- if args.nougat:
191
- start = time.time()
192
- predictions = inference_nougat(source_data)
193
- times["nougat"] = time.time() - start
194
- n_text = [normalize_text(p) for p in predictions]
195
-
196
- n_scores = score_text(n_text, references)
197
-
198
- write_data["nougat"] = {
199
- "scores": n_scores,
200
- "text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)]
201
- }
202
-
203
- score_table = []
204
- score_headers = ["bleu", "meteor", "edit"]
205
- score_dirs = ["⬆", "⬆", "⬇", "⬇"]
206
-
207
- for method in write_data.keys():
208
- score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]])
209
-
210
- score_headers.append("time taken (s)")
211
- score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
212
- print()
213
- print(tabulate(score_table, headers=["Method", *score_headers]))
214
- print()
215
- print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.")
216
- print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.")
217
-
218
- with open(result_path, "w") as f:
219
- json.dump(write_data, f, indent=4)
220
-
221
-
222
- if __name__ == "__main__":
223
- main()
224
-
225
-
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/.gitignore DELETED
@@ -1,6 +0,0 @@
1
- *
2
- !.gitignore
3
- !examples
4
- !examples/*
5
- !images
6
- !images/*
 
 
 
 
 
 
 
data/examples/0.md DELETED
@@ -1,5 +0,0 @@
1
- The potential $V_{i}$ of cell $\mathcal{C}_ {j}$ centred at position $\mathbf{r}_ {i}$ is related to the surface charge densities $\sigma_ {j}$ of cells $\mathcal{E}_ {j}$ $j\in[1,N]$ through the superposition principle as:
2
-
3
- $$V_ {i}\,=\,\sum_ {j=0}^{N}\,\frac{\sigma_ {j}}{4\pi\varepsilon_ {0}}\,\int_{\mathcal{E}_ {j}}\frac{1}{\left|\mathbf{r}_ {i}-\mathbf{r}^{\prime}\right|}\,\mathrm{d}^{2}\mathbf{r}^{\prime}\,=\,\sum_{j=0}^{N}\,Q_ {ij}\,\sigma_{j},$$
4
-
5
- where the integral over the surface of cell $\mathcal{C}_ {j}$ only depends on $\mathcal{C}{j}$ shape and on the relative position of the target point $\mathbf{r}_ {i}$ with respect to $\mathcal{C}_ {j}$ location, as $\sigma_ {j}$ is assumed constant over the whole surface of cell $\mathcal{C}_ {j}$.
 
 
 
 
 
 
data/examples/0.png DELETED
Binary file (24.1 kB)
 
data/examples/100.md DELETED
@@ -1 +0,0 @@
1
- Following , the minimal energy fraction the muon receives in the pion's rest frame is $r_ {\pi}=(m_ {\mu}/m_ {\pi})^2\approx0.57$, when it is emitted against the direction of movement, or 1 when it coincides with the pion's direction.
 
 
data/examples/100.png DELETED
Binary file (11.2 kB)
 
data/examples/300.md DELETED
@@ -1,4 +0,0 @@
1
-
2
- $$\mid\frac{1}{x}=\frac{1}{c}\mid=\mid\frac{c-x}{xc}\mid=\frac{1}{\left\vert x\right\vert}\cdot\frac{1}{\left\vert c\right\vert}\cdot\left\vert x-c\right\vert$$
3
-
4
- The factor $$\frac{1}{\left\vert x\right\vert}$$ is not good if its near 0.
 
 
 
 
 
data/examples/300.png DELETED
Binary file (5.48 kB)
 
data/examples/400.md DELETED
@@ -1,9 +0,0 @@
1
- Then the results are that afterward:
2
-
3
- For every value of $\lambda$, there is a probability of $|\langle\Psi|\Psi_\lambda\rangle|^2$ that the system is in state $|\Psi_\lambda\rangle$
4
-
5
- This is captured by the density matrix formalism as the transition
6
-
7
- $|\Psi\rangle\langle\Psi|\Rightarrow\sum_\lambda|\langle\Psi|\Psi_\lambda\rangle|^2|\Psi_\lambda\rangle\langle\Psi_\lambda|$
8
-
9
- atyy I guess thinking about it classically, Demystifier's argument must be right.
 
 
 
 
 
 
 
 
 
 
data/examples/400.png DELETED
Binary file (20.2 kB)
 
data/images/gui_screen.png DELETED
Binary file (655 kB)
 
data/images/texify_bench.png DELETED
Binary file (27.5 kB)
 
ocr_app.py DELETED
@@ -1,167 +0,0 @@
1
- import io
2
-
3
- import pandas as pd
4
- import streamlit as st
5
- from streamlit_drawable_canvas import st_canvas
6
- import hashlib
7
- import pypdfium2
8
-
9
- from texify.inference import batch_inference
10
- from texify.model.model import load_model
11
- from texify.model.processor import load_processor
12
- from texify.settings import settings
13
- import subprocess
14
- import re
15
- from PIL import Image
16
-
17
- MAX_WIDTH = 1000
18
-
19
-
20
- def replace_katex_invalid(string):
21
- # KaTeX cannot render all LaTeX, so we need to replace some things
22
- string = re.sub(r'\\tag\{.*?\}', '', string)
23
- string = re.sub(r'\\Big\{(.*?)\}|\\big\{(.*?)\}', r'\1\2', string)
24
- return string
25
-
26
- @st.cache_resource()
27
- def load_model_cached():
28
- return load_model()
29
-
30
-
31
- @st.cache_resource()
32
- def load_processor_cached():
33
- return load_processor()
34
-
35
-
36
- @st.cache_data()
37
- def infer_image(pil_image, bbox, temperature):
38
- input_img = pil_image.crop(bbox)
39
- model_output = batch_inference([input_img], model, processor, temperature=temperature)
40
- return model_output[0]
41
-
42
-
43
- def open_pdf(pdf_file):
44
- stream = io.BytesIO(pdf_file.getvalue())
45
- return pypdfium2.PdfDocument(stream)
46
-
47
-
48
- @st.cache_data()
49
- def get_page_image(pdf_file, page_num, dpi=96):
50
- doc = open_pdf(pdf_file)
51
- renderer = doc.render(
52
- pypdfium2.PdfBitmap.to_pil,
53
- page_indices=[page_num - 1],
54
- scale=dpi / 72,
55
- )
56
- png = list(renderer)[0]
57
- png_image = png.convert("RGB")
58
- return png_image
59
-
60
-
61
- @st.cache_data()
62
- def get_uploaded_image(in_file):
63
- return Image.open(in_file).convert("RGB")
64
-
65
-
66
- @st.cache_data()
67
- def page_count(pdf_file):
68
- doc = open_pdf(pdf_file)
69
- return len(doc)
70
-
71
-
72
- def get_canvas_hash(pil_image):
73
- return hashlib.md5(pil_image.tobytes()).hexdigest()
74
-
75
-
76
- @st.cache_data()
77
- def get_image_size(pil_image):
78
- if pil_image is None:
79
- return 800, 600
80
- height, width = pil_image.height, pil_image.width
81
- if width > MAX_WIDTH:
82
- scale = MAX_WIDTH / width
83
- height = int(height * scale)
84
- width = MAX_WIDTH
85
- return height, width
86
-
87
-
88
- st.set_page_config(layout="wide")
89
-
90
- top_message = """### Texify
91
-
92
- After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Texify will convert it to Markdown with LaTeX math on the right.
93
-
94
- If you have already cropped your image, select "OCR image" in the sidebar instead.
95
- """
96
-
97
- st.markdown(top_message)
98
- col1, col2 = st.columns([.7, .3])
99
-
100
- model = load_model_cached()
101
- processor = load_processor_cached()
102
-
103
- in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
104
- if in_file is None:
105
- st.stop()
106
-
107
- filetype = in_file.type
108
- whole_image = False
109
- if "pdf" in filetype:
110
- page_count = page_count(in_file)
111
- page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
112
-
113
- pil_image = get_page_image(in_file, page_number)
114
- else:
115
- pil_image = get_uploaded_image(in_file)
116
- whole_image = st.sidebar.button("OCR image")
117
-
118
- temperature = st.sidebar.slider("Generation temperature:", min_value=0.0, max_value=1.0, value=0.0, step=0.05)
119
-
120
- canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
121
-
122
- with col1:
123
- # Create a canvas component
124
- canvas_result = st_canvas(
125
- fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity
126
- stroke_width=1,
127
- stroke_color="#FFAA00",
128
- background_color="#FFF",
129
- background_image=pil_image,
130
- update_streamlit=True,
131
- height=get_image_size(pil_image)[0],
132
- width=get_image_size(pil_image)[1],
133
- drawing_mode="rect",
134
- point_display_radius=0,
135
- key=canvas_hash,
136
- )
137
-
138
- if canvas_result.json_data is not None or whole_image:
139
- objects = pd.json_normalize(canvas_result.json_data["objects"]) # need to convert obj to str because PyArrow
140
- bbox_list = None
141
- if objects.shape[0] > 0:
142
- boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
143
- boxes["right"] = boxes["left"] + boxes["width"]
144
- boxes["bottom"] = boxes["top"] + boxes["height"]
145
- bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
146
- if whole_image:
147
- bbox_list = [(0, 0, pil_image.width, pil_image.height)]
148
-
149
- if bbox_list:
150
- with col2:
151
- inferences = [infer_image(pil_image, bbox, temperature) for bbox in bbox_list]
152
- for idx, inference in enumerate(reversed(inferences)):
153
- st.markdown(f"### {len(inferences) - idx}")
154
- katex_markdown = replace_katex_invalid(inference)
155
- st.markdown(katex_markdown)
156
- st.code(inference)
157
- st.divider()
158
-
159
- with col2:
160
- tips = """
161
- ### Usage tips
162
- - Don't make your boxes too small or too large. See the examples and the video in the [README](https://github.com/vikParuchuri/texify) for more info.
163
- - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple.
164
- - You can try changing the temperature value on the left if you don't get good results. This controls how "creative" the model is.
165
- - Sometimes KaTeX won't be able to render an equation (red error text), but it will still be valid LaTeX. You can copy the LaTeX and render it elsewhere.
166
- """
167
- st.markdown(tips)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ocr_image.py DELETED
@@ -1,67 +0,0 @@
1
- import argparse
2
- import os.path
3
-
4
- from texify.inference import batch_inference
5
- from texify.model.model import load_model
6
- from texify.model.processor import load_processor
7
- from PIL import Image
8
- from texify.settings import settings
9
- from texify.util import is_valid_image
10
- import json
11
-
12
-
13
- def inference_single_image(image_path, json_path, model, processor):
14
- image = Image.open(image_path)
15
- text = batch_inference([image], model, processor)
16
- write_data = [{"image_path": image_path, "text": text[0]}]
17
- with open(json_path, "w+") as f:
18
- json_repr = json.dumps(write_data, indent=4)
19
- f.write(json_repr)
20
-
21
-
22
- def inference_image_dir(image_dir, json_path, model, processor, max=None):
23
- image_paths = [os.path.join(image_dir, image_name) for image_name in os.listdir(image_dir)]
24
- image_paths = [ip for ip in image_paths if is_valid_image(ip)]
25
- if max:
26
- image_paths = image_paths[:max]
27
-
28
- write_data = []
29
- for i in range(0, len(image_paths), settings.BATCH_SIZE):
30
- batch = image_paths[i:i+settings.BATCH_SIZE]
31
- images = [Image.open(image_path) for image_path in batch]
32
- text = batch_inference(images, model, processor)
33
- for image_path, t in zip(batch, text):
34
- write_data.append({"image_path": image_path, "text": t})
35
-
36
- with open(json_path, "w+") as f:
37
- json_repr = json.dumps(write_data, indent=4)
38
- f.write(json_repr)
39
-
40
-
41
- def main():
42
- parser = argparse.ArgumentParser(description="OCR an image of a LaTeX equation.")
43
- parser.add_argument("image", type=str, help="Path to image or folder of images to OCR.")
44
- parser.add_argument("--max", type=int, help="Maximum number of images to OCR if a folder is passes.", default=None)
45
- parser.add_argument("--json_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "results.json"))
46
- args = parser.parse_args()
47
-
48
- image_path = args.image
49
- model = load_model()
50
- processor = load_processor()
51
-
52
- json_path = os.path.abspath(args.json_path)
53
- os.makedirs(os.path.dirname(json_path), exist_ok=True)
54
-
55
- if os.path.isfile(image_path):
56
- inference_single_image(image_path, json_path, model, processor)
57
- else:
58
- inference_image_dir(image_path, json_path, model, processor, args.max)
59
-
60
- print(f"Wrote results to {json_path}")
61
-
62
-
63
- if __name__ == "__main__":
64
- main()
65
-
66
-
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
poetry.lock DELETED
The diff for this file is too large to render. See raw diff
 
pyproject.toml DELETED
@@ -1,47 +0,0 @@
1
- [tool.poetry]
2
- name = "texify"
3
- version = "0.1.6"
4
- description = "OCR for latex images"
5
- authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
6
- readme = "README.md"
7
- license = "GPL-3.0-or-later"
8
- repository = "https://github.com/VikParuchuri/texify"
9
- keywords = ["ocr", "latex", "markdown", "pdf"]
10
- include = [
11
- "ocr_app.py",
12
- "ocr_image.py",
13
- "run_ocr_app.py",
14
- "benchmark.py"
15
- ]
16
-
17
- [tool.poetry.dependencies]
18
- python = ">=3.10,<4.0"
19
- streamlit = "^1.29.0"
20
- transformers = "^4.36.2"
21
- torch = "^2.1.2"
22
- pydantic = "^2.5.2"
23
- pydantic-settings = "^2.1.0"
24
- Pillow = "^10.1.0"
25
- numpy = "^1.26.2"
26
- pypdfium2 = "^4.25.0"
27
- python-dotenv = "^1.0.0"
28
- watchdog = "^3.0.0"
29
- ftfy = "^6.1.3"
30
- tabulate = "^0.9.0"
31
- streamlit-drawable-canvas-jsretry = "^0.9.3"
32
-
33
- [tool.poetry.group.dev.dependencies]
34
- jupyter = "^1.0.0"
35
- evaluate = "^0.4.1"
36
- rapidfuzz = "^3.5.2"
37
- pyperclip = "^1.8.2"
38
- nltk = "^3.8.1"
39
-
40
- [tool.poetry.scripts]
41
- texify = "ocr_image:main"
42
- texify_gui = "run_ocr_app:run_app"
43
- texify_benchmark = "benchmark:main"
44
-
45
- [build-system]
46
- requires = ["poetry-core"]
47
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  transformers
2
- torch
3
- accelerate
 
 
1
  transformers
2
+ torch==2.2.0
3
+ accelerate
4
+ flash-attn==2.2.0
run_ocr_app.py DELETED
@@ -1,8 +0,0 @@
1
- import subprocess
2
- import os
3
-
4
-
5
- def run_app():
6
- cur_dir = os.path.dirname(os.path.abspath(__file__))
7
- ocr_app_path = os.path.join(cur_dir, "ocr_app.py")
8
- subprocess.run(["streamlit", "run", ocr_app_path])
 
 
 
 
 
 
 
 
 
scripts/verify_benchmark_scores.py DELETED
@@ -1,20 +0,0 @@
1
- import json
2
- import argparse
3
-
4
-
5
- def verify_scores(file_path):
6
- with open(file_path, 'r') as file:
7
- data = json.load(file)
8
-
9
- scores = data["texify"]["scores"]
10
-
11
- if scores["bleu"] <= 0.6 or scores["meteor"] <= 0.6 or scores["edit"] > 0.2:
12
- print(scores)
13
- raise ValueError("Scores do not meet the required threshold")
14
-
15
-
16
- if __name__ == "__main__":
17
- parser = argparse.ArgumentParser(description="Verify benchmark scores")
18
- parser.add_argument("file_path", type=str, help="Path to the json file")
19
- args = parser.parse_args()
20
- verify_scores(args.file_path)