import random, os from PIL import Image import copy import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image, ImageDraw, ImageFont import numpy as np import warnings from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig try: from vouchervision.utils_LLM import SystemLoadMonitor except: from utils_LLM import SystemLoadMonitor warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") class FlorenceOCR: # def __init__(self, logger, model_id='microsoft/Florence-2-base'): def __init__(self, logger, model_id='microsoft/Florence-2-large'): self.MAX_TOKENS = 1024 self.logger = logger self.model_id = model_id self.monitor = SystemLoadMonitor(logger) self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda() self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # self.model_id_clean = "mistralai/Mistral-7B-v0.3" self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean) # Configuring the BitsAndBytesConfig for quantization quant_config = BitsAndBytesConfig( load_in_4bit=True, quant_method="bnb", ) self.model_clean = AutoModelForCausalLM.from_pretrained( self.model_id_clean, quantization_config=quant_config, low_cpu_mem_usage=True,) def ocr_florence(self, image, task_prompt='', text_input=None): self.monitor.start_monitoring_usage() # Open image if a path is provided if isinstance(image, str): image = Image.open(image) if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = self.processor(text=prompt, images=image, return_tensors="pt") # Move input_ids and pixel_values to the same device as the model inputs = {key: value.to(self.model.device) for key, value in inputs.items()} generated_ids = self.model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=self.MAX_TOKENS, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer_dict = self.processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) parsed_answer_text = parsed_answer_dict[task_prompt] # Prepare input for the second model inputs_clean = self.tokenizer_clean( f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_text}", return_tensors="pt" ) inputs_clean = {key: value.to(self.model_clean.device) for key, value in inputs_clean.items()} outputs_clean = self.model_clean.generate(**inputs_clean, max_new_tokens=self.MAX_TOKENS) text_with_spaces = self.tokenizer_clean.decode(outputs_clean[0], skip_special_tokens=True) # Extract only the LLM response from the decoded text response_start = text_with_spaces.find(parsed_answer_text) if response_start != -1: text_with_spaces = text_with_spaces[response_start + len(parsed_answer_text):].strip() print(text_with_spaces) self.monitor.stop_inference_timer() # Starts tool timer too usage_report = self.monitor.stop_monitoring_report_usage() return text_with_spaces, parsed_answer_text, parsed_answer_dict, usage_report def main(): # img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg' img_path = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg' image = Image.open(img_path) # ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-base') ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-large') results_text, results_all, results_dirty, usage_report = ocr.ocr_florence(image, task_prompt='', text_input=None) print(results_text) if __name__ == '__main__': main()