Yash Malviya
Added everything
d001321
raw
history blame
No virus
2.88 kB
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import gradio as gr
import pandas as pd
import subprocess
import os
# Install flash-attn without CUDA build
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Load the model and processor
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation='flash_attention_2',
load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
# Define the OCR function
def phi_ocr(image):
question = "Convert the text to markdown format."
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0]
return response
# Define the function to process multiple images and save results to a CSV
def process_images(input_images):
results = []
for index, image in enumerate(input_images):
extracted_text = phi_ocr(image)
results.append({
'index': index,
'extracted_text': extracted_text
})
# Convert to DataFrame and save to CSV
df = pd.DataFrame(results)
output_csv = "extracted_entities.csv"
df.to_csv(output_csv, index=False)
return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# OCR with TB-OCR-preview-0.1")
gr.Markdown("Upload multiple images to extract and convert text to markdown format.")
gr.Markdown("[Check out the model here](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)")
with gr.Row():
input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True)
output_text = gr.Textbox(label="Status")
output_csv_link = gr.File(label="Download CSV")
input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link])
demo.launch()