rrquizon1's picture
Create app.py
4d1d4a2
raw
history blame
2.23 kB
from transformers import MarianTokenizer, MarianMTModel
from gtts import gTTS
import gradio as gr
import gradio as gr
import torch
import torchvision
import torchvision.transforms as transforms
import requests
from einops import rearrange
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
import matplotlib
def imgtrans(img):
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-384')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-384')
inputs = feature_extractor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 21,841 ImageNet-22k classes
predicted_class_idx = logits.argmax(-1).item()
english=model.config.id2label[predicted_class_idx]
english=english.replace("_", " ")
english=english.split(',',1)[0]
src = "en" # source language
trg = "tl" # target language
model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
sample_text = english.lower()
batch = tokenizer([sample_text], return_tensors="pt")
generated_ids = model.generate(**batch)
fil=tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0];
tts=gTTS(text=fil,lang='tl')
tts.save('filtrans.wav')
fil_sound='filtrans.wav'
english=english.lower()
tts=gTTS(text=english,lang='en')
tts.save('engtrans.wav')
eng_sound='engtrans.wav'
return fil_sound,fil,eng_sound,english
interface=gr.Interface(fn=imgtrans,
inputs=gr.inputs.Image(shape=(224,224),label='Insert Image'),
outputs=[gr.outputs.Audio(label='Filipino Pronunciation'),gr.outputs.Textbox(label='Filipino Label'),
gr.outputs.Audio(label='English Pronunciation'),gr.outputs.Textbox(label='English label')],
examples = ['220px-Modern_British_LED_Traffic_Light.jpg','aki_dog.jpg','cat.jpg','dog.jpg','plasticbag.jpg',
'telephone.jpg','vpavic_211006_4796_0061.jpg','watch.jpg','wonder_cat.jpg','hammer.jpg'])
interface.launch()