|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer |
|
import os |
|
from IndicTransToolkit import IndicProcessor |
|
from gtts import gTTS |
|
import soundfile as sf |
|
from transformers import VitsTokenizer, VitsModel, set_seed |
|
|
|
|
|
if not os.path.exists('IndicTransToolkit'): |
|
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit') |
|
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./') |
|
|
|
|
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def generate_caption(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") |
|
with torch.no_grad(): |
|
generated_ids = blip_model.generate(**inputs) |
|
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True) |
|
return caption |
|
|
|
|
|
def translate_caption(caption, target_languages): |
|
|
|
model_name = "ai4bharat/indictrans2-en-indic-1B" |
|
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) |
|
model_IT2 = torch.quantization.quantize_dynamic( |
|
model_IT2, {torch.nn.Linear}, dtype=torch.qint8 |
|
) |
|
|
|
ip = IndicProcessor(inference=True) |
|
|
|
|
|
src_lang = "eng_Latn" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_IT2.to(DEVICE) |
|
|
|
|
|
input_sentences = [caption] |
|
translations = {} |
|
|
|
for tgt_lang in target_languages: |
|
|
|
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang) |
|
|
|
|
|
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_tokens = model_IT2.generate( |
|
**inputs, |
|
use_cache=True, |
|
min_length=0, |
|
max_length=256, |
|
num_beams=5, |
|
num_return_sequences=1, |
|
) |
|
|
|
|
|
with tokenizer_IT2.as_target_tokenizer(): |
|
generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
|
|
|
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang) |
|
translations[tgt_lang] = translated_texts[0] |
|
|
|
return translations |
|
|
|
|
|
def generate_audio_gtts(text, lang_code, output_file): |
|
tts = gTTS(text=text, lang=lang_code) |
|
tts.save(output_file) |
|
return output_file |
|
|
|
|
|
def generate_audio_fbmms(text, model_name, output_file): |
|
tokenizer = VitsTokenizer.from_pretrained(model_name) |
|
model = VitsModel.from_pretrained(model_name) |
|
inputs = tokenizer(text=text, return_tensors="pt") |
|
set_seed(555) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
waveform = outputs.waveform[0].cpu().numpy() |
|
sf.write(output_file, waveform, samplerate=model.config.sampling_rate) |
|
return output_file |
|
|
|
|
|
st.title("Multilingual Assistive Model") |
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image is not None: |
|
|
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
st.write("Generating Caption...") |
|
caption = generate_caption(uploaded_image) |
|
st.write(f"Caption: {caption}") |
|
|
|
|
|
language_options = { |
|
"hin_Deva": "Hindi (Devanagari)", |
|
"mar_Deva": "Marathi (Devanagari)", |
|
"guj_Gujr": "Gujarati (Gujrati)", |
|
"urd_Arab": "Urdu (Arabic)", |
|
} |
|
|
|
target_languages = st.multiselect( |
|
"Select target languages for translation", |
|
list(language_options.keys()), |
|
["hin_Deva", "mar_Deva"] |
|
) |
|
|
|
|
|
if target_languages: |
|
st.write("Translating Caption...") |
|
translations = translate_caption(caption, target_languages) |
|
st.write("Translations:") |
|
for lang in target_languages: |
|
st.write(f"{language_options[lang]}: {translations[lang]}") |
|
|
|
|
|
audio_method = st.radio("Choose Audio Generation Method", ("gTTS (Default)", "Facebook MMS-TTS")) |
|
|
|
|
|
for lang in target_languages: |
|
st.write(f"Generating audio for {language_options[lang]}...") |
|
|
|
lang_code = { |
|
"hin_Deva": "hi", |
|
"mar_Deva": "mr", |
|
"guj_Gujr": "gu", |
|
"urd_Arab": "ur" |
|
}.get(lang, "en") |
|
|
|
output_file = f"{lang}_audio.mp3" |
|
|
|
if audio_method == "gTTS (Default)": |
|
audio_file = generate_audio_gtts(translations[lang], lang_code, output_file) |
|
else: |
|
model_name = "your_facebook_mms_model_name" |
|
audio_file = generate_audio_fbmms(translations[lang], model_name, output_file) |
|
|
|
st.write(f"Playing {language_options[lang]} audio:") |
|
st.audio(audio_file) |
|
else: |
|
st.write("Upload an image to start.") |
|
|