Upload 2 files
Browse files- app.py +143 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
+
import os
|
6 |
+
from IndicTransToolkit import IndicProcessor
|
7 |
+
from gtts import gTTS
|
8 |
+
import soundfile as sf
|
9 |
+
from transformers import VitsTokenizer, VitsModel, set_seed
|
10 |
+
|
11 |
+
# Clone and Install IndicTransToolkit repository
|
12 |
+
if not os.path.exists('IndicTransToolkit'):
|
13 |
+
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
|
14 |
+
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
|
15 |
+
|
16 |
+
# Ensure that IndicTransToolkit is installed and used properly
|
17 |
+
from IndicTransToolkit import IndicProcessor
|
18 |
+
|
19 |
+
# Initialize BLIP for image captioning
|
20 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
21 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
+
# Function to generate captions
|
24 |
+
def generate_caption(image_path):
|
25 |
+
image = Image.open(image_path).convert("RGB")
|
26 |
+
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
with torch.no_grad():
|
28 |
+
generated_ids = blip_model.generate(**inputs)
|
29 |
+
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
|
30 |
+
return caption
|
31 |
+
|
32 |
+
# Function for translation using IndicTrans2
|
33 |
+
def translate_caption(caption, target_languages):
|
34 |
+
# Load model and tokenizer
|
35 |
+
model_name = "ai4bharat/indictrans2-en-indic-1B"
|
36 |
+
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
37 |
+
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
|
38 |
+
model_IT2 = torch.quantization.quantize_dynamic(
|
39 |
+
model_IT2, {torch.nn.Linear}, dtype=torch.qint8
|
40 |
+
)
|
41 |
+
|
42 |
+
ip = IndicProcessor(inference=True)
|
43 |
+
|
44 |
+
# Source language (English)
|
45 |
+
src_lang = "eng_Latn"
|
46 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
+
model_IT2.to(DEVICE) # Move model to the device
|
48 |
+
|
49 |
+
# Integrating with workflow now
|
50 |
+
input_sentences = [caption]
|
51 |
+
translations = {}
|
52 |
+
|
53 |
+
for tgt_lang in target_languages:
|
54 |
+
# Preprocess input sentences
|
55 |
+
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
56 |
+
|
57 |
+
# Tokenize the sentences and generate input encodings
|
58 |
+
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
|
59 |
+
|
60 |
+
# Generate translations using the model
|
61 |
+
with torch.no_grad():
|
62 |
+
generated_tokens = model_IT2.generate(
|
63 |
+
**inputs,
|
64 |
+
use_cache=True,
|
65 |
+
min_length=0,
|
66 |
+
max_length=256,
|
67 |
+
num_beams=5,
|
68 |
+
num_return_sequences=1,
|
69 |
+
)
|
70 |
+
|
71 |
+
# Decode the generated tokens into text
|
72 |
+
with tokenizer_IT2.as_target_tokenizer():
|
73 |
+
generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
74 |
+
|
75 |
+
# Postprocess the translations
|
76 |
+
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
77 |
+
translations[tgt_lang] = translated_texts[0]
|
78 |
+
|
79 |
+
return translations
|
80 |
+
|
81 |
+
# Function to generate audio using gTTS
|
82 |
+
def generate_audio_gtts(text, lang_code, output_file):
|
83 |
+
tts = gTTS(text=text, lang=lang_code)
|
84 |
+
tts.save(output_file)
|
85 |
+
return output_file
|
86 |
+
|
87 |
+
# Function to generate audio using Facebook MMS-TTS
|
88 |
+
def generate_audio_fbmms(text, model_name, output_file):
|
89 |
+
tokenizer = VitsTokenizer.from_pretrained(model_name)
|
90 |
+
model = VitsModel.from_pretrained(model_name)
|
91 |
+
inputs = tokenizer(text=text, return_tensors="pt")
|
92 |
+
set_seed(555)
|
93 |
+
with torch.no_grad():
|
94 |
+
outputs = model(**inputs)
|
95 |
+
waveform = outputs.waveform[0].cpu().numpy()
|
96 |
+
sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
|
97 |
+
return output_file
|
98 |
+
|
99 |
+
# Streamlit UI
|
100 |
+
st.title("Multilingual Assistive Model")
|
101 |
+
|
102 |
+
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
|
103 |
+
|
104 |
+
if uploaded_image is not None:
|
105 |
+
# Display the uploaded image
|
106 |
+
image = Image.open(uploaded_image)
|
107 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
108 |
+
|
109 |
+
# Generate Caption
|
110 |
+
st.write("Generating Caption...")
|
111 |
+
caption = generate_caption(uploaded_image)
|
112 |
+
st.write(f"Caption: {caption}")
|
113 |
+
|
114 |
+
# Select target languages for translation
|
115 |
+
target_languages = st.multiselect(
|
116 |
+
"Select target languages for translation",
|
117 |
+
["hin_Deva", "mar_Deva", "guj_Gujr", "urd_Arab"], # Add more languages as needed
|
118 |
+
["hin_Deva", "mar_Deva"]
|
119 |
+
)
|
120 |
+
|
121 |
+
# Generate Translations
|
122 |
+
if target_languages:
|
123 |
+
st.write("Translating Caption...")
|
124 |
+
translations = translate_caption(caption, target_languages)
|
125 |
+
st.write("Translations:")
|
126 |
+
for lang, translation in translations.items():
|
127 |
+
st.write(f"{lang}: {translation}")
|
128 |
+
|
129 |
+
# Default to gTTS for TTS
|
130 |
+
for lang in target_languages:
|
131 |
+
st.write(f"Using gTTS for {lang}...")
|
132 |
+
lang_code = {
|
133 |
+
"hin_Deva": "hi", # Hindi
|
134 |
+
"guj_Gujr": "gu", # Gujarati
|
135 |
+
"urd_Arab": "ur" # Urdu
|
136 |
+
}.get(lang, "en")
|
137 |
+
output_file = f"{lang}_gTTS.mp3"
|
138 |
+
audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)
|
139 |
+
|
140 |
+
st.write(f"Playing {lang} audio:")
|
141 |
+
st.audio(audio_file)
|
142 |
+
else:
|
143 |
+
st.write("Upload an image to start.")
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
Pillow
|
5 |
+
git+https://github.com/VarunGumma/IndicTransToolkit.git
|
6 |
+
gtts
|
7 |
+
soundfile
|
8 |
+
matplotlib
|
9 |
+
numpy
|
10 |
+
pandas
|