from PIL import Image import gradio as gr import torch import requests from models.zhclip import ZhCLIPProcessor, ZhCLIPModel # From https://www.github.com/yue-gang/ZH-CLIP version = 'nlpcver/zh-clip-vit-roberta-large-patch14' model = ZhCLIPModel.from_pretrained(version) processor = ZhCLIPProcessor.from_pretrained(version) def get_result(image,text): inputs = processor(text=[text], images=image, return_tensors="pt", padding=True) outputs = model(**inputs) image_features = outputs.image_features text_features = outputs.text_features text_probs = (image_features @ text_features.T).softmax(dim=-1) return text_probs with gr.Blocks( css=""" .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} #component-21 > div.wrap.svelte-w6rprc {height: 600px;} """ ) as iface: state = gr.State([]) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil",label="VQA Image Input") with gr.Row(): with gr.Column(scale=1): chat_input = gr.Textbox(lines=1, label="VQA Question Input") with gr.Row(): clear_button = gr.Button(value="Clear", interactive=True,width=30) submit_button = gr.Button( value="Submit", interactive=True, variant="primary" ) ''' cap_submit_button = gr.Button( value="Submit_CAP", interactive=True, variant="primary" ) gpt3_submit_button = gr.Button( value="Submit_GPT3", interactive=True, variant="primary" ) ''' with gr.Column(): caption_output = gr.Textbox(lines=0, label="ITM") chat_input.submit( get_result, [ image_input, chat_input, ], [ caption_output], ) clear_button.click( lambda: ("", [],"","",""), [], [chat_input, state,caption_output], queue=False, ) submit_button.click( get_result, [ image_input, chat_input, ], [caption_output], ) iface.queue(concurrency_count=1, api_open=False, max_size=10) iface.launch(enable_queue=True)