import os import torch import gradio as gr from PIL import Image from transformers import AutoModelForCausalLM,AutoProcessor device = 'cuda' if torch.cuda.is_available() else 'cpu' processor = AutoProcessor.from_pretrained("microsoft/git-base") model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device) def generate_captions(images:[Image],max_length=200): # prepare image for the model inputs = processor(images=images, return_tensors="pt").to(device) pixel_values = inputs.pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True) return generated_caption def generate_caption(image,max_length=200): return generate_captions(image,max_length)[0] inputs = [ gr.Image(sources=["upload", "clipboard"], height=400, type="pil" ), gr.Slider(minimum=10, maximum=400, value=200, label='max length', step=8, ) ] outputs = [ gr.Text(label="Generated Caption"), ] demo = gr.Interface( fn=generate_caption, inputs=inputs, outputs=outputs, title="Stable Diffusion Portrait Captioner", theme="gradio/monochrome", api_name="caption", submit_btn=gr.Button("caption it", variant="primary"), allow_flagging="never", ) demo.queue( max_size=10, ) demo.launch()