File size: 4,231 Bytes
a57bfff
 
 
 
 
 
 
 
704dff6
a57bfff
 
 
 
 
 
704dff6
 
a57bfff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4727cc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Import necessary libraries
import requests
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import MarianMTModel, MarianTokenizer, pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import os  # For accessing environment variables

# Constants for model names and API URLs
class Constants:
    TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
    IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
    GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M"
    # Get the Hugging Face API token from environment variables
    HEADERS = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}

# Translation Class
class Translator:
    def __init__(self):
        self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
        self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
        self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer)

    def translate(self, tamil_text):
        """Translate Tamil text to English."""
        try:
            translation = self.pipeline(tamil_text, max_length=40)
            return translation[0]['translation_text']
        except Exception as e:
            return f"Translation error: {str(e)}"


# Image Generation Class
class ImageGenerator:
    def __init__(self):
        self.api_url = Constants.IMAGE_GENERATION_API_URL

    def generate(self, prompt):
        """Generate an image based on the given prompt."""
        try:
            response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt})
            if response.status_code == 200:
                image_bytes = response.content
                return Image.open(io.BytesIO(image_bytes))
            else:
                print(f"Image generation failed: Status code {response.status_code}")
                return None
        except Exception as e:
            print(f"Image generation error: {str(e)}")
            return None


# Creative Text Generation Class
class CreativeTextGenerator:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
        self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME)

    def generate(self, translated_text):
        """Generate creative text based on translated text."""
        input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids
        generated_text_ids = self.model.generate(input_ids, max_length=100)
        return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)


# Main Application Class
class TransArtApp:
    def __init__(self):
        self.translator = Translator()
        self.image_generator = ImageGenerator()
        self.creative_text_generator = CreativeTextGenerator()

    def process(self, tamil_text):
        """Handle the full workflow: translate, generate image, and creative text."""
        translated_text = self.translator.translate(tamil_text)
        image = self.image_generator.generate(translated_text)
        creative_text = self.creative_text_generator.generate(translated_text)
        return translated_text, creative_text, image


# Function to display images
def show_image(image):
    """Display an image using matplotlib."""
    if image:
        plt.imshow(image)
        plt.axis('off')  # Hide axes
        plt.show()
    else:
        print("No image to display.")


# Create an instance of the TransArt app
app = TransArtApp()

# Gradio interface function
def gradio_interface(tamil_text):
    """Interface function for Gradio."""
    translated_text, creative_text, image = app.process(tamil_text)
    return translated_text, creative_text, image


# Create Gradio interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs="text",
    outputs=["text", "text", "image"],
    title="Tamil to English Translation, Image Generation & Creative Text",
    description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
)

# Launch Gradio app
if __name__ == "__main__":
    interface.launch()