Yash Malviya commited on
Commit
d001321
1 Parent(s): 079fac8

Added everything

Browse files
Files changed (3) hide show
  1. README.md +3 -4
  2. app.py +76 -3
  3. requirements.txt +10 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: ML Hackathon Tesseract
3
  emoji: 💻
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Tesseract Ocr
3
  emoji: 💻
4
+ colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,7 +1,80 @@
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ from PIL import Image
4
+ import requests
5
  import gradio as gr
6
+ import pandas as pd
7
+ import subprocess
8
+ import os
9
 
10
+ # Install flash-attn without CUDA build
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ # Load the model and processor
14
+ model_id = "yifeihu/TB-OCR-preview-0.1"
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ device_map="cuda",
20
+ trust_remote_code=True,
21
+ torch_dtype="auto",
22
+ attn_implementation='flash_attention_2',
23
+ load_in_4bit=True
24
+ )
25
+ processor = AutoProcessor.from_pretrained(model_id,
26
+ trust_remote_code=True,
27
+ num_crops=16
28
+ )
29
+
30
+ # Define the OCR function
31
+ def phi_ocr(image):
32
+ question = "Convert the text to markdown format."
33
+ prompt_message = [{
34
+ 'role': 'user',
35
+ 'content': f'<|image_1|>\n{question}',
36
+ }]
37
+ prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
38
+ inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
39
+ generation_args = {
40
+ "max_new_tokens": 1024,
41
+ "temperature": 0.1,
42
+ "do_sample": False
43
+ }
44
+ generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
45
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
46
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
47
+ response = response.split("<image_end>")[0]
48
+ return response
49
+
50
+ # Define the function to process multiple images and save results to a CSV
51
+ def process_images(input_images):
52
+ results = []
53
+ for index, image in enumerate(input_images):
54
+ extracted_text = phi_ocr(image)
55
+ results.append({
56
+ 'index': index,
57
+ 'extracted_text': extracted_text
58
+ })
59
+
60
+ # Convert to DataFrame and save to CSV
61
+ df = pd.DataFrame(results)
62
+ output_csv = "extracted_entities.csv"
63
+ df.to_csv(output_csv, index=False)
64
+
65
+ return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv
66
+
67
+ # Gradio UI
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# OCR with TB-OCR-preview-0.1")
70
+ gr.Markdown("Upload multiple images to extract and convert text to markdown format.")
71
+ gr.Markdown("[Check out the model here](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)")
72
+
73
+ with gr.Row():
74
+ input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True)
75
+ output_text = gr.Textbox(label="Status")
76
+ output_csv_link = gr.File(label="Download CSV")
77
+
78
+ input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link])
79
 
 
80
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ accelerate
3
+ torch
4
+ spaces
5
+ torchvision
6
+ Pillow
7
+ pandas
8
+ gradio
9
+ bitsandbytes
10
+