|
import gradio as gr |
|
import spaces |
|
from PIL import Image |
|
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor |
|
from transformers import TextIteratorStreamer |
|
from threading import Thread |
|
|
|
TITLE = "E621 Tagger" |
|
DESCRIPTION = "Tag images with E621 tags" |
|
|
|
MODEL_ID = "estrogen/paligemma2-3b-e621-224" |
|
|
|
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID) |
|
model.to("cuda") |
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
|
@spaces.GPU |
|
def tag_image(image, max_new_tokens=128, temperature=1, top_p=1, min_p=0): |
|
inputs = processor(images=image, text="<image>tag en", return_tensors="pt").to("cuda") |
|
streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True) |
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, cache_implementation="hybrid", do_sample=True, temperature=temperature, top_p=top_p, min_p=min_p) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
|
text = "" |
|
thread.start() |
|
for new_text in streamer: |
|
text += new_text |
|
yield text |
|
|
|
return text |
|
|
|
gr.Interface( |
|
fn=tag_image, |
|
inputs=[gr.Image(type="pil"), gr.Slider(label="Max new tokens", minimum=1, maximum=1024, value=128), gr.Slider(label="Temperature", minimum=0, maximum=2, value=1), gr.Slider(label="Top p", minimum=0, maximum=1, value=1), gr.Slider(label="Min p", minimum=0, maximum=1, value=0)], |
|
outputs=gr.Textbox(type="text"), |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
).launch() |