ixxan's picture
Update app.py
653ba6d verified
import gradio as gr
import re
import torch
from transformers.utils import logging
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration
import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') # set SyncHTTPTransport attribute for googletrans dependency
from googletrans import Translator
from googletrans import LANGCODES
# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
logger.info("Translation text: " + translation.text)
logger.info("Translation src: " + translation.src)
return (translation.text, translation.src)
# Lang to lang_code mapping
def lang_code_match(accaptable_lang):
# Exception for chinese langs
if accaptable_lang == 'mandarin':
return 'zh-cn'
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
return 'zh-tw'
# Default
else:
return LANGCODES[accaptable_lang]
# Find destination language
def find_dest_language(sentence, src_lang):
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
match = re.search(pattern, sentence, flags=re.IGNORECASE)
if match:
lang_code = lang_code_match(match.group(0).lower())
logger.info("Destination lang: " + lang_code)
return lang_code
else:
logger.info("Destination lang:" + src_lang)
return src_lang
# Remove destination language context
def remove_language_phrase(sentence):
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
logger.info("Language Phrase Removed: " + cleaned_sentence)
return cleaned_sentence
# Load Vilt
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def vilt_vqa(image, question):
inputs = vilt_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vilt_model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = vilt_model.config.id2label[idx]
logger.info("ViLT: " + answer)
# Get the top 10 scores and their indices
topk_values, topk_indices = torch.topk(logits, 10, dim=-1)
topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]]
logger.info("ViLT top 10 answers: " + str(topk_answers))
return answer
# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
def flan_t5_complete_sentence(question, answer):
# #input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
# input_text = f"What language is this question asking about: {question}"
# logger.info("T5 input: " + input_text)
# inputs = t5_tokenizer(input_text, return_tensors="pt")
# outputs = t5_model.generate(**inputs, max_length=50)
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
# logger.info("T5 output1: " + result_sentence)
# input_text = f"Translate to {str(result_sentence)}: {answer}"
# logger.info("T5 input: " + input_text)
# inputs = t5_tokenizer(input_text, return_tensors="pt")
# outputs = t5_model.generate(**inputs, max_length=50)
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
# logger.info("T5 output2: " + result_sentence)
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
logger.info("T5 input: " + input_text)
inputs = t5_tokenizer(input_text, return_tensors="pt")
outputs = t5_model.generate(**inputs, max_length=50)
result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
logger.info("T5 output: " + result_sentence)
return result_sentence
# Main function
def vqa_main(image, question):
en_question, question_src_lang = google_translate(question, dest='en')
dest_lang = find_dest_language(en_question, question_src_lang)
cleaned_question = remove_language_phrase(en_question)
vqa_answer = vilt_vqa(image, cleaned_question)
llm_answer = flan_t5_complete_sentence(cleaned_question, vqa_answer)
final_answer, answer_src_lang = google_translate(llm_answer, dest=dest_lang)
logger.info("Final Answer: " + final_answer)
return final_answer
# Home page text
title = "Interactive demo: Cross-Lingual VQA"
description = """
Upload an image, type a question, click 'submit', or click one of the examples to load them.
Note: This web demo is running on a CPU thus, may take a few minutes for completing output at times. For better performance, please consider migrating to your own space and upgrading to a GPU runtime.
"""
article = """
Supported 107 Languages: Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bosnian, Bulgarian, Catalan, Cebuano, Chichewa, Chinese (Simplified), Chinese (Traditional), Corsican, Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Frisian, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Myanmar (Burmese), Nepali, Norwegian, Odia, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scots Gaelic, Serbian, Sesotho, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uyghur, Uzbek, Vietnamese, Welsh, Xhosa, Yiddish, Yoruba, Zulu
"""
# Load example images
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkey.jpg')
# Define home page variables
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
["apple.jpg", "Qu'est-ce que j'ai dans la main en anglais?"],
["monkey.jpg", "In Korean, what are these animals called?"],
["apple.jpg", "What color is this? Answer in Uyghur."],
["monkey.jpg", "Maymunlar ne yiyor, Çince cevap ver."]
]
demo = gr.Interface(fn=vqa_main,
inputs=[image, question],
outputs="text",
examples=examples,
title=title,
description=description,
article=article)
demo.launch(debug=True, show_error = True)