VDNT11 commited on
Commit
09eb658
1 Parent(s): 1199291

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import os
6
+ from IndicTransToolkit import IndicProcessor
7
+ from gtts import gTTS
8
+ import soundfile as sf
9
+ from transformers import VitsTokenizer, VitsModel, set_seed
10
+
11
+ # Clone and Install IndicTransToolkit repository
12
+ if not os.path.exists('IndicTransToolkit'):
13
+ os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
14
+ os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
15
+
16
+ # Ensure that IndicTransToolkit is installed and used properly
17
+ from IndicTransToolkit import IndicProcessor
18
+
19
+ # Initialize BLIP for image captioning
20
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # Function to generate captions
24
+ def generate_caption(image_path):
25
+ image = Image.open(image_path).convert("RGB")
26
+ inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
27
+ with torch.no_grad():
28
+ generated_ids = blip_model.generate(**inputs)
29
+ caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
30
+ return caption
31
+
32
+ # Function for translation using IndicTrans2
33
+ def translate_caption(caption, target_languages):
34
+ # Load model and tokenizer
35
+ model_name = "ai4bharat/indictrans2-en-indic-1B"
36
+ tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
37
+ model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
38
+ model_IT2 = torch.quantization.quantize_dynamic(
39
+ model_IT2, {torch.nn.Linear}, dtype=torch.qint8
40
+ )
41
+
42
+ ip = IndicProcessor(inference=True)
43
+
44
+ # Source language (English)
45
+ src_lang = "eng_Latn"
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model_IT2.to(DEVICE) # Move model to the device
48
+
49
+ # Integrating with workflow now
50
+ input_sentences = [caption]
51
+ translations = {}
52
+
53
+ for tgt_lang in target_languages:
54
+ # Preprocess input sentences
55
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
56
+
57
+ # Tokenize the sentences and generate input encodings
58
+ inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
59
+
60
+ # Generate translations using the model
61
+ with torch.no_grad():
62
+ generated_tokens = model_IT2.generate(
63
+ **inputs,
64
+ use_cache=True,
65
+ min_length=0,
66
+ max_length=256,
67
+ num_beams=5,
68
+ num_return_sequences=1,
69
+ )
70
+
71
+ # Decode the generated tokens into text
72
+ with tokenizer_IT2.as_target_tokenizer():
73
+ generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
74
+
75
+ # Postprocess the translations
76
+ translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
77
+ translations[tgt_lang] = translated_texts[0]
78
+
79
+ return translations
80
+
81
+ # Function to generate audio using gTTS
82
+ def generate_audio_gtts(text, lang_code, output_file):
83
+ tts = gTTS(text=text, lang=lang_code)
84
+ tts.save(output_file)
85
+ return output_file
86
+
87
+ # Function to generate audio using Facebook MMS-TTS
88
+ def generate_audio_fbmms(text, model_name, output_file):
89
+ tokenizer = VitsTokenizer.from_pretrained(model_name)
90
+ model = VitsModel.from_pretrained(model_name)
91
+ inputs = tokenizer(text=text, return_tensors="pt")
92
+ set_seed(555)
93
+ with torch.no_grad():
94
+ outputs = model(**inputs)
95
+ waveform = outputs.waveform[0].cpu().numpy()
96
+ sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
97
+ return output_file
98
+
99
+ # Streamlit UI
100
+ st.title("Multilingual Assistive Model")
101
+
102
+ uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
103
+
104
+ if uploaded_image is not None:
105
+ # Display the uploaded image
106
+ image = Image.open(uploaded_image)
107
+ st.image(image, caption="Uploaded Image", use_column_width=True)
108
+
109
+ # Generate Caption
110
+ st.write("Generating Caption...")
111
+ caption = generate_caption(uploaded_image)
112
+ st.write(f"Caption: {caption}")
113
+
114
+ # Select target languages for translation
115
+ target_languages = st.multiselect(
116
+ "Select target languages for translation",
117
+ ["hin_Deva", "mar_Deva", "guj_Gujr", "urd_Arab"], # Add more languages as needed
118
+ ["hin_Deva", "mar_Deva"]
119
+ )
120
+
121
+ # Generate Translations
122
+ if target_languages:
123
+ st.write("Translating Caption...")
124
+ translations = translate_caption(caption, target_languages)
125
+ st.write("Translations:")
126
+ for lang, translation in translations.items():
127
+ st.write(f"{lang}: {translation}")
128
+
129
+ # Default to gTTS for TTS
130
+ for lang in target_languages:
131
+ st.write(f"Using gTTS for {lang}...")
132
+ lang_code = {
133
+ "hin_Deva": "hi", # Hindi
134
+ "guj_Gujr": "gu", # Gujarati
135
+ "urd_Arab": "ur" # Urdu
136
+ }.get(lang, "en")
137
+ output_file = f"{lang}_gTTS.mp3"
138
+ audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)
139
+
140
+ st.write(f"Playing {lang} audio:")
141
+ st.audio(audio_file)
142
+ else:
143
+ st.write("Upload an image to start.")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ Pillow
5
+ git+https://github.com/VarunGumma/IndicTransToolkit.git
6
+ gtts
7
+ soundfile
8
+ matplotlib
9
+ numpy
10
+ pandas