from __future__ import annotations import gradio as gr import torch from gradio_client import Client, file DESCRIPTION = "# Comparing image captioning models" ORIGINAL_SPACE_INFO = """\ - [GIT-large fine-tuned on COCO](https://huggingface.co/spaces/hysts/image-captioning-with-git) - [BLIP-large](https://huggingface.co/spaces/hysts/image-captioning-with-blip) - [BLIP-2 OPT 6.7B](https://huggingface.co/spaces/merve/BLIP2-with-transformers) - [BLIP-2 T5-XXL](https://huggingface.co/spaces/hysts/BLIP2) - [InstructBLIP](https://huggingface.co/spaces/hysts/InstructBLIP) - [Fuyu-8B](https://huggingface.co/spaces/adept/fuyu-8b-demo) """ torch.hub.download_url_to_file("http://images.cocodataset.org/val2017/000000039769.jpg", "cats.jpg") torch.hub.download_url_to_file( "https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png", "stop_sign.png" ) torch.hub.download_url_to_file( "https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg", "astronaut.jpg" ) def generate_caption_git(image_path: str) -> str: try: client = Client("hysts/image-captioning-with-git") return client.predict(file(image_path), api_name="/caption") except Exception: gr.Warning("The GIT-large Space is currently unavailable. Please try again later.") return "" def generate_caption_blip(image_path: str) -> str: try: client = Client("hysts/image-captioning-with-blip") return client.predict(file(image_path), "A picture of", api_name="/caption") except Exception: gr.Warning("The BLIP-large Space is currently unavailable. Please try again later.") return "" def generate_caption_blip2_opt(image_path: str) -> str: try: client = Client("merve/BLIP2-with-transformers") return client.predict( file(image_path), "Beam search", 1, # temperature 1, # length penalty 1.5, # repetition penalty api_name="/caption", ) except Exception: gr.Warning("The BLIP2 OPT6.7B Space is currently unavailable. Please try again later.") return "" def generate_caption_blip2_t5xxl(image_path: str) -> str: try: client = Client("hysts/BLIP2") return client.predict( file(image_path), "Beam search", 1, # temperature 1, # length penalty 1.5, # repetition penalty 50, # max length 1, # min length 5, # number of beams 0.9, # top p api_name="/caption", ) except Exception: gr.Warning("The BLIP2 T5-XXL Space is currently unavailable. Please try again later.") return "" def generate_caption_instructblip(image_path: str) -> str: try: client = Client("hysts/InstructBLIP") return client.predict( file(image_path), "Describe the image.", "Beam search", 5, # beam size 256, # max length 1, # min length 0.9, # top p 1.5, # repetition penalty 1.0, # length penalty 1.0, # temperature api_name="/run", ) except Exception: gr.Warning("The InstructBLIP Space is currently unavailable. Please try again later.") return "" def generate_caption_fuyu(image_path: str) -> str: try: client = Client("adept/fuyu-8b-demo") return client.predict( file(image_path), "Generate a coco style caption.\n", fn_index=3, ) except Exception: gr.Warning("The Fuyu-8B Space is currently unavailable. Please try again later.") return "" with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(type="filepath") run_button = gr.Button("Caption") with gr.Column(): out_git = gr.Textbox(label="GIT-large fine-tuned on COCO") out_blip = gr.Textbox(label="BLIP-large") out_blip2_opt = gr.Textbox(label="BLIP-2 OPT 6.7B") out_blip2_t5xxl = gr.Textbox(label="BLIP-2 T5-XXL") out_instructblip = gr.Textbox(label="InstructBLIP") out_fuyu = gr.Textbox(label="Fuyu-8B") gr.Examples( examples=[ "cats.jpg", "stop_sign.png", "astronaut.jpg", ], inputs=input_image, ) with gr.Accordion(label="The original Spaces can be found here:", open=False): gr.Markdown(ORIGINAL_SPACE_INFO) fn_out_pairs = [ (generate_caption_git, out_git), (generate_caption_blip, out_blip), (generate_caption_blip2_opt, out_blip2_opt), (generate_caption_blip2_t5xxl, out_blip2_t5xxl), (generate_caption_instructblip, out_instructblip), (generate_caption_fuyu, out_fuyu), ] for fn, out in fn_out_pairs: run_button.click( fn=fn, inputs=input_image, outputs=out, api_name=False, ) if __name__ == "__main__": demo.queue(max_size=20, api_open=False).launch(show_api=False)