import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image from transformers import BertTokenizer, BertModel import numpy as np import os import time from typing import Optional, Union LATENT_DIM = 128 HIDDEN_DIM = 256 # Text encoder class TextEncoder(nn.Module): def __init__(self, hidden_size, output_size): super(TextEncoder, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(self.bert.config.hidden_size, output_size) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) return self.fc(outputs.last_hidden_state[:, 0, :]) # CVAE model (unchanged) class CVAE(nn.Module): def __init__(self, text_encoder): super(CVAE, self).__init__() self.text_encoder = text_encoder # Encoder self.encoder = nn.Sequential( nn.Conv2d(4, 32, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(), nn.Flatten(), nn.Linear(128 * 4 * 4, HIDDEN_DIM) ) self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) # Decoder self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.Conv2d(32, 4, 3, stride=1, padding=1), nn.Tanh() ) def encode(self, x, c): x = self.encoder(x) x = torch.cat([x, c], dim=1) mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def decode(self, z, c): z = torch.cat([z, c], dim=1) x = self.decoder_input(z) x = x.view(-1, 128, 4, 4) return self.decoder(x) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x, c): mu, logvar = self.encode(x, c) z = self.reparameterize(mu, logvar) return self.decode(z, c), mu, logvar # Initialize the BERT tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') def clean_image(image: Image.Image, threshold: float = 0.75) -> Image.Image: np_image = np.array(image) alpha_channel = np_image[:, :, 3] alpha_channel[alpha_channel <= int(threshold * 255)] = 0 alpha_channel[alpha_channel > int(threshold * 255)] = 255 return Image.fromarray(np_image) def generate_image( model: CVAE, text_prompt: str, device: torch.device, input_image: Optional[Image.Image] = None, img_control: float = 0.5 ) -> Image.Image: encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt") input_ids = encoded_input['input_ids'].to(device) attention_mask = encoded_input['attention_mask'].to(device) with torch.no_grad(): text_encoding = model.text_encoder(input_ids, attention_mask) z = torch.randn(1, LATENT_DIM).to(device) generated_image = model.decode(z, text_encoding) if input_image is not None: input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST) input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) generated_image = img_control * input_image + (1 - img_control) * generated_image generated_image = generated_image.squeeze(0).cpu() generated_image = (generated_image + 1) / 2 generated_image = generated_image.clamp(0, 1) generated_image = transforms.ToPILImage()(generated_image) return generated_image # Model loading with caching _model_cache = {} def load_model(model_path: str, device: torch.device) -> CVAE: if model_path not in _model_cache: text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM) model = CVAE(text_encoder).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() _model_cache[model_path] = model return _model_cache[model_path] def generate_image_gradio( prompt: str, model_path: str, clean_image_flag: bool, size: int, input_image: Optional[Image.Image] = None, img_control: float = 0.5 ) -> tuple[Image.Image, str]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: model = load_model(model_path, device) except Exception as e: raise gr.Error(f"Failed to load model: {str(e)}") start_time = time.time() try: generated_image = generate_image(model, prompt, device, input_image, img_control) except Exception as e: raise gr.Error(f"Failed to generate image: {str(e)}") end_time = time.time() generation_time = end_time - start_time if clean_image_flag: generated_image = clean_image(generated_image) try: generated_image = generated_image.resize((size, size), resample=Image.NEAREST) except Exception as e: raise gr.Error(f"Failed to resize image: {str(e)}") return generated_image, f"Generation time: {generation_time:.4f} seconds" def gradio_interface() -> gr.Blocks: with gr.Blocks() as demo: gr.Markdown("# Image Generator from Text Prompt") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Text Prompt") model_path = gr.Textbox(label="Model Path", value="BitRoss.pth") clean_image_flag = gr.Checkbox(label="Clean Image", value=False) size = gr.Slider(minimum=16, maximum=1024, step=16, label="Image Size", value=16) img_control = gr.Slider(minimum=0, maximum=1, step=0.1, label="Image Control", value=0.5) input_image = gr.Image(label="Input Image (optional)", type="pil") generate_button = gr.Button("Generate Image") with gr.Column(): output_image = gr.Image(label="Generated Image") generation_time = gr.Textbox(label="Generation Time") # Use gr.Error for error handling generate_button.click( fn=generate_image_gradio, inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control], outputs=[output_image, generation_time], api_name="generate" # Explicit API endpoint name ) return demo if __name__ == "__main__": demo = gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, # Configure CORS if needed # allowed_paths=["/custom/path"], # cors_allowed_origins=["*"] )