|
--- |
|
license: mit |
|
--- |
|
Part of Advanced NLP Project for Team Shrine - Adnan Qidwai, Harshit Karwal and Shrikara Arun. |
|
CleanCaption is an image captioning model that forget an object from the image when generating the caption. It is a finetuned version of `microsoft/Florence-2-large-ft`. |
|
|
|
Usage: |
|
```python |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from PIL import Image |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"sudokara/CleanCaption", |
|
trust_remote_code=True |
|
).eval().to(device) |
|
|
|
def forget(prompt, image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
prompt = f"Forget from caption: {str(prompt)}".strip(' :') |
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) |
|
generated_ids = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
num_beams=3, |
|
) |
|
return processor.decode(generated_ids[0]).replace('<s>', '').replace('</s>', '') |
|
|
|
image_path = "image.png" |
|
print(forget(image_path = image_path, prompt = "water")) |
|
``` |
|
|