Nemil's picture
Upload app.py
83ff097 verified
raw
history blame contribute delete
No virus
10.1 kB
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from PIL import Image
import requests
import traceback
import os
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
class Image2Text:
def __init__(self):
# Load the GIT coco model
preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.preprocessor = preprocessor_git_large_coco
self.model = model_git_large_coco
self.model.to(self.device)
def image_description(
self,
image_url,
max_length=50,
temperature=0.1,
use_sample_image=False,
):
"""
Generate captions for the given image.
-----
Parameters
image_url: Image URL
The image to generate captions for.
max_length: int
The max length of the generated descriptions.
-----
Returns
str
The generated image description.
"""
caption_git_large_coco = ""
if use_sample_image:
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Generate captions for the image using the GIT coco model
try:
caption_git_large_coco = self._generate_description(image, max_length, False).strip()
return caption_git_large_coco
except Exception as e:
print(e)
traceback.print_exc()
def _generate_description(
self,
image,
max_length=50,
use_float_16=False,
):
"""
Generate captions for the given image.
-----
Parameters
image: PIL.Image
The image to generate captions for.
max_length: int
The max length of the generated descriptions.
use_float_16: bool
Whether to use float16 precision. This can speed up inference, but may lead to worse results.
-----
Returns
str
The generated caption.
"""
# inputs = preprocessor(image, return_tensors="pt").to(device)
pixel_values = self.preprocessor(images=image, return_tensors="pt").pixel_values.to(self.device)
generated_ids = self.model.generate(
pixel_values=pixel_values,
max_length=max_length,
)
generated_caption = self.preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
import json
from pprint import pprint
import bitsandbytes as bnb
import pandas as pd
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import (
LoraConfig ,
PeftConfig ,
PeftModel ,
get_peft_model ,
prepare_model_for_kbit_training,
)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from peft import LoraConfig, get_peft_model
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
class Social_Media_Captioner:
def __init__(self, use_finetuned: bool=True, temp=0.1):
self.use_finetuned = use_finetuned
self.MODEL_NAME = "vilsonrodrigues/falcon-7b-instruct-sharded"
self.peft_model_name = "ayush-vatsal/caption_qlora_finetune"
self.model_loaded = False
self.device = "cuda:0"
self._load_model()
self.generation_config = self.model.generation_config
self.generation_config.max_new_tokens = 50
self.generation_config.temperature = temp
self.generation_config.top_p = 0.7
self.generation_config.num_return_sequences = 1
self.generation_config.pad_token_id = self.tokenizer.eos_token_id
self.generation_config.eos_token_id = self.tokenizer.eos_token_id
self.cache: list[dict] = [] # [{"image_decription": "A man", "caption": ["A man"]}]
def _load_model(self):
self.bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
load_in_8bit_fp32_cpu_offload=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.MODEL_NAME,
device_map = "auto",
trust_remote_code = True,
quantization_config = self.bnb_config
)
# Defining the tokenizers
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
self.tokenizer.pad_token = self.tokenizer.eos_token
# if self.use_finetuned:
# # LORA Config Model
# self.lora_config = LoraConfig(
# r=16,
# lora_alpha=32,
# target_modules=["query_key_value"],
# lora_dropout=0.05,
# bias="none",
# task_type="CAUSAL_LM"
# )
# self.model = get_peft_model(self.model, self.lora_config)
# # Fitting the adapters
# self.peft_config = PeftConfig.from_pretrained(self.peft_model_name)
# self.model = AutoModelForCausalLM.from_pretrained(
# self.peft_config.base_model_name_or_path,
# return_dict = True,
# quantization_config = self.bnb_config,
# device_map= "auto",
# trust_remote_code = True
# )
# self.model = PeftModel.from_pretrained(self.model, self.peft_model_name)
# # Defining the tokenizers
# self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
# self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
print("Model Loaded successfully")
def inference(self, input_text: str, use_cached=True, cache_generation=True) -> str | None:
if not self.model_loaded:
raise Exception("Model not loaded")
try:
prompt = Social_Media_Captioner._prompt(input_text)
if use_cached:
for item in self.cache:
if item['image_description'] == input_text:
return item['caption']
encoding = self.tokenizer(prompt, return_tensors = "pt").to(self.device)
with torch.inference_mode():
outputs = self.model.generate(
input_ids = encoding.input_ids,
attention_mask = encoding.attention_mask,
generation_config = self.generation_config
)
generated_caption = (self.tokenizer.decode(outputs[0], skip_special_tokens=True).split('Caption: "')[-1]).split('"')[0]
if cache_generation:
for item in self.cache:
if item['image_description'] == input_text:
item['caption'].append(generated_caption)
break
else:
self.cache.append({
'image_description': input_text,
'caption': [generated_caption]
})
return generated_caption
except Exception as e:
print(e)
return None
def _prompt(input_text="A man walking alone in the road"):
if input_text is None:
raise Exception("Enter a valid input text to generate a valid prompt")
return f"""
Convert the given image description to a appropriate metaphoric caption
Description: {input_text}
Caption:
""".strip()
@staticmethod
def get_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return f"trainable_params: {trainable_params} || all_params: {all_param} || Percentage of trainable params: {100*trainable_params / all_param}"
def __repr__(self):
return f"""
Base Model Name: {self.MODEL_NAME}
PEFT Model Name: {self.peft_model_name}
Using PEFT Finetuned Model: {self.use_finetuned}
Model: {self.model}
------------------------------------------------------------
{Social_Media_Captioner.get_trainable_parameters(self.model)}
"""
class Captions:
def __init__(self, use_finetuned_LLM: bool=True, temp_LLM=0.1):
self.image_to_text = Image2Text()
self.LLM = Social_Media_Captioner(use_finetuned_LLM, temp_LLM)
def generate_captions(
self,
image,
image_url=None,
max_length_GIT=50,
temperature_GIT=0.1,
use_sample_image_GIT=False,
use_cached_LLM=True,
cache_generation_LLM=True
):
if image_url:
image_description = self.image_to_text.image_description(image_url, max_length=max_length_GIT, temperature=temperature_GIT, use_sample_image=use_sample_image_GIT)
else:
image_description = self.image_to_text._generate_description(image, max_length=max_length_GIT)
captions = self.LLM.inference(image_description, use_cached=use_cached_LLM, cache_generation=cache_generation_LLM)
return captions
caption_generator = Captions()
import gradio as gr
def setup(image):
return caption_generator.generate_captions(image = image)
iface = gr.Interface(
fn=setup,
inputs=gr.inputs.Image(type="pil", label="Upload Image"),
outputs=gr.outputs.Textbox(label="Caption")
)
iface.launch()