OVAWARE commited on
Commit
273708c
·
verified ·
1 Parent(s): 2e719f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from transformers import BertTokenizer, BertModel
6
+ import argparse
7
+ import numpy as np
8
+ import os
9
+ import time # Import the time module
10
+
11
+ # Import the model architecture from train.py
12
+ from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
13
+
14
+ # Initialize the BERT tokenizer
15
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
+
17
+ def clean_image(image, threshold=0.75):
18
+ """
19
+ Clean up the image by setting pixels with opacity <= threshold to 0% opacity
20
+ and pixels above the threshold to 100% visibility.
21
+ """
22
+ np_image = np.array(image)
23
+ alpha_channel = np_image[:, :, 3]
24
+ alpha_channel[alpha_channel <= int(threshold * 255)] = 0
25
+ alpha_channel[alpha_channel > int(threshold * 255)] = 255 # Set to 100% visibility
26
+ return Image.fromarray(np_image)
27
+
28
+ def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
29
+ # Encode text prompt using BERT tokenizer
30
+ encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
31
+ input_ids = encoded_input['input_ids'].to(device)
32
+ attention_mask = encoded_input['attention_mask'].to(device)
33
+
34
+ # Generate text encoding
35
+ with torch.no_grad():
36
+ text_encoding = model.text_encoder(input_ids, attention_mask)
37
+
38
+ # Sample from the latent space
39
+ z = torch.randn(1, LATENT_DIM).to(device)
40
+
41
+ # Generate image
42
+ with torch.no_grad():
43
+ generated_image = model.decode(z, text_encoding)
44
+
45
+ if input_image is not None:
46
+ input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
47
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
48
+ generated_image = img_control * input_image + (1 - img_control) * generated_image
49
+
50
+ # Convert the generated tensor to a PIL Image
51
+ generated_image = generated_image.squeeze(0).cpu()
52
+ generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
53
+ generated_image = generated_image.clamp(0, 1)
54
+ generated_image = transforms.ToPILImage()(generated_image)
55
+
56
+ return generated_image
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description="Generate an image from a text prompt using the trained CVAE model(s).")
60
+ parser.add_argument("--prompt", type=str, help="Text prompt for image generation")
61
+ parser.add_argument("--prompt_file", type=str, help="File containing prompts, one per line")
62
+ parser.add_argument("--output", type=str, default="generated_images", help="Output directory or file for generated images")
63
+ parser.add_argument("--model_paths", type=str, nargs='*', help="Paths to the trained model(s)")
64
+ parser.add_argument("--model_path", type=str, help="Path to a single trained model")
65
+ parser.add_argument("--clean", action="store_true", help="Clean up the image by removing low opacity pixels")
66
+ parser.add_argument("--size", type=int, default=16, help="Size of the generated image")
67
+ parser.add_argument("--input_image", type=str, help="Path to the input image for img2img generation")
68
+ parser.add_argument("--img_control", type=float, default=0.5, help="Control how much the input image influences the output (0 to 1)")
69
+ args = parser.parse_args()
70
+
71
+ if not args.prompt and not args.prompt_file:
72
+ parser.error("Either --prompt or --prompt_file must be provided")
73
+
74
+ if args.model_paths and args.model_path:
75
+ parser.error("Specify either --model_paths or --model_path, not both")
76
+
77
+ model_paths = args.model_paths if args.model_paths else [args.model_path]
78
+
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+
81
+ # Check if --output is a file or directory
82
+ is_folder_output = os.path.isdir(args.output)
83
+
84
+ if is_folder_output:
85
+ # Ensure output directory exists if it's not a file
86
+ os.makedirs(args.output, exist_ok=True)
87
+
88
+ # Load input image if provided
89
+ input_image = None
90
+ if args.input_image:
91
+ input_image = Image.open(args.input_image).convert("RGBA")
92
+
93
+ # Process single prompt or batch of prompts
94
+ if args.prompt:
95
+ prompts = [args.prompt]
96
+ else:
97
+ with open(args.prompt_file, 'r') as f:
98
+ prompts = [line.strip() for line in f if line.strip()]
99
+
100
+ for model_path in model_paths:
101
+ # Initialize model
102
+ text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
103
+ model = CVAE(text_encoder).to(device)
104
+
105
+ # Load the trained model
106
+ model.load_state_dict(torch.load(model_path, map_location=device))
107
+ model.eval()
108
+
109
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
110
+
111
+ for i, prompt in enumerate(prompts):
112
+ start_time = time.time() # Start timing the generation
113
+
114
+ # Generate image from prompt
115
+ generated_image = generate_image(model, prompt, device, input_image, args.img_control)
116
+
117
+ # End timing the generation
118
+ end_time = time.time()
119
+ generation_time = end_time - start_time # Calculate the generation time
120
+
121
+ # Clean up the image if the flag is set
122
+ if args.clean:
123
+ generated_image = clean_image(generated_image)
124
+
125
+ # Resize the generated image
126
+ generated_image = generated_image.resize((args.size, args.size), resample=Image.NEAREST)
127
+
128
+ if not is_folder_output:
129
+ # Save the generated image to the specified file
130
+ output_file = args.output
131
+ else:
132
+ # Save the generated image to the output directory
133
+ output_file = os.path.join(args.output, f"{model_name}_{prompt}_{i:03d}.png")
134
+
135
+ generated_image.save(output_file)
136
+ print(f"Generated image for prompt '{prompt}' using model '{model_name}' saved as {output_file}")
137
+ print(f"Generation time: {generation_time:.10f} seconds") # Print the generation time
138
+
139
+ if __name__ == "__main__":
140
+ main()