Spaces:
Running
Running
File size: 4,550 Bytes
c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 a145e37 c5e57d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import random, os
from PIL import Image
import copy
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import warnings
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
try:
from vouchervision.utils_LLM import SystemLoadMonitor
except:
from utils_LLM import SystemLoadMonitor
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
class FlorenceOCR:
# def __init__(self, logger, model_id='microsoft/Florence-2-base'):
def __init__(self, logger, model_id='microsoft/Florence-2-large'):
self.MAX_TOKENS = 1024
self.logger = logger
self.model_id = model_id
self.monitor = SystemLoadMonitor(logger)
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# self.model_id_clean = "mistralai/Mistral-7B-v0.3"
self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean)
# Configuring the BitsAndBytesConfig for quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
quant_method="bnb",
)
self.model_clean = AutoModelForCausalLM.from_pretrained(
self.model_id_clean,
quantization_config=quant_config,
low_cpu_mem_usage=True,)
def ocr_florence(self, image, task_prompt='<OCR>', text_input=None):
self.monitor.start_monitoring_usage()
# Open image if a path is provided
if isinstance(image, str):
image = Image.open(image)
if text_input is None:
prompt = task_prompt
else:
prompt = task_prompt + text_input
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
# Move input_ids and pixel_values to the same device as the model
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=self.MAX_TOKENS,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer_dict = self.processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
parsed_answer_text = parsed_answer_dict[task_prompt]
# Prepare input for the second model
inputs_clean = self.tokenizer_clean(
f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_text}",
return_tensors="pt"
)
inputs_clean = {key: value.to(self.model_clean.device) for key, value in inputs_clean.items()}
outputs_clean = self.model_clean.generate(**inputs_clean, max_new_tokens=self.MAX_TOKENS)
text_with_spaces = self.tokenizer_clean.decode(outputs_clean[0], skip_special_tokens=True)
# Extract only the LLM response from the decoded text
response_start = text_with_spaces.find(parsed_answer_text)
if response_start != -1:
text_with_spaces = text_with_spaces[response_start + len(parsed_answer_text):].strip()
print(text_with_spaces)
self.monitor.stop_inference_timer() # Starts tool timer too
usage_report = self.monitor.stop_monitoring_report_usage()
return text_with_spaces, parsed_answer_text, parsed_answer_dict, usage_report
def main():
# img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg'
img_path = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg'
image = Image.open(img_path)
# ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-base')
ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-large')
results_text, results_all, results_dirty, usage_report = ocr.ocr_florence(image, task_prompt='<OCR>', text_input=None)
print(results_text)
if __name__ == '__main__':
main()
|