Spaces:
Running
Running
import gradio as gr | |
import re | |
import torch | |
from transformers.utils import logging | |
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration | |
import httpcore | |
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') # set SyncHTTPTransport attribute for googletrans dependency | |
from googletrans import Translator | |
from googletrans import LANGCODES | |
# List of acceptable languages | |
acceptable_languages = set(L.split()[0] for L in LANGCODES) | |
acceptable_languages.add("mandarin") | |
acceptable_languages.add("cantonese") | |
logging.set_verbosity_info() | |
logger = logging.get_logger("transformers") | |
# Translation | |
def google_translate(question, dest): | |
translator = Translator() | |
translation = translator.translate(question, dest=dest) | |
logger.info("Translation text: " + translation.text) | |
logger.info("Translation src: " + translation.src) | |
return (translation.text, translation.src) | |
# Lang to lang_code mapping | |
def lang_code_match(accaptable_lang): | |
# Exception for chinese langs | |
if accaptable_lang == 'mandarin': | |
return 'zh-cn' | |
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese': | |
return 'zh-tw' | |
# Default | |
else: | |
return LANGCODES[accaptable_lang] | |
# Find destination language | |
def find_dest_language(sentence, src_lang): | |
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b' | |
match = re.search(pattern, sentence, flags=re.IGNORECASE) | |
if match: | |
lang_code = lang_code_match(match.group(0).lower()) | |
logger.info("Destination lang: " + lang_code) | |
return lang_code | |
else: | |
logger.info("Destination lang:" + src_lang) | |
return src_lang | |
# Remove destination language context | |
def remove_language_phrase(sentence): | |
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it | |
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*' | |
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip() | |
logger.info("Language Phrase Removed: " + cleaned_sentence) | |
return cleaned_sentence | |
# Load Vilt | |
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def vilt_vqa(image, question): | |
inputs = vilt_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vilt_model(**inputs) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
answer = vilt_model.config.id2label[idx] | |
logger.info("ViLT: " + answer) | |
# Get the top 10 scores and their indices | |
topk_values, topk_indices = torch.topk(logits, 10, dim=-1) | |
topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]] | |
logger.info("ViLT top 10 answers: " + str(topk_answers)) | |
return answer | |
# Load FLAN-T5 | |
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") | |
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto") | |
def flan_t5_complete_sentence(question, answer): | |
# #input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." | |
# input_text = f"What language is this question asking about: {question}" | |
# logger.info("T5 input: " + input_text) | |
# inputs = t5_tokenizer(input_text, return_tensors="pt") | |
# outputs = t5_model.generate(**inputs, max_length=50) | |
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) | |
# logger.info("T5 output1: " + result_sentence) | |
# input_text = f"Translate to {str(result_sentence)}: {answer}" | |
# logger.info("T5 input: " + input_text) | |
# inputs = t5_tokenizer(input_text, return_tensors="pt") | |
# outputs = t5_model.generate(**inputs, max_length=50) | |
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) | |
# logger.info("T5 output2: " + result_sentence) | |
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." | |
logger.info("T5 input: " + input_text) | |
inputs = t5_tokenizer(input_text, return_tensors="pt") | |
outputs = t5_model.generate(**inputs, max_length=50) | |
result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) | |
logger.info("T5 output: " + result_sentence) | |
return result_sentence | |
# Main function | |
def vqa_main(image, question): | |
en_question, question_src_lang = google_translate(question, dest='en') | |
dest_lang = find_dest_language(en_question, question_src_lang) | |
cleaned_question = remove_language_phrase(en_question) | |
vqa_answer = vilt_vqa(image, cleaned_question) | |
llm_answer = flan_t5_complete_sentence(cleaned_question, vqa_answer) | |
final_answer, answer_src_lang = google_translate(llm_answer, dest=dest_lang) | |
logger.info("Final Answer: " + final_answer) | |
return final_answer | |
# Home page text | |
title = "Interactive demo: Cross-Lingual VQA" | |
description = """ | |
Upload an image, type a question, click 'submit', or click one of the examples to load them. | |
Note: This web demo is running on a CPU thus, may take a few minutes for completing output at times. For better performance, please consider migrating to your own space and upgrading to a GPU runtime. | |
""" | |
article = """ | |
Supported 107 Languages: Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bosnian, Bulgarian, Catalan, Cebuano, Chichewa, Chinese (Simplified), Chinese (Traditional), Corsican, Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Frisian, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Myanmar (Burmese), Nepali, Norwegian, Odia, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scots Gaelic, Serbian, Sesotho, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uyghur, Uzbek, Vietnamese, Welsh, Xhosa, Yiddish, Yoruba, Zulu | |
""" | |
# Load example images | |
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg') | |
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkey.jpg') | |
# Define home page variables | |
image = gr.Image(type="pil") | |
question = gr.Textbox(label="Question") | |
answer = gr.Textbox(label="Predicted answer") | |
examples = [ | |
["apple.jpg", "Qu'est-ce que j'ai dans la main en anglais?"], | |
["monkey.jpg", "In Korean, what are these animals called?"], | |
["apple.jpg", "What color is this? Answer in Uyghur."], | |
["monkey.jpg", "Maymunlar ne yiyor, Çince cevap ver."] | |
] | |
demo = gr.Interface(fn=vqa_main, | |
inputs=[image, question], | |
outputs="text", | |
examples=examples, | |
title=title, | |
description=description, | |
article=article) | |
demo.launch(debug=True, show_error = True) |