Spaces:
Build error
Build error
import torch | |
from rudalle import get_tokenizer, get_vae | |
from rudalle.utils import seed_everything | |
import sys | |
from rudolph.model.utils import get_i2t_attention_mask, get_t2t_attention_mask | |
from rudolph.model import get_rudolph_model, ruDolphModel, FP16Module | |
from rudolph.pipelines import generate_codebooks, self_reranking_by_image, self_reranking_by_text, show, generate_captions, generate_texts | |
from rudolph.pipelines import zs_clf | |
import gradio as gr | |
from rudolph import utils | |
from PIL import Image | |
device = 'cpu' | |
if device=='cuda': | |
half = True | |
else: | |
half = False | |
model = get_rudolph_model('350M', fp16=half, device=device) | |
model.load_state_dict(torch.load("awesomemodel__dalle_1500.pt",map_location=torch.device('cpu'))) | |
tokenizer = get_tokenizer() | |
vae = get_vae(dwt=False).to(device) | |
template = 'белков: ' | |
# Download human-readable labels for ImageNet. | |
def classify_image(inp): | |
print(type(inp)) | |
inp = Image.fromarray(inp) | |
texts = generate_captions(inp, tokenizer, model, vae, template=template, top_k=16, captions_num=1, bs=16, top_p=0.6, seed=43, temperature=0.8) | |
rp = texts[0].replace('белков','protein').replace('жиров','fat').replace('углеводов','carbs').replace('calories','ккал') | |
print(rp) | |
return rp | |
image = gr.inputs.Image(shape=(128, 128)) | |
label = gr.outputs.Label(num_top_classes=3) | |
iface = gr.Interface(fn=classify_image, description="https://github.com/sberbank-ai/ru-dolph RuDoplh by SBER AI finetuned for a image2text task to predict food calories by https://t.me/lovedeathtransformers Alex Wortega", inputs=image, outputs="text",examples=[ | |
['b9c277a3.jpeg']]) | |
iface.launch() | |