Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,93 +1,129 @@
|
|
1 |
-
from flask import Flask, request,
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
-
from transformers import
|
5 |
-
|
|
|
|
|
|
|
|
|
6 |
import os
|
7 |
-
|
8 |
-
|
9 |
from IndicTransToolkit import IndicProcessor
|
10 |
|
11 |
-
# Initialize Flask app
|
12 |
app = Flask(__name__)
|
13 |
-
UPLOAD_FOLDER = "./static/uploads/"
|
14 |
-
AUDIO_FOLDER = "./static/audio/"
|
15 |
-
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
16 |
-
os.makedirs(AUDIO_FOLDER, exist_ok=True)
|
17 |
-
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
|
18 |
-
app.config["AUDIO_FOLDER"] = AUDIO_FOLDER
|
19 |
|
20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
21 |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
22 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
-
model_name = "ai4bharat/indictrans2-en-indic-1B"
|
24 |
-
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
25 |
-
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
|
26 |
-
model_IT2 = torch.quantization.quantize_dynamic(
|
27 |
-
model_IT2, {torch.nn.Linear}, dtype=torch.qint8
|
28 |
-
)
|
29 |
-
model_IT2.to("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
-
ip = IndicProcessor(inference=True)
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
|
|
|
35 |
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
with torch.no_grad():
|
37 |
generated_ids = blip_model.generate(**inputs)
|
38 |
-
|
|
|
39 |
|
40 |
def translate_caption(caption, target_languages):
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
src_lang = "eng_Latn"
|
42 |
-
|
|
|
|
|
43 |
translations = {}
|
44 |
-
|
|
|
45 |
for tgt_lang in target_languages:
|
46 |
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
47 |
-
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(
|
|
|
48 |
with torch.no_grad():
|
49 |
generated_tokens = model_IT2.generate(
|
50 |
-
**inputs,
|
|
|
|
|
|
|
|
|
|
|
51 |
)
|
|
|
52 |
with tokenizer_IT2.as_target_tokenizer():
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return translations
|
56 |
|
57 |
-
def generate_audio_gtts(text, lang_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
tts = gTTS(text=text, lang=lang_code)
|
59 |
-
tts.save(
|
60 |
-
|
|
|
61 |
|
62 |
-
@app.route(
|
63 |
def index():
|
64 |
-
|
65 |
-
image_file = request.files.get("image")
|
66 |
-
if image_file:
|
67 |
-
image_path = os.path.join(app.config["UPLOAD_FOLDER"], image_file.filename)
|
68 |
-
image_file.save(image_path)
|
69 |
-
|
70 |
-
caption = generate_caption(image_path)
|
71 |
-
target_languages = request.form.getlist("languages")
|
72 |
-
translations = translate_caption(caption, target_languages)
|
73 |
-
|
74 |
-
audio_files = {}
|
75 |
-
lang_codes = {
|
76 |
-
"hin_Deva": "hi", "guj_Gujr": "gu", "urd_Arab": "ur", "mar_Deva": "mr"
|
77 |
-
}
|
78 |
-
for lang, translation in translations.items():
|
79 |
-
lang_code = lang_codes.get(lang, "en")
|
80 |
-
audio_file_path = os.path.join(app.config["AUDIO_FOLDER"], f"{lang}.mp3")
|
81 |
-
audio_files[lang] = generate_audio_gtts(translation, lang_code, audio_file_path)
|
82 |
-
|
83 |
-
return render_template(
|
84 |
-
"index.html", image_path=image_path, caption=caption, translations=translations, audio_files=audio_files
|
85 |
-
)
|
86 |
-
return render_template("index.html")
|
87 |
|
88 |
-
@app.route(
|
89 |
-
def
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
if __name__ ==
|
93 |
-
app.run(host=
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify, send_file
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
+
from transformers import (
|
5 |
+
BlipProcessor,
|
6 |
+
BlipForConditionalGeneration,
|
7 |
+
AutoModelForSeq2SeqTranslation,
|
8 |
+
AutoTokenizer
|
9 |
+
)
|
10 |
import os
|
11 |
+
from gtts import gTTS
|
12 |
+
import tempfile
|
13 |
from IndicTransToolkit import IndicProcessor
|
14 |
|
|
|
15 |
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# Initialize models
|
18 |
+
if not os.path.exists('IndicTransToolkit'):
|
19 |
+
os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
|
20 |
+
os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
|
21 |
+
|
22 |
+
# Global variables for models
|
23 |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
24 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
# Cache for translated results
|
27 |
+
translation_cache = {}
|
28 |
+
|
29 |
+
def generate_caption(image):
|
30 |
inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
with torch.no_grad():
|
32 |
generated_ids = blip_model.generate(**inputs)
|
33 |
+
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
|
34 |
+
return caption
|
35 |
|
36 |
def translate_caption(caption, target_languages):
|
37 |
+
model_name = "ai4bharat/indictrans2-en-indic-1B"
|
38 |
+
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
39 |
+
model_IT2 = AutoModelForSeq2SeqTranslation.from_pretrained(model_name, trust_remote_code=True)
|
40 |
+
model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
|
41 |
+
|
42 |
+
ip = IndicProcessor(inference=True)
|
43 |
src_lang = "eng_Latn"
|
44 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
model_IT2.to(DEVICE)
|
46 |
+
|
47 |
translations = {}
|
48 |
+
input_sentences = [caption]
|
49 |
+
|
50 |
for tgt_lang in target_languages:
|
51 |
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
52 |
+
inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
|
53 |
+
|
54 |
with torch.no_grad():
|
55 |
generated_tokens = model_IT2.generate(
|
56 |
+
**inputs,
|
57 |
+
use_cache=True,
|
58 |
+
min_length=0,
|
59 |
+
max_length=256,
|
60 |
+
num_beams=5,
|
61 |
+
num_return_sequences=1,
|
62 |
)
|
63 |
+
|
64 |
with tokenizer_IT2.as_target_tokenizer():
|
65 |
+
generated_tokens = tokenizer_IT2.batch_decode(
|
66 |
+
generated_tokens.detach().cpu().tolist(),
|
67 |
+
skip_special_tokens=True,
|
68 |
+
clean_up_tokenization_spaces=True
|
69 |
+
)
|
70 |
+
|
71 |
+
translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
72 |
+
translations[tgt_lang] = translated_texts[0]
|
73 |
+
|
74 |
return translations
|
75 |
|
76 |
+
def generate_audio_gtts(text, lang_code):
|
77 |
+
# Create audio directory if it doesn't exist
|
78 |
+
os.makedirs('static/audio', exist_ok=True)
|
79 |
+
|
80 |
+
# Generate temporary file name
|
81 |
+
temp_filename = f"static/audio/audio_{hash(text)}_{lang_code}.mp3"
|
82 |
+
|
83 |
+
# Generate audio file
|
84 |
tts = gTTS(text=text, lang=lang_code)
|
85 |
+
tts.save(temp_filename)
|
86 |
+
|
87 |
+
return temp_filename
|
88 |
|
89 |
+
@app.route('/')
|
90 |
def index():
|
91 |
+
return render_template('index.html')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
@app.route('/process', methods=['POST'])
|
94 |
+
def process_image():
|
95 |
+
if 'image' not in request.files:
|
96 |
+
return jsonify({'error': 'No image uploaded'}), 400
|
97 |
+
|
98 |
+
image_file = request.files['image']
|
99 |
+
target_languages = request.form.getlist('languages[]')
|
100 |
+
|
101 |
+
# Process image
|
102 |
+
image = Image.open(image_file).convert('RGB')
|
103 |
+
caption = generate_caption(image)
|
104 |
+
|
105 |
+
# Generate translations
|
106 |
+
translations = translate_caption(caption, target_languages)
|
107 |
+
|
108 |
+
# Generate audio files
|
109 |
+
audio_files = {}
|
110 |
+
lang_codes = {
|
111 |
+
"hin_Deva": "hi",
|
112 |
+
"guj_Gujr": "gu",
|
113 |
+
"urd_Arab": "ur",
|
114 |
+
"mar_Deva": "mr"
|
115 |
+
}
|
116 |
+
|
117 |
+
for lang in target_languages:
|
118 |
+
lang_code = lang_codes.get(lang, "en")
|
119 |
+
audio_path = generate_audio_gtts(translations[lang], lang_code)
|
120 |
+
audio_files[lang] = audio_path.replace('static/', '')
|
121 |
+
|
122 |
+
return jsonify({
|
123 |
+
'caption': caption,
|
124 |
+
'translations': translations,
|
125 |
+
'audio_files': audio_files
|
126 |
+
})
|
127 |
|
128 |
+
if __name__ == '__main__':
|
129 |
+
app.run(host='0.0.0.0', port=7860)
|