VDNT11 commited on
Commit
ba28811
·
verified ·
1 Parent(s): 8389277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -63
app.py CHANGED
@@ -1,93 +1,129 @@
1
- from flask import Flask, request, render_template, send_from_directory
2
  from PIL import Image
3
  import torch
4
- from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
5
- from gtts import gTTS
 
 
 
 
6
  import os
7
- import soundfile as sf
8
- from transformers import VitsTokenizer, VitsModel, set_seed
9
  from IndicTransToolkit import IndicProcessor
10
 
11
- # Initialize Flask app
12
  app = Flask(__name__)
13
- UPLOAD_FOLDER = "./static/uploads/"
14
- AUDIO_FOLDER = "./static/audio/"
15
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
16
- os.makedirs(AUDIO_FOLDER, exist_ok=True)
17
- app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
18
- app.config["AUDIO_FOLDER"] = AUDIO_FOLDER
19
 
20
- # Load models
 
 
 
 
 
21
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
22
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
23
- model_name = "ai4bharat/indictrans2-en-indic-1B"
24
- tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
- model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
26
- model_IT2 = torch.quantization.quantize_dynamic(
27
- model_IT2, {torch.nn.Linear}, dtype=torch.qint8
28
- )
29
- model_IT2.to("cuda" if torch.cuda.is_available() else "cpu")
30
- ip = IndicProcessor(inference=True)
31
 
32
- # Functions
33
- def generate_caption(image_path):
34
- image = Image.open(image_path).convert("RGB")
 
35
  inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
36
  with torch.no_grad():
37
  generated_ids = blip_model.generate(**inputs)
38
- return blip_processor.decode(generated_ids[0], skip_special_tokens=True)
 
39
 
40
  def translate_caption(caption, target_languages):
 
 
 
 
 
 
41
  src_lang = "eng_Latn"
42
- input_sentences = [caption]
 
 
43
  translations = {}
44
-
 
45
  for tgt_lang in target_languages:
46
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
47
- inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
 
48
  with torch.no_grad():
49
  generated_tokens = model_IT2.generate(
50
- **inputs, min_length=0, max_length=256, num_beams=5, num_return_sequences=1
 
 
 
 
 
51
  )
 
52
  with tokenizer_IT2.as_target_tokenizer():
53
- translated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
54
- translations[tgt_lang] = ip.postprocess_batch(translated_tokens, lang=tgt_lang)[0]
 
 
 
 
 
 
 
55
  return translations
56
 
57
- def generate_audio_gtts(text, lang_code, output_file):
 
 
 
 
 
 
 
58
  tts = gTTS(text=text, lang=lang_code)
59
- tts.save(output_file)
60
- return output_file
 
61
 
62
- @app.route("/", methods=["GET", "POST"])
63
  def index():
64
- if request.method == "POST":
65
- image_file = request.files.get("image")
66
- if image_file:
67
- image_path = os.path.join(app.config["UPLOAD_FOLDER"], image_file.filename)
68
- image_file.save(image_path)
69
-
70
- caption = generate_caption(image_path)
71
- target_languages = request.form.getlist("languages")
72
- translations = translate_caption(caption, target_languages)
73
-
74
- audio_files = {}
75
- lang_codes = {
76
- "hin_Deva": "hi", "guj_Gujr": "gu", "urd_Arab": "ur", "mar_Deva": "mr"
77
- }
78
- for lang, translation in translations.items():
79
- lang_code = lang_codes.get(lang, "en")
80
- audio_file_path = os.path.join(app.config["AUDIO_FOLDER"], f"{lang}.mp3")
81
- audio_files[lang] = generate_audio_gtts(translation, lang_code, audio_file_path)
82
-
83
- return render_template(
84
- "index.html", image_path=image_path, caption=caption, translations=translations, audio_files=audio_files
85
- )
86
- return render_template("index.html")
87
 
88
- @app.route("/audio/<filename>")
89
- def audio(filename):
90
- return send_from_directory(app.config["AUDIO_FOLDER"], filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- if __name__ == "__main__":
93
- app.run(host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, render_template, request, jsonify, send_file
2
  from PIL import Image
3
  import torch
4
+ from transformers import (
5
+ BlipProcessor,
6
+ BlipForConditionalGeneration,
7
+ AutoModelForSeq2SeqTranslation,
8
+ AutoTokenizer
9
+ )
10
  import os
11
+ from gtts import gTTS
12
+ import tempfile
13
  from IndicTransToolkit import IndicProcessor
14
 
 
15
  app = Flask(__name__)
 
 
 
 
 
 
16
 
17
+ # Initialize models
18
+ if not os.path.exists('IndicTransToolkit'):
19
+ os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
20
+ os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
21
+
22
+ # Global variables for models
23
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
24
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
25
 
26
+ # Cache for translated results
27
+ translation_cache = {}
28
+
29
+ def generate_caption(image):
30
  inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
31
  with torch.no_grad():
32
  generated_ids = blip_model.generate(**inputs)
33
+ caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
34
+ return caption
35
 
36
  def translate_caption(caption, target_languages):
37
+ model_name = "ai4bharat/indictrans2-en-indic-1B"
38
+ tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
39
+ model_IT2 = AutoModelForSeq2SeqTranslation.from_pretrained(model_name, trust_remote_code=True)
40
+ model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
41
+
42
+ ip = IndicProcessor(inference=True)
43
  src_lang = "eng_Latn"
44
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
45
+ model_IT2.to(DEVICE)
46
+
47
  translations = {}
48
+ input_sentences = [caption]
49
+
50
  for tgt_lang in target_languages:
51
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
52
+ inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
53
+
54
  with torch.no_grad():
55
  generated_tokens = model_IT2.generate(
56
+ **inputs,
57
+ use_cache=True,
58
+ min_length=0,
59
+ max_length=256,
60
+ num_beams=5,
61
+ num_return_sequences=1,
62
  )
63
+
64
  with tokenizer_IT2.as_target_tokenizer():
65
+ generated_tokens = tokenizer_IT2.batch_decode(
66
+ generated_tokens.detach().cpu().tolist(),
67
+ skip_special_tokens=True,
68
+ clean_up_tokenization_spaces=True
69
+ )
70
+
71
+ translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
72
+ translations[tgt_lang] = translated_texts[0]
73
+
74
  return translations
75
 
76
+ def generate_audio_gtts(text, lang_code):
77
+ # Create audio directory if it doesn't exist
78
+ os.makedirs('static/audio', exist_ok=True)
79
+
80
+ # Generate temporary file name
81
+ temp_filename = f"static/audio/audio_{hash(text)}_{lang_code}.mp3"
82
+
83
+ # Generate audio file
84
  tts = gTTS(text=text, lang=lang_code)
85
+ tts.save(temp_filename)
86
+
87
+ return temp_filename
88
 
89
+ @app.route('/')
90
  def index():
91
+ return render_template('index.html')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ @app.route('/process', methods=['POST'])
94
+ def process_image():
95
+ if 'image' not in request.files:
96
+ return jsonify({'error': 'No image uploaded'}), 400
97
+
98
+ image_file = request.files['image']
99
+ target_languages = request.form.getlist('languages[]')
100
+
101
+ # Process image
102
+ image = Image.open(image_file).convert('RGB')
103
+ caption = generate_caption(image)
104
+
105
+ # Generate translations
106
+ translations = translate_caption(caption, target_languages)
107
+
108
+ # Generate audio files
109
+ audio_files = {}
110
+ lang_codes = {
111
+ "hin_Deva": "hi",
112
+ "guj_Gujr": "gu",
113
+ "urd_Arab": "ur",
114
+ "mar_Deva": "mr"
115
+ }
116
+
117
+ for lang in target_languages:
118
+ lang_code = lang_codes.get(lang, "en")
119
+ audio_path = generate_audio_gtts(translations[lang], lang_code)
120
+ audio_files[lang] = audio_path.replace('static/', '')
121
+
122
+ return jsonify({
123
+ 'caption': caption,
124
+ 'translations': translations,
125
+ 'audio_files': audio_files
126
+ })
127
 
128
+ if __name__ == '__main__':
129
+ app.run(host='0.0.0.0', port=7860)