|
import tensorflow as tf |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
from scipy.stats import truncnorm |
|
from skimage.transform import resize |
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
class TextToImageGenerator: |
|
def __init__(self): |
|
self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32') |
|
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') |
|
self.generator = tf.keras.models.load_model('path/to/generator/model') |
|
|
|
def generate_image(self, prompt): |
|
encoded_prompt = self.processor(prompt, return_tensors="tf").to_dict() |
|
noise = tf.random.normal([1, 256]) |
|
text_features = self.clip.get_text_features(encoded_prompt) |
|
image_features = self.generator([text_features, noise], training=False)[0] |
|
image = self._postprocess_image(image_features) |
|
return image |
|
|
|
def _postprocess_image(self, image_features): |
|
image_features = (image_features + 1) / 2 |
|
image_features = np.clip(image_features, 0, 1) |
|
image = Image.fromarray(np.uint8(image_features * 255)) |
|
image = image.resize((256, 256)) |
|
image_buffer = BytesIO() |
|
image.save(image_buffer, format='JPEG') |
|
image_data = image_buffer.getvalue() |
|
return image_data |
|
|