import torch from rudalle import get_tokenizer, get_vae from rudalle.utils import seed_everything import sys from rudolph.model.utils import get_i2t_attention_mask, get_t2t_attention_mask from rudolph.model import get_rudolph_model, ruDolphModel, FP16Module from rudolph.pipelines import generate_codebooks, self_reranking_by_image, self_reranking_by_text, show, generate_captions, generate_texts from rudolph.pipelines import zs_clf import gradio as gr from rudolph import utils device = 'cuda' model = get_rudolph_model('350M', fp16=True, device='cuda') tokenizer = get_tokenizer() vae = get_vae(dwt=False).to(device) # Download human-readable labels for ImageNet. def classify_image(inp): print(type(inp)) inp = Image.fromarray(inp) texts = generate_captions(inp, tokenizer, model, vae, template=template, top_k=16, captions_num=1, bs=16, top_p=0.6, seed=43, temperature=0.8) return texts image = gr.inputs.Image(shape=(128, 128)) label = gr.outputs.Label(num_top_classes=3) iface = gr.Interface(fn=classify_image, inputs=image, outputs="text",examples=[ ['b9c277a3.jpeg']]) iface.launch(share=True)