File size: 4,330 Bytes
ba28811
72a8f7d
 
ba28811
 
 
 
 
72a8f7d
ba28811
 
72a8f7d
 
 
 
ba28811
 
 
 
 
 
72a8f7d
 
 
ba28811
 
 
 
72a8f7d
 
 
ba28811
 
72a8f7d
 
ba28811
 
 
 
 
 
72a8f7d
ba28811
 
 
72a8f7d
ba28811
 
72a8f7d
 
ba28811
 
72a8f7d
 
ba28811
 
 
 
 
 
72a8f7d
ba28811
72a8f7d
ba28811
 
 
 
 
 
 
 
 
72a8f7d
 
ba28811
 
 
 
 
 
 
 
72a8f7d
ba28811
 
 
72a8f7d
ba28811
72a8f7d
ba28811
72a8f7d
ba28811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a8f7d
ba28811
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from flask import Flask, render_template, request, jsonify, send_file
from PIL import Image
import torch
from transformers import (
    BlipProcessor, 
    BlipForConditionalGeneration, 
    AutoTokenizer
)
import os
from gtts import gTTS
import tempfile
from IndicTransToolkit import IndicProcessor

app = Flask(__name__)

# Initialize models
if not os.path.exists('IndicTransToolkit'):
    os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
    os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')

# Global variables for models
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")

# Cache for translated results
translation_cache = {}

def generate_caption(image):
    inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        generated_ids = blip_model.generate(**inputs)
    caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
    return caption

def translate_caption(caption, target_languages):
    model_name = "ai4bharat/indictrans2-en-indic-1B"
    tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model_IT2 = AutoModelForSeq2SeqTranslation.from_pretrained(model_name, trust_remote_code=True)
    model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
    
    ip = IndicProcessor(inference=True)
    src_lang = "eng_Latn"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model_IT2.to(DEVICE)
    
    translations = {}
    input_sentences = [caption]
    
    for tgt_lang in target_languages:
        batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
        inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
        
        with torch.no_grad():
            generated_tokens = model_IT2.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )
            
        with tokenizer_IT2.as_target_tokenizer():
            generated_tokens = tokenizer_IT2.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            
        translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
        translations[tgt_lang] = translated_texts[0]
        
    return translations

def generate_audio_gtts(text, lang_code):
    # Create audio directory if it doesn't exist
    os.makedirs('static/audio', exist_ok=True)
    
    # Generate temporary file name
    temp_filename = f"static/audio/audio_{hash(text)}_{lang_code}.mp3"
    
    # Generate audio file
    tts = gTTS(text=text, lang=lang_code)
    tts.save(temp_filename)
    
    return temp_filename

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/process', methods=['POST'])
def process_image():
    if 'image' not in request.files:
        return jsonify({'error': 'No image uploaded'}), 400
        
    image_file = request.files['image']
    target_languages = request.form.getlist('languages[]')
    
    # Process image
    image = Image.open(image_file).convert('RGB')
    caption = generate_caption(image)
    
    # Generate translations
    translations = translate_caption(caption, target_languages)
    
    # Generate audio files
    audio_files = {}
    lang_codes = {
        "hin_Deva": "hi",
        "guj_Gujr": "gu",
        "urd_Arab": "ur",
        "mar_Deva": "mr"
    }
    
    for lang in target_languages:
        lang_code = lang_codes.get(lang, "en")
        audio_path = generate_audio_gtts(translations[lang], lang_code)
        audio_files[lang] = audio_path.replace('static/', '')
    
    return jsonify({
        'caption': caption,
        'translations': translations,
        'audio_files': audio_files
    })

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)