VDNT11's picture
Update app.py
2f388b7 verified
raw
history blame
6.17 kB
import streamlit as st
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
import os
from IndicTransToolkit import IndicProcessor
from gtts import gTTS
import soundfile as sf
from transformers import VitsTokenizer, VitsModel, set_seed
# Clone and Install IndicTransToolkit repository
if not os.path.exists('IndicTransToolkit'):
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
# Initialize BLIP for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu")
# Function to generate captions
def generate_caption(image_path):
image = Image.open(image_path).convert("RGB")
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
generated_ids = blip_model.generate(**inputs)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
# Function for translation using IndicTrans2
def translate_caption(caption, target_languages):
# Load model and tokenizer
model_name = "ai4bharat/indictrans2-en-indic-1B"
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = torch.quantization.quantize_dynamic(
model_IT2, {torch.nn.Linear}, dtype=torch.qint8
)
ip = IndicProcessor(inference=True)
# Source language (English)
src_lang = "eng_Latn"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_IT2.to(DEVICE) # Move model to the device
# Integrating with workflow now
input_sentences = [caption]
translations = {}
for tgt_lang in target_languages:
# Preprocess input sentences
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
# Tokenize the sentences and generate input encodings
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
# Generate translations using the model
with torch.no_grad():
generated_tokens = model_IT2.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer_IT2.as_target_tokenizer():
generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
# Postprocess the translations
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
translations[tgt_lang] = translated_texts[0]
return translations
# Function to generate audio using gTTS
def generate_audio_gtts(text, lang_code, output_file):
tts = gTTS(text=text, lang=lang_code)
tts.save(output_file)
return output_file
# Function to generate audio using Facebook MMS-TTS
def generate_audio_fbmms(text, model_name, output_file):
tokenizer = VitsTokenizer.from_pretrained(model_name)
model = VitsModel.from_pretrained(model_name)
inputs = tokenizer(text=text, return_tensors="pt")
set_seed(555)
with torch.no_grad():
outputs = model(**inputs)
waveform = outputs.waveform[0].cpu().numpy()
sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
return output_file
# Streamlit UI
st.title("Multilingual Assistive Model")
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
# Display the uploaded image
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Generate Caption
st.write("Generating Caption...")
caption = generate_caption(uploaded_image)
st.write(f"Caption: {caption}")
# Select target languages for translation
language_options = {
"hin_Deva": "Hindi (Devanagari)",
"mar_Deva": "Marathi (Devanagari)",
"guj_Gujr": "Gujarati (Gujrati)",
"urd_Arab": "Urdu (Arabic)",
}
target_languages = st.multiselect(
"Select target languages for translation",
list(language_options.keys()),
["hin_Deva", "mar_Deva"]
)
# Generate Translations
if target_languages:
st.write("Translating Caption...")
translations = translate_caption(caption, target_languages)
st.write("Translations:")
for lang in target_languages:
st.write(f"{language_options[lang]}: {translations[lang]}")
# Select audio generation method
audio_method = st.radio("Choose Audio Generation Method", ("gTTS (Default)", "Facebook MMS-TTS"))
# Generate audio for each target language
for lang in target_languages:
st.write(f"Generating audio for {language_options[lang]}...")
lang_code = {
"hin_Deva": "hi", # Hindi
"mar_Deva": "mr", # Marathi
"guj_Gujr": "gu", # Gujarati
"urd_Arab": "ur" # Urdu
}.get(lang, "en")
output_file = f"{lang}_audio.mp3"
if audio_method == "gTTS (Default)":
audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)
else:
model_name = "your_facebook_mms_model_name" # Update this to the correct model name
audio_file = generate_audio_fbmms(translations[lang], model_name, output_file)
st.write(f"Playing {language_options[lang]} audio:")
st.audio(audio_file)
else:
st.write("Upload an image to start.")