File size: 4,941 Bytes
09eb658
 
 
 
 
 
 
 
 
 
a641f96
 
 
09eb658
a641f96
09eb658
4012b48
 
09eb658
64dfa3e
a641f96
 
 
 
 
 
 
 
64dfa3e
a641f96
 
 
 
 
 
 
059ddc0
09eb658
a641f96
 
 
 
 
64dfa3e
 
a641f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64dfa3e
a641f96
 
 
 
 
64dfa3e
a641f96
 
 
 
 
 
 
 
 
 
059ddc0
a641f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09eb658
a641f96
64dfa3e
a641f96
09eb658
 
a641f96
 
 
 
 
 
 
2f388b7
a641f96
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
import os
from IndicTransToolkit import IndicProcessor
from gtts import gTTS
import soundfile as sf
from transformers import VitsTokenizer, VitsModel, set_seed

if not os.path.exists('IndicTransToolkit'):
    os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
    os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')

from IndicTransToolkit import IndicProcessor

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")

@st.cache_resource
def generate_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        generated_ids = blip_model.generate(**inputs)
    caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
    return caption

@st.cache_resource
def translate_caption(caption, target_languages):
    # Load model and tokenizer
    model_name = "ai4bharat/indictrans2-en-indic-1B"
    tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
    model_IT2 = torch.quantization.quantize_dynamic(
        model_IT2, {torch.nn.Linear}, dtype=torch.qint8
    )

    ip = IndicProcessor(inference=True)

    # Source language (English)
    src_lang = "eng_Latn"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model_IT2.to(DEVICE)  
    
    input_sentences = [caption]
    translations = {}

    for tgt_lang in target_languages:
        batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)

        inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)

        with torch.no_grad():
            generated_tokens = model_IT2.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer_IT2.as_target_tokenizer():
            generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)

        translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
        translations[tgt_lang] = translated_texts[0]

    return translations

@st.cache_resource
def generate_audio_gtts(text, lang_code, output_file):
    tts = gTTS(text=text, lang=lang_code)
    tts.save(output_file)
    return output_file

@st.cache_resource
def generate_audio_fbmms(text, model_name, output_file):
    tokenizer = VitsTokenizer.from_pretrained(model_name)
    model = VitsModel.from_pretrained(model_name)
    inputs = tokenizer(text=text, return_tensors="pt")
    set_seed(555)
    with torch.no_grad():
        outputs = model(**inputs)
    waveform = outputs.waveform[0].cpu().numpy()
    sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
    return output_file

# Streamlit UI
st.title("Multilingual Assistive Model")

uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])

if uploaded_image is not None:
    # Display the uploaded image
    image = Image.open(uploaded_image)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    # Generate Caption
    st.write("Generating Caption...")
    caption = generate_caption(uploaded_image)
    st.write(f"Caption: {caption}")

    # Select target languages for translation
    target_languages = st.multiselect(
        "Select target languages for translation", 
        ["hin_Deva", "mar_Deva", "guj_Gujr", "urd_Arab"], 
        ["hin_Deva", "mar_Deva"]
    )

    if target_languages:
        st.write("Translating Caption...")
        translations = translate_caption(caption, target_languages)
        st.write("Translations:")
        for lang, translation in translations.items():
            st.write(f"{lang}: {translation}")
        
        for lang in target_languages:
            st.write(f"Using gTTS for {lang}...")
            lang_code = {
                "hin_Deva": "hi",  # Hindi
                "guj_Gujr": "gu",  # Gujarati
                "urd_Arab": "ur"   # Urdu
            }.get(lang, "en")
            output_file = f"{lang}_gTTS.mp3"
            audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)

            st.write(f"Playing {lang} audio:")
            st.audio(audio_file)
else:
    st.write("Upload an image to start.")