Yash Malviya
added try
060de37
raw
history blame contribute delete
No virus
2.89 kB
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import gradio as gr
import pandas as pd
import subprocess
import os
# try one more time
# Install flash-attn without CUDA build
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Load the model and processor
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation='flash_attention_2',
load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
# Define the OCR function
def phi_ocr(image):
question = "Convert the text to markdown format."
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0]
return response
# Define the function to process multiple images and save results to a CSV
def process_images(input_images):
results = []
for index, image in enumerate(input_images):
extracted_text = phi_ocr(image)
results.append({
'index': index,
'extracted_text': extracted_text
})
# Convert to DataFrame and save to CSV
df = pd.DataFrame(results)
output_csv = "extracted_entities.csv"
df.to_csv(output_csv, index=False)
return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# OCR with TB-OCR-preview-0.1")
gr.Markdown("Upload multiple images to extract and convert text to markdown format.")
gr.Markdown("[Check out ](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)")
with gr.Row():
input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True)
output_text = gr.Textbox(label="Status")
output_csv_link = gr.File(label="Download CSV")
input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link])
demo.launch()