zh-clip / app.py
xxx1's picture
Update app.py
dcf82fb
from PIL import Image
import gradio as gr
import requests
# import torch
from models.zhclip import ZhCLIPProcessor, ZhCLIPModel # From https://www.github.com/yue-gang/ZH-CLIP
version = 'nlpcver/zh-clip-vit-roberta-large-patch14'
model = ZhCLIPModel.from_pretrained(version)
processor = ZhCLIPProcessor.from_pretrained(version)
def get_result(image,text,text1):
inputs = processor(text=[text,text1], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_features = outputs.image_features
text_features = outputs.text_features
text_probs = (image_features @ text_features.T).softmax(dim=-1)
return text_probs
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil",label="Image Input")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="Captions0 Input")
chat_input1 = gr.Textbox(lines=1, label="Captions1 Input")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True,width=30)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
with gr.Column():
caption_output = gr.Textbox(lines=0, label="ITM")
clear_button.click(
lambda: ("", [],"","",""),
[],
[chat_input, state,caption_output],
queue=False,
)
submit_button.click(
get_result,
[
image_input,
chat_input,
chat_input1,
],
[caption_output],
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)