Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
•
876dc56
1
Parent(s):
fc81a43
update
Browse files- .gitignore +5 -0
- app.py +3 -2
- app_canny.py +12 -12
- autoregressive/models/gpt_t2i.py +0 -2
- checkpoints/flan-t5-xl/flan-t5-xl/spiece.model +3 -0
- language/t5.py +6 -4
- model.py +11 -5
.gitignore
CHANGED
@@ -154,6 +154,11 @@ dmypy.json
|
|
154 |
# Cython debug symbols
|
155 |
cython_debug/
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
# PyCharm
|
158 |
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
|
154 |
# Cython debug symbols
|
155 |
cython_debug/
|
156 |
|
157 |
+
*.safetensors
|
158 |
+
*.lock
|
159 |
+
*.bin
|
160 |
+
*.pt
|
161 |
+
*.json
|
162 |
# PyCharm
|
163 |
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
164 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
app.py
CHANGED
@@ -18,6 +18,7 @@ DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive M
|
|
18 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
19 |
model = Model()
|
20 |
device = "cuda"
|
|
|
21 |
with gr.Blocks(css="style.css") as demo:
|
22 |
gr.Markdown(DESCRIPTION)
|
23 |
gr.DuplicateButton(
|
@@ -26,8 +27,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
26 |
visible=SHOW_DUPLICATE_BUTTON,
|
27 |
)
|
28 |
with gr.Tabs():
|
29 |
-
with gr.TabItem("Depth"):
|
30 |
-
|
31 |
with gr.TabItem("Canny"):
|
32 |
create_demo_canny(model.process_canny)
|
33 |
|
|
|
18 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
19 |
model = Model()
|
20 |
device = "cuda"
|
21 |
+
model.to(device)
|
22 |
with gr.Blocks(css="style.css") as demo:
|
23 |
gr.Markdown(DESCRIPTION)
|
24 |
gr.DuplicateButton(
|
|
|
27 |
visible=SHOW_DUPLICATE_BUTTON,
|
28 |
)
|
29 |
with gr.Tabs():
|
30 |
+
# with gr.TabItem("Depth"):
|
31 |
+
# create_demo_depth(model.process_depth)
|
32 |
with gr.TabItem("Canny"):
|
33 |
create_demo_canny(model.process_canny)
|
34 |
|
app_canny.py
CHANGED
@@ -104,18 +104,18 @@ def create_demo(process):
|
|
104 |
canny_low_threshold,
|
105 |
canny_high_threshold,
|
106 |
]
|
107 |
-
prompt.submit(
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
).then(
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
)
|
119 |
run_button.click(
|
120 |
fn=randomize_seed_fn,
|
121 |
inputs=[seed, randomize_seed],
|
|
|
104 |
canny_low_threshold,
|
105 |
canny_high_threshold,
|
106 |
]
|
107 |
+
# prompt.submit(
|
108 |
+
# fn=randomize_seed_fn,
|
109 |
+
# inputs=[seed, randomize_seed],
|
110 |
+
# outputs=seed,
|
111 |
+
# queue=False,
|
112 |
+
# api_name=False,
|
113 |
+
# ).then(
|
114 |
+
# fn=process,
|
115 |
+
# inputs=inputs,
|
116 |
+
# outputs=result,
|
117 |
+
# api_name=False,
|
118 |
+
# )
|
119 |
run_button.click(
|
120 |
fn=randomize_seed_fn,
|
121 |
inputs=[seed, randomize_seed],
|
autoregressive/models/gpt_t2i.py
CHANGED
@@ -375,8 +375,6 @@ class Transformer(nn.Module):
|
|
375 |
# Zero-out output layers:
|
376 |
nn.init.constant_(self.output.weight, 0)
|
377 |
|
378 |
-
|
379 |
-
|
380 |
def _init_weights(self, module):
|
381 |
std = self.config.initializer_range
|
382 |
if isinstance(module, nn.Linear):
|
|
|
375 |
# Zero-out output layers:
|
376 |
nn.init.constant_(self.output.weight, 0)
|
377 |
|
|
|
|
|
378 |
def _init_weights(self, module):
|
379 |
std = self.config.initializer_range
|
380 |
if isinstance(module, nn.Linear):
|
checkpoints/flan-t5-xl/flan-t5-xl/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
3 |
+
size 791656
|
language/t5.py
CHANGED
@@ -18,7 +18,7 @@ class T5Embedder:
|
|
18 |
|
19 |
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
20 |
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
21 |
-
self.device = torch.device(
|
22 |
self.torch_dtype = torch_dtype or torch.bfloat16
|
23 |
if t5_model_kwargs is None:
|
24 |
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
@@ -53,6 +53,7 @@ class T5Embedder:
|
|
53 |
print(tokenizer_path)
|
54 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
55 |
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
|
|
56 |
self.model_max_length = model_max_length
|
57 |
|
58 |
def get_text_embeddings(self, texts):
|
@@ -72,11 +73,12 @@ class T5Embedder:
|
|
72 |
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
73 |
|
74 |
with torch.no_grad():
|
|
|
75 |
text_encoder_embs = self.model(
|
76 |
-
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
|
77 |
-
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
|
78 |
)['last_hidden_state'].detach()
|
79 |
-
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
|
80 |
|
81 |
def text_preprocessing(self, text):
|
82 |
if self.use_text_preprocessing:
|
|
|
18 |
|
19 |
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
20 |
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
21 |
+
self.device = torch.device('cuda:0')
|
22 |
self.torch_dtype = torch_dtype or torch.bfloat16
|
23 |
if t5_model_kwargs is None:
|
24 |
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
|
|
53 |
print(tokenizer_path)
|
54 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
55 |
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
56 |
+
self.model.to('cuda')
|
57 |
self.model_max_length = model_max_length
|
58 |
|
59 |
def get_text_embeddings(self, texts):
|
|
|
73 |
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
74 |
|
75 |
with torch.no_grad():
|
76 |
+
print("t5:", self.model.device)
|
77 |
text_encoder_embs = self.model(
|
78 |
+
input_ids=text_tokens_and_mask['input_ids'].to(self.model.device),
|
79 |
+
attention_mask=text_tokens_and_mask['attention_mask'].to(self.model.device),
|
80 |
)['last_hidden_state'].detach()
|
81 |
+
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.model.device)
|
82 |
|
83 |
def text_preprocessing(self, text):
|
84 |
if self.use_text_preprocessing:
|
model.py
CHANGED
@@ -40,7 +40,7 @@ class Model:
|
|
40 |
|
41 |
def __init__(self):
|
42 |
self.device = torch.device(
|
43 |
-
"cuda:0"
|
44 |
self.base_model_id = ""
|
45 |
self.task_name = ""
|
46 |
self.vq_model = self.load_vq()
|
@@ -48,12 +48,17 @@ class Model:
|
|
48 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
49 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
50 |
self.get_control_canny = CannyDetector()
|
51 |
-
self.get_control_depth = MidasDetector(
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def load_vq(self):
|
54 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
55 |
codebook_embed_dim=8)
|
56 |
-
vq_model.to(
|
57 |
vq_model.eval()
|
58 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
59 |
map_location="cpu")
|
@@ -71,7 +76,7 @@ class Model:
|
|
71 |
cls_token_num=120,
|
72 |
model_type='t2i',
|
73 |
condition_type=condition_type,
|
74 |
-
).to(device=
|
75 |
|
76 |
model_weight = load_file(gpt_ckpt)
|
77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
@@ -82,7 +87,7 @@ class Model:
|
|
82 |
def load_t5(self):
|
83 |
precision = torch.bfloat16
|
84 |
t5_model = T5Embedder(
|
85 |
-
device=
|
86 |
local_cache=False,
|
87 |
cache_dir='checkpoints/flan-t5-xl',
|
88 |
dir_or_name='flan-t5-xl',
|
@@ -134,6 +139,7 @@ class Model:
|
|
134 |
c_emb_masks = new_emb_masks
|
135 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
136 |
t1 = time.time()
|
|
|
137 |
index_sample = generate(
|
138 |
self.gpt_model_canny,
|
139 |
c_indices,
|
|
|
40 |
|
41 |
def __init__(self):
|
42 |
self.device = torch.device(
|
43 |
+
"cuda:0")
|
44 |
self.base_model_id = ""
|
45 |
self.task_name = ""
|
46 |
self.vq_model = self.load_vq()
|
|
|
48 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
49 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
50 |
self.get_control_canny = CannyDetector()
|
51 |
+
self.get_control_depth = MidasDetector('cuda')
|
52 |
+
|
53 |
+
def to(self, device):
|
54 |
+
self.gpt_model_canny.to('cuda')
|
55 |
+
print(next(self.gpt_model_canny.adapter.parameters()).device)
|
56 |
+
# print(self.gpt_model_canny.device)
|
57 |
|
58 |
def load_vq(self):
|
59 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
60 |
codebook_embed_dim=8)
|
61 |
+
vq_model.to('cuda')
|
62 |
vq_model.eval()
|
63 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
64 |
map_location="cpu")
|
|
|
76 |
cls_token_num=120,
|
77 |
model_type='t2i',
|
78 |
condition_type=condition_type,
|
79 |
+
).to(device='cuda', dtype=precision)
|
80 |
|
81 |
model_weight = load_file(gpt_ckpt)
|
82 |
gpt_model.load_state_dict(model_weight, strict=False)
|
|
|
87 |
def load_t5(self):
|
88 |
precision = torch.bfloat16
|
89 |
t5_model = T5Embedder(
|
90 |
+
device="cuda",
|
91 |
local_cache=False,
|
92 |
cache_dir='checkpoints/flan-t5-xl',
|
93 |
dir_or_name='flan-t5-xl',
|
|
|
139 |
c_emb_masks = new_emb_masks
|
140 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
141 |
t1 = time.time()
|
142 |
+
print(caption_embs.device)
|
143 |
index_sample = generate(
|
144 |
self.gpt_model_canny,
|
145 |
c_indices,
|