wondervictor commited on
Commit
876dc56
1 Parent(s): fc81a43
.gitignore CHANGED
@@ -154,6 +154,11 @@ dmypy.json
154
  # Cython debug symbols
155
  cython_debug/
156
 
 
 
 
 
 
157
  # PyCharm
158
  # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 
154
  # Cython debug symbols
155
  cython_debug/
156
 
157
+ *.safetensors
158
+ *.lock
159
+ *.bin
160
+ *.pt
161
+ *.json
162
  # PyCharm
163
  # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
app.py CHANGED
@@ -18,6 +18,7 @@ DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive M
18
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
19
  model = Model()
20
  device = "cuda"
 
21
  with gr.Blocks(css="style.css") as demo:
22
  gr.Markdown(DESCRIPTION)
23
  gr.DuplicateButton(
@@ -26,8 +27,8 @@ with gr.Blocks(css="style.css") as demo:
26
  visible=SHOW_DUPLICATE_BUTTON,
27
  )
28
  with gr.Tabs():
29
- with gr.TabItem("Depth"):
30
- create_demo_depth(model.process_depth)
31
  with gr.TabItem("Canny"):
32
  create_demo_canny(model.process_canny)
33
 
 
18
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
19
  model = Model()
20
  device = "cuda"
21
+ model.to(device)
22
  with gr.Blocks(css="style.css") as demo:
23
  gr.Markdown(DESCRIPTION)
24
  gr.DuplicateButton(
 
27
  visible=SHOW_DUPLICATE_BUTTON,
28
  )
29
  with gr.Tabs():
30
+ # with gr.TabItem("Depth"):
31
+ # create_demo_depth(model.process_depth)
32
  with gr.TabItem("Canny"):
33
  create_demo_canny(model.process_canny)
34
 
app_canny.py CHANGED
@@ -104,18 +104,18 @@ def create_demo(process):
104
  canny_low_threshold,
105
  canny_high_threshold,
106
  ]
107
- prompt.submit(
108
- fn=randomize_seed_fn,
109
- inputs=[seed, randomize_seed],
110
- outputs=seed,
111
- queue=False,
112
- api_name=False,
113
- ).then(
114
- fn=process,
115
- inputs=inputs,
116
- outputs=result,
117
- api_name=False,
118
- )
119
  run_button.click(
120
  fn=randomize_seed_fn,
121
  inputs=[seed, randomize_seed],
 
104
  canny_low_threshold,
105
  canny_high_threshold,
106
  ]
107
+ # prompt.submit(
108
+ # fn=randomize_seed_fn,
109
+ # inputs=[seed, randomize_seed],
110
+ # outputs=seed,
111
+ # queue=False,
112
+ # api_name=False,
113
+ # ).then(
114
+ # fn=process,
115
+ # inputs=inputs,
116
+ # outputs=result,
117
+ # api_name=False,
118
+ # )
119
  run_button.click(
120
  fn=randomize_seed_fn,
121
  inputs=[seed, randomize_seed],
autoregressive/models/gpt_t2i.py CHANGED
@@ -375,8 +375,6 @@ class Transformer(nn.Module):
375
  # Zero-out output layers:
376
  nn.init.constant_(self.output.weight, 0)
377
 
378
-
379
-
380
  def _init_weights(self, module):
381
  std = self.config.initializer_range
382
  if isinstance(module, nn.Linear):
 
375
  # Zero-out output layers:
376
  nn.init.constant_(self.output.weight, 0)
377
 
 
 
378
  def _init_weights(self, module):
379
  std = self.config.initializer_range
380
  if isinstance(module, nn.Linear):
checkpoints/flan-t5-xl/flan-t5-xl/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
language/t5.py CHANGED
@@ -18,7 +18,7 @@ class T5Embedder:
18
 
19
  def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
20
  t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
21
- self.device = torch.device(device)
22
  self.torch_dtype = torch_dtype or torch.bfloat16
23
  if t5_model_kwargs is None:
24
  t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
@@ -53,6 +53,7 @@ class T5Embedder:
53
  print(tokenizer_path)
54
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
55
  self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
 
56
  self.model_max_length = model_max_length
57
 
58
  def get_text_embeddings(self, texts):
@@ -72,11 +73,12 @@ class T5Embedder:
72
  text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
73
 
74
  with torch.no_grad():
 
75
  text_encoder_embs = self.model(
76
- input_ids=text_tokens_and_mask['input_ids'].to(self.device),
77
- attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
78
  )['last_hidden_state'].detach()
79
- return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
80
 
81
  def text_preprocessing(self, text):
82
  if self.use_text_preprocessing:
 
18
 
19
  def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
20
  t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
21
+ self.device = torch.device('cuda:0')
22
  self.torch_dtype = torch_dtype or torch.bfloat16
23
  if t5_model_kwargs is None:
24
  t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
 
53
  print(tokenizer_path)
54
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
55
  self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
56
+ self.model.to('cuda')
57
  self.model_max_length = model_max_length
58
 
59
  def get_text_embeddings(self, texts):
 
73
  text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
74
 
75
  with torch.no_grad():
76
+ print("t5:", self.model.device)
77
  text_encoder_embs = self.model(
78
+ input_ids=text_tokens_and_mask['input_ids'].to(self.model.device),
79
+ attention_mask=text_tokens_and_mask['attention_mask'].to(self.model.device),
80
  )['last_hidden_state'].detach()
81
+ return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.model.device)
82
 
83
  def text_preprocessing(self, text):
84
  if self.use_text_preprocessing:
model.py CHANGED
@@ -40,7 +40,7 @@ class Model:
40
 
41
  def __init__(self):
42
  self.device = torch.device(
43
- "cuda:0" if torch.cuda.is_available() else "cpu")
44
  self.base_model_id = ""
45
  self.task_name = ""
46
  self.vq_model = self.load_vq()
@@ -48,12 +48,17 @@ class Model:
48
  self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
  self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
  self.get_control_canny = CannyDetector()
51
- self.get_control_depth = MidasDetector(device=self.device)
 
 
 
 
 
52
 
53
  def load_vq(self):
54
  vq_model = VQ_models["VQ-16"](codebook_size=16384,
55
  codebook_embed_dim=8)
56
- vq_model.to(self.device)
57
  vq_model.eval()
58
  checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
59
  map_location="cpu")
@@ -71,7 +76,7 @@ class Model:
71
  cls_token_num=120,
72
  model_type='t2i',
73
  condition_type=condition_type,
74
- ).to(device=self.device, dtype=precision)
75
 
76
  model_weight = load_file(gpt_ckpt)
77
  gpt_model.load_state_dict(model_weight, strict=False)
@@ -82,7 +87,7 @@ class Model:
82
  def load_t5(self):
83
  precision = torch.bfloat16
84
  t5_model = T5Embedder(
85
- device=self.device,
86
  local_cache=False,
87
  cache_dir='checkpoints/flan-t5-xl',
88
  dir_or_name='flan-t5-xl',
@@ -134,6 +139,7 @@ class Model:
134
  c_emb_masks = new_emb_masks
135
  qzshape = [len(c_indices), 8, H // 16, W // 16]
136
  t1 = time.time()
 
137
  index_sample = generate(
138
  self.gpt_model_canny,
139
  c_indices,
 
40
 
41
  def __init__(self):
42
  self.device = torch.device(
43
+ "cuda:0")
44
  self.base_model_id = ""
45
  self.task_name = ""
46
  self.vq_model = self.load_vq()
 
48
  self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
  self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
  self.get_control_canny = CannyDetector()
51
+ self.get_control_depth = MidasDetector('cuda')
52
+
53
+ def to(self, device):
54
+ self.gpt_model_canny.to('cuda')
55
+ print(next(self.gpt_model_canny.adapter.parameters()).device)
56
+ # print(self.gpt_model_canny.device)
57
 
58
  def load_vq(self):
59
  vq_model = VQ_models["VQ-16"](codebook_size=16384,
60
  codebook_embed_dim=8)
61
+ vq_model.to('cuda')
62
  vq_model.eval()
63
  checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
64
  map_location="cpu")
 
76
  cls_token_num=120,
77
  model_type='t2i',
78
  condition_type=condition_type,
79
+ ).to(device='cuda', dtype=precision)
80
 
81
  model_weight = load_file(gpt_ckpt)
82
  gpt_model.load_state_dict(model_weight, strict=False)
 
87
  def load_t5(self):
88
  precision = torch.bfloat16
89
  t5_model = T5Embedder(
90
+ device="cuda",
91
  local_cache=False,
92
  cache_dir='checkpoints/flan-t5-xl',
93
  dir_or_name='flan-t5-xl',
 
139
  c_emb_masks = new_emb_masks
140
  qzshape = [len(c_indices), 8, H // 16, W // 16]
141
  t1 = time.time()
142
+ print(caption_embs.device)
143
  index_sample = generate(
144
  self.gpt_model_canny,
145
  c_indices,