Spaces:
Runtime error
Runtime error
File size: 4,330 Bytes
ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 72a8f7d ba28811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from flask import Flask, render_template, request, jsonify, send_file
from PIL import Image
import torch
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer
)
import os
from gtts import gTTS
import tempfile
from IndicTransToolkit import IndicProcessor
app = Flask(__name__)
# Initialize models
if not os.path.exists('IndicTransToolkit'):
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
# Global variables for models
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
# Cache for translated results
translation_cache = {}
def generate_caption(image):
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
generated_ids = blip_model.generate(**inputs)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
def translate_caption(caption, target_languages):
model_name = "ai4bharat/indictrans2-en-indic-1B"
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = AutoModelForSeq2SeqTranslation.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
ip = IndicProcessor(inference=True)
src_lang = "eng_Latn"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_IT2.to(DEVICE)
translations = {}
input_sentences = [caption]
for tgt_lang in target_languages:
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_tokens = model_IT2.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer_IT2.as_target_tokenizer():
generated_tokens = tokenizer_IT2.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
translations[tgt_lang] = translated_texts[0]
return translations
def generate_audio_gtts(text, lang_code):
# Create audio directory if it doesn't exist
os.makedirs('static/audio', exist_ok=True)
# Generate temporary file name
temp_filename = f"static/audio/audio_{hash(text)}_{lang_code}.mp3"
# Generate audio file
tts = gTTS(text=text, lang=lang_code)
tts.save(temp_filename)
return temp_filename
@app.route('/')
def index():
return render_template('index.html')
@app.route('/process', methods=['POST'])
def process_image():
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
image_file = request.files['image']
target_languages = request.form.getlist('languages[]')
# Process image
image = Image.open(image_file).convert('RGB')
caption = generate_caption(image)
# Generate translations
translations = translate_caption(caption, target_languages)
# Generate audio files
audio_files = {}
lang_codes = {
"hin_Deva": "hi",
"guj_Gujr": "gu",
"urd_Arab": "ur",
"mar_Deva": "mr"
}
for lang in target_languages:
lang_code = lang_codes.get(lang, "en")
audio_path = generate_audio_gtts(translations[lang], lang_code)
audio_files[lang] = audio_path.replace('static/', '')
return jsonify({
'caption': caption,
'translations': translations,
'audio_files': audio_files
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |