EmaadKhwaja
commited on
Commit
•
2eb6d66
1
Parent(s):
860c3d7
add models
Browse files- app.py +3 -5
- models/HPA_480.ckpt +3 -0
- models/HPA_480.yaml +25 -0
- models/HPA_Finetuned_480.ckpt +3 -0
- models/HPA_Finetuned_480.yaml +25 -0
- models/nucleus_vqgan.yaml +33 -0
- models/threshold_vqgan.yaml +33 -0
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
from huggingface_hub import hf_hub_download
|
3 |
from prediction import run_image_prediction
|
4 |
import torch
|
5 |
import torchvision.transforms as T
|
@@ -9,10 +8,9 @@ from matplotlib import pyplot as plt
|
|
9 |
|
10 |
|
11 |
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
|
12 |
-
model =
|
13 |
-
config =
|
14 |
-
|
15 |
-
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
if "Finetuned" in model_name:
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from prediction import run_image_prediction
|
3 |
import torch
|
4 |
import torchvision.transforms as T
|
|
|
8 |
|
9 |
|
10 |
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
|
11 |
+
model = f"models/{model_name}.ckpt"
|
12 |
+
config = f"models/{model_name}.yaml"
|
13 |
+
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
if "Finetuned" in model_name:
|
models/HPA_480.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14b445eedda4f46e67c7d848df3301363e9034d752591a1275b47eaf483e0cac
|
3 |
+
size 4731482483
|
models/HPA_480.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
learning_rate: 0.0003
|
3 |
+
target: celle_main.CELLE_trainer
|
4 |
+
params:
|
5 |
+
ckpt_path: model.ckpt
|
6 |
+
condition_model_path:
|
7 |
+
condition_config_path: nucleus_vqgan.yaml
|
8 |
+
vqgan_model_path:
|
9 |
+
vqgan_config_path: threshold_vqgan.yaml
|
10 |
+
image_key: threshold
|
11 |
+
num_images: 2
|
12 |
+
dim: 480
|
13 |
+
num_text_tokens: 33
|
14 |
+
text_seq_len: 1000
|
15 |
+
depth: 68
|
16 |
+
heads: 16
|
17 |
+
dim_head: 64
|
18 |
+
attn_dropout: 0.1
|
19 |
+
ff_dropout: 0.1
|
20 |
+
attn_types: full
|
21 |
+
rotary_emb: true
|
22 |
+
fixed_embedding: true
|
23 |
+
text_embedding: esm2
|
24 |
+
loss_img_weight: 1
|
25 |
+
loss_cond_weight: 1
|
models/HPA_Finetuned_480.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:152dfd30c8adfe83868ac0ffea4d075d25d11f24369f34482f79a3db4524d1be
|
3 |
+
size 4731482675
|
models/HPA_Finetuned_480.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
learning_rate: 0.0003
|
3 |
+
target: celle_main.CELLE_trainer
|
4 |
+
params:
|
5 |
+
ckpt_path: model.ckpt
|
6 |
+
condition_model_path:
|
7 |
+
condition_config_path: nucleus_vqgan.yaml
|
8 |
+
vqgan_model_path:
|
9 |
+
vqgan_config_path: threshold_vqgan.yaml
|
10 |
+
image_key: threshold
|
11 |
+
num_images: 2
|
12 |
+
dim: 480
|
13 |
+
num_text_tokens: 33
|
14 |
+
text_seq_len: 1000
|
15 |
+
depth: 68
|
16 |
+
heads: 16
|
17 |
+
dim_head: 64
|
18 |
+
attn_dropout: 0.1
|
19 |
+
ff_dropout: 0.1
|
20 |
+
attn_types: full
|
21 |
+
rotary_emb: true
|
22 |
+
fixed_embedding: true
|
23 |
+
text_embedding: esm2
|
24 |
+
loss_img_weight: 1
|
25 |
+
loss_cond_weight: 1
|
models/nucleus_vqgan.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
image_key: nucleus
|
6 |
+
ckpt_path:
|
7 |
+
embed_dim: 256
|
8 |
+
n_embed: 512
|
9 |
+
ddconfig:
|
10 |
+
double_z: false
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 1
|
14 |
+
out_ch: 1
|
15 |
+
ch: 128
|
16 |
+
ch_mult:
|
17 |
+
- 1
|
18 |
+
- 1
|
19 |
+
- 2
|
20 |
+
- 2
|
21 |
+
- 4
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions:
|
24 |
+
- 16
|
25 |
+
dropout: 0.0
|
26 |
+
lossconfig:
|
27 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
28 |
+
params:
|
29 |
+
disc_conditional: false
|
30 |
+
disc_in_channels: 1
|
31 |
+
disc_start: 50000
|
32 |
+
disc_weight: 0.2
|
33 |
+
codebook_weight: 1.0
|
models/threshold_vqgan.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-05
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
image_key: threshold
|
6 |
+
ckpt_path:
|
7 |
+
embed_dim: 256
|
8 |
+
n_embed: 512
|
9 |
+
ddconfig:
|
10 |
+
double_z: false
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 1
|
14 |
+
out_ch: 1
|
15 |
+
ch: 128
|
16 |
+
ch_mult:
|
17 |
+
- 1
|
18 |
+
- 1
|
19 |
+
- 2
|
20 |
+
- 2
|
21 |
+
- 4
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions:
|
24 |
+
- 16
|
25 |
+
dropout: 0.0
|
26 |
+
lossconfig:
|
27 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
28 |
+
params:
|
29 |
+
disc_conditional: false
|
30 |
+
disc_in_channels: 1
|
31 |
+
disc_start: 100000
|
32 |
+
disc_weight: 0.2
|
33 |
+
codebook_weight: 1.0
|
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
torch
|
3 |
torchvision
|
4 |
-
huggingface_hub
|
5 |
gradio
|
6 |
OmegaConf
|
7 |
axial-positional-embedding
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
torch
|
3 |
torchvision
|
|
|
4 |
gradio
|
5 |
OmegaConf
|
6 |
axial-positional-embedding
|