Spaces:
Runtime error
Runtime error
from flask import Flask, request, render_template, send_from_directory | |
from PIL import Image | |
import torch | |
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer | |
from gtts import gTTS | |
import os | |
import soundfile as sf | |
from transformers import VitsTokenizer, VitsModel, set_seed | |
from IndicTransToolkit import IndicProcessor | |
# Initialize Flask app | |
app = Flask(__name__) | |
UPLOAD_FOLDER = "./static/uploads/" | |
AUDIO_FOLDER = "./static/audio/" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(AUDIO_FOLDER, exist_ok=True) | |
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER | |
app.config["AUDIO_FOLDER"] = AUDIO_FOLDER | |
# Load models | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu") | |
model_name = "ai4bharat/indictrans2-en-indic-1B" | |
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
model_IT2 = torch.quantization.quantize_dynamic( | |
model_IT2, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
model_IT2.to("cuda" if torch.cuda.is_available() else "cpu") | |
ip = IndicProcessor(inference=True) | |
# Functions | |
def generate_caption(image_path): | |
image = Image.open(image_path).convert("RGB") | |
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(): | |
generated_ids = blip_model.generate(**inputs) | |
return blip_processor.decode(generated_ids[0], skip_special_tokens=True) | |
def translate_caption(caption, target_languages): | |
src_lang = "eng_Latn" | |
input_sentences = [caption] | |
translations = {} | |
for tgt_lang in target_languages: | |
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang) | |
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(): | |
generated_tokens = model_IT2.generate( | |
**inputs, min_length=0, max_length=256, num_beams=5, num_return_sequences=1 | |
) | |
with tokenizer_IT2.as_target_tokenizer(): | |
translated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
translations[tgt_lang] = ip.postprocess_batch(translated_tokens, lang=tgt_lang)[0] | |
return translations | |
def generate_audio_gtts(text, lang_code, output_file): | |
tts = gTTS(text=text, lang=lang_code) | |
tts.save(output_file) | |
return output_file | |
def index(): | |
if request.method == "POST": | |
image_file = request.files.get("image") | |
if image_file: | |
image_path = os.path.join(app.config["UPLOAD_FOLDER"], image_file.filename) | |
image_file.save(image_path) | |
caption = generate_caption(image_path) | |
target_languages = request.form.getlist("languages") | |
translations = translate_caption(caption, target_languages) | |
audio_files = {} | |
lang_codes = { | |
"hin_Deva": "hi", "guj_Gujr": "gu", "urd_Arab": "ur", "mar_Deva": "mr" | |
} | |
for lang, translation in translations.items(): | |
lang_code = lang_codes.get(lang, "en") | |
audio_file_path = os.path.join(app.config["AUDIO_FOLDER"], f"{lang}.mp3") | |
audio_files[lang] = generate_audio_gtts(translation, lang_code, audio_file_path) | |
return render_template( | |
"index.html", image_path=image_path, caption=caption, translations=translations, audio_files=audio_files | |
) | |
return render_template("index.html") | |
def audio(filename): | |
return send_from_directory(app.config["AUDIO_FOLDER"], filename) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |