Rahul8827 commited on
Commit
e08d8b4
1 Parent(s): c13f57e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -40
app.py CHANGED
@@ -10,7 +10,7 @@ from glide_text2im.model_creation import (
10
  )
11
  has_cuda = th.cuda.is_available()
12
  device = th.device('cpu' if not has_cuda else 'cuda')
13
- # Create base model.
14
  options = model_and_diffusion_defaults()
15
  options['use_fp16'] = has_cuda
16
  options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
@@ -21,7 +21,8 @@ if has_cuda:
21
  model.to(device)
22
  model.load_state_dict(load_checkpoint('base', device))
23
  print('total base parameters', sum(x.numel() for x in model.parameters()))
24
- # Create upsampler model.
 
25
  options_up = model_and_diffusion_defaults_upsampler()
26
  options_up['use_fp16'] = has_cuda
27
  options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
@@ -38,31 +39,15 @@ def show_images(batch: th.Tensor):
38
  reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
39
  display(Image.fromarray(reshaped.numpy()))
40
  # Sampling parameters
41
- prompt = "an oil painting of a corgi"
42
  batch_size = 1
43
  guidance_scale = 3.0
44
 
45
  # Tune this parameter to control the sharpness of 256x256 images.
46
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
47
  upsample_temp = 0.997
48
-
49
- samples = diffusion.p_sample_loop(
50
- model_fn,
51
- (full_batch_size, 3, options["image_size"], options["image_size"]),
52
- device=device,
53
- clip_denoised=True,
54
- progress=True,
55
- model_kwargs=model_kwargs,
56
- cond_fn=None,
57
- )[:batch_size]
58
- model.del_cache()
59
-
60
- # Show the output
61
- show_images(samples)
62
-
63
-
64
  import gradio as gr
65
- def generate_upsampled_image_from_text(prompt):
66
  # Set the prompt text
67
  prompt = prompt
68
 
@@ -71,43 +56,54 @@ def generate_upsampled_image_from_text(prompt):
71
  ##############################
72
 
73
  # Create the text tokens to feed to the model.
74
- tokens = model_up.tokenizer.encode(prompt)
75
- tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
76
- tokens, options_up['text_ctx']
77
  )
78
 
79
- # Create the model conditioning dict.
80
- model_kwargs = dict(
81
- # Low-res image to upsample.
82
- low_res=((samples + 1) * 127.5).round() / 127.5 - 1,
 
83
 
84
- # Text tokens
 
85
  tokens=th.tensor(
86
- [tokens] * batch_size, device=device
87
  ),
88
  mask=th.tensor(
89
- [mask] * batch_size,
90
  dtype=th.bool,
91
  device=device,
92
  ),
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Sample from the base model.
96
- model_up.del_cache()
97
- up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
98
- up_samples = diffusion_up.ddim_sample_loop(
99
- model_up,
100
- up_shape,
101
- noise=th.randn(up_shape, device=device) * upsample_temp,
102
  device=device,
103
  clip_denoised=True,
104
  progress=True,
105
  model_kwargs=model_kwargs,
106
  cond_fn=None,
107
  )[:batch_size]
108
- model_up.del_cache()
109
 
110
  # Show the output
111
- show_images(up_samples)
112
- demo = gr.Interface(fn =generate_upsampled_image_from_text,inputs ="text",outputs ="image")
113
- demo.launch()
 
10
  )
11
  has_cuda = th.cuda.is_available()
12
  device = th.device('cpu' if not has_cuda else 'cuda')
13
+ # a base model.
14
  options = model_and_diffusion_defaults()
15
  options['use_fp16'] = has_cuda
16
  options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
 
21
  model.to(device)
22
  model.load_state_dict(load_checkpoint('base', device))
23
  print('total base parameters', sum(x.numel() for x in model.parameters()))
24
+
25
+ # Create an upsampler model.
26
  options_up = model_and_diffusion_defaults_upsampler()
27
  options_up['use_fp16'] = has_cuda
28
  options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
 
39
  reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
40
  display(Image.fromarray(reshaped.numpy()))
41
  # Sampling parameters
42
+ prompt = ""
43
  batch_size = 1
44
  guidance_scale = 3.0
45
 
46
  # Tune this parameter to control the sharpness of 256x256 images.
47
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
48
  upsample_temp = 0.997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  import gradio as gr
50
+ def generate_image_from_text(prompt):
51
  # Set the prompt text
52
  prompt = prompt
53
 
 
56
  ##############################
57
 
58
  # Create the text tokens to feed to the model.
59
+ tokens = model.tokenizer.encode(prompt)
60
+ tokens, mask = model.tokenizer.padded_tokens_and_mask(
61
+ tokens, options['text_ctx']
62
  )
63
 
64
+ # Create the classifier-free guidance tokens (empty)
65
+ full_batch_size = batch_size * 2
66
+ uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
67
+ [], options['text_ctx']
68
+ )
69
 
70
+ # Pack the tokens together into model kwargs.
71
+ model_kwargs = dict(
72
  tokens=th.tensor(
73
+ [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
74
  ),
75
  mask=th.tensor(
76
+ [mask] * batch_size + [uncond_mask] * batch_size,
77
  dtype=th.bool,
78
  device=device,
79
  ),
80
  )
81
 
82
+ # Create a classifier-free guidance sampling function
83
+ def model_fn(x_t, ts, **kwargs):
84
+ half = x_t[: len(x_t) // 2]
85
+ combined = th.cat([half, half], dim=0)
86
+ model_out = model(combined, ts, **kwargs)
87
+ eps, rest = model_out[:, :3], model_out[:, 3:]
88
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
89
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
90
+ eps = th.cat([half_eps, half_eps], dim=0)
91
+ return th.cat([eps, rest], dim=1)
92
+
93
  # Sample from the base model.
94
+ model.del_cache()
95
+ samples = diffusion.p_sample_loop(
96
+ model_fn,
97
+ (full_batch_size, 3, options["image_size"], options["image_size"]),
 
 
98
  device=device,
99
  clip_denoised=True,
100
  progress=True,
101
  model_kwargs=model_kwargs,
102
  cond_fn=None,
103
  )[:batch_size]
104
+ model.del_cache()
105
 
106
  # Show the output
107
+ show_images(samples)
108
+ demo = gr.Interface(fn =generate_image_from_text,inputs ="text",outputs ="image")
109
+ demo.launch()