EmaadKhwaja commited on
Commit
2eb6d66
1 Parent(s): 860c3d7

add models

Browse files
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
  from prediction import run_image_prediction
4
  import torch
5
  import torchvision.transforms as T
@@ -9,10 +8,9 @@ from matplotlib import pyplot as plt
9
 
10
 
11
  def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
12
- model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
13
- config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
14
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
15
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  if "Finetuned" in model_name:
 
1
  import gradio as gr
 
2
  from prediction import run_image_prediction
3
  import torch
4
  import torchvision.transforms as T
 
8
 
9
 
10
  def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
11
+ model = f"models/{model_name}.ckpt"
12
+ config = f"models/{model_name}.yaml"
13
+
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  if "Finetuned" in model_name:
models/HPA_480.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14b445eedda4f46e67c7d848df3301363e9034d752591a1275b47eaf483e0cac
3
+ size 4731482483
models/HPA_480.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ learning_rate: 0.0003
3
+ target: celle_main.CELLE_trainer
4
+ params:
5
+ ckpt_path: model.ckpt
6
+ condition_model_path:
7
+ condition_config_path: nucleus_vqgan.yaml
8
+ vqgan_model_path:
9
+ vqgan_config_path: threshold_vqgan.yaml
10
+ image_key: threshold
11
+ num_images: 2
12
+ dim: 480
13
+ num_text_tokens: 33
14
+ text_seq_len: 1000
15
+ depth: 68
16
+ heads: 16
17
+ dim_head: 64
18
+ attn_dropout: 0.1
19
+ ff_dropout: 0.1
20
+ attn_types: full
21
+ rotary_emb: true
22
+ fixed_embedding: true
23
+ text_embedding: esm2
24
+ loss_img_weight: 1
25
+ loss_cond_weight: 1
models/HPA_Finetuned_480.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:152dfd30c8adfe83868ac0ffea4d075d25d11f24369f34482f79a3db4524d1be
3
+ size 4731482675
models/HPA_Finetuned_480.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ learning_rate: 0.0003
3
+ target: celle_main.CELLE_trainer
4
+ params:
5
+ ckpt_path: model.ckpt
6
+ condition_model_path:
7
+ condition_config_path: nucleus_vqgan.yaml
8
+ vqgan_model_path:
9
+ vqgan_config_path: threshold_vqgan.yaml
10
+ image_key: threshold
11
+ num_images: 2
12
+ dim: 480
13
+ num_text_tokens: 33
14
+ text_seq_len: 1000
15
+ depth: 68
16
+ heads: 16
17
+ dim_head: 64
18
+ attn_dropout: 0.1
19
+ ff_dropout: 0.1
20
+ attn_types: full
21
+ rotary_emb: true
22
+ fixed_embedding: true
23
+ text_embedding: esm2
24
+ loss_img_weight: 1
25
+ loss_cond_weight: 1
models/nucleus_vqgan.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ image_key: nucleus
6
+ ckpt_path:
7
+ embed_dim: 256
8
+ n_embed: 512
9
+ ddconfig:
10
+ double_z: false
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 1
14
+ out_ch: 1
15
+ ch: 128
16
+ ch_mult:
17
+ - 1
18
+ - 1
19
+ - 2
20
+ - 2
21
+ - 4
22
+ num_res_blocks: 2
23
+ attn_resolutions:
24
+ - 16
25
+ dropout: 0.0
26
+ lossconfig:
27
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
28
+ params:
29
+ disc_conditional: false
30
+ disc_in_channels: 1
31
+ disc_start: 50000
32
+ disc_weight: 0.2
33
+ codebook_weight: 1.0
models/threshold_vqgan.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-05
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ image_key: threshold
6
+ ckpt_path:
7
+ embed_dim: 256
8
+ n_embed: 512
9
+ ddconfig:
10
+ double_z: false
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 1
14
+ out_ch: 1
15
+ ch: 128
16
+ ch_mult:
17
+ - 1
18
+ - 1
19
+ - 2
20
+ - 2
21
+ - 4
22
+ num_res_blocks: 2
23
+ attn_resolutions:
24
+ - 16
25
+ dropout: 0.0
26
+ lossconfig:
27
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
28
+ params:
29
+ disc_conditional: false
30
+ disc_in_channels: 1
31
+ disc_start: 100000
32
+ disc_weight: 0.2
33
+ codebook_weight: 1.0
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  torchvision
4
- huggingface_hub
5
  gradio
6
  OmegaConf
7
  axial-positional-embedding
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  torchvision
 
4
  gradio
5
  OmegaConf
6
  axial-positional-embedding