VDNT11's picture
Update app.py
0d52963 verified
raw
history blame
4.33 kB
from flask import Flask, render_template, request, jsonify, send_file
from PIL import Image
import torch
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer
)
import os
from gtts import gTTS
import tempfile
from IndicTransToolkit import IndicProcessor
app = Flask(__name__)
# Initialize models
if not os.path.exists('IndicTransToolkit'):
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
# Global variables for models
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
# Cache for translated results
translation_cache = {}
def generate_caption(image):
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
generated_ids = blip_model.generate(**inputs)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
def translate_caption(caption, target_languages):
model_name = "ai4bharat/indictrans2-en-indic-1B"
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = AutoModelForSeq2SeqTranslation.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
ip = IndicProcessor(inference=True)
src_lang = "eng_Latn"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_IT2.to(DEVICE)
translations = {}
input_sentences = [caption]
for tgt_lang in target_languages:
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_tokens = model_IT2.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer_IT2.as_target_tokenizer():
generated_tokens = tokenizer_IT2.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
translations[tgt_lang] = translated_texts[0]
return translations
def generate_audio_gtts(text, lang_code):
# Create audio directory if it doesn't exist
os.makedirs('static/audio', exist_ok=True)
# Generate temporary file name
temp_filename = f"static/audio/audio_{hash(text)}_{lang_code}.mp3"
# Generate audio file
tts = gTTS(text=text, lang=lang_code)
tts.save(temp_filename)
return temp_filename
@app.route('/')
def index():
return render_template('index.html')
@app.route('/process', methods=['POST'])
def process_image():
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
image_file = request.files['image']
target_languages = request.form.getlist('languages[]')
# Process image
image = Image.open(image_file).convert('RGB')
caption = generate_caption(image)
# Generate translations
translations = translate_caption(caption, target_languages)
# Generate audio files
audio_files = {}
lang_codes = {
"hin_Deva": "hi",
"guj_Gujr": "gu",
"urd_Arab": "ur",
"mar_Deva": "mr"
}
for lang in target_languages:
lang_code = lang_codes.get(lang, "en")
audio_path = generate_audio_gtts(translations[lang], lang_code)
audio_files[lang] = audio_path.replace('static/', '')
return jsonify({
'caption': caption,
'translations': translations,
'audio_files': audio_files
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)