Add pyt model from Kipo examples
Browse filesSource revision:
https://github.com/kipoi/kipoi/tree/6b5460c1cd1ba9667c23b7cb029640116147646b/example/models/pyt
- .gitattributes +1 -0
- dataloader.py +77 -0
- dataloader.yaml +45 -0
- example_files/hg38_chr22.fa +3 -0
- example_files/hg38_chr22.fa.fai +1 -0
- example_files/intervals.tsv +14 -0
- example_files/test.json +3 -0
- expected.pred.h5 +3 -0
- model.yaml +32 -0
- model_files/full_model.pth +3 -0
- model_files/only_weights.pth +3 -0
- model_files/pyt.py +75 -0
- pyt.py +1 -0
- wrong.pred.h5 +3 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.fa filter=lfs diff=lfs merge=lfs -text
|
dataloader.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""DeepSEA dataloader
|
2 |
+
"""
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import pybedtools
|
6 |
+
from pybedtools import BedTool
|
7 |
+
from kipoi.data import Dataset
|
8 |
+
from kipoi.metadata import GenomicRanges
|
9 |
+
from kipoiseq.extractors import FastaStringExtractor
|
10 |
+
import linecache
|
11 |
+
from kipoiseq.transforms.functional import one_hot_dna
|
12 |
+
|
13 |
+
# --------------------------------------------
|
14 |
+
|
15 |
+
|
16 |
+
class BedToolLinecache(BedTool):
|
17 |
+
"""Fast BedTool accessor by Ziga Avsec
|
18 |
+
|
19 |
+
Normal BedTools loops through the whole file to get the
|
20 |
+
line of interest. Hence the access it o(n)
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
line = linecache.getline(self.fn, idx + 1)
|
25 |
+
return pybedtools.create_interval_from_list(line.strip().split("\t"))
|
26 |
+
|
27 |
+
|
28 |
+
class SeqDataset(Dataset):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
intervals_file: bed3 file containing intervals
|
32 |
+
fasta_file: file path; Genome sequence
|
33 |
+
target_file: file path; path to the targets in the csv format
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, intervals_file, fasta_file, target_file=None, use_linecache=False):
|
37 |
+
|
38 |
+
# intervals
|
39 |
+
if use_linecache:
|
40 |
+
self.bt = BedToolLinecache(intervals_file)
|
41 |
+
else:
|
42 |
+
self.bt = BedTool(intervals_file)
|
43 |
+
self.fasta_file = fasta_file
|
44 |
+
self.fasta_extractor = None
|
45 |
+
|
46 |
+
# Targets
|
47 |
+
if target_file is not None:
|
48 |
+
self.targets = pd.read_csv(target_file)
|
49 |
+
else:
|
50 |
+
self.targets = None
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.bt)
|
54 |
+
|
55 |
+
def __getitem__(self, idx):
|
56 |
+
if self.fasta_extractor is None:
|
57 |
+
self.fasta_extractor = FastaStringExtractor(self.fasta_file)
|
58 |
+
|
59 |
+
interval = self.bt[idx]
|
60 |
+
|
61 |
+
# Intervals need to be 1000bp wide
|
62 |
+
assert interval.stop - interval.start == 1000
|
63 |
+
|
64 |
+
if self.targets is not None:
|
65 |
+
y = self.targets.iloc[idx].values
|
66 |
+
else:
|
67 |
+
y = {}
|
68 |
+
|
69 |
+
# Run the fasta extractor
|
70 |
+
seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release
|
71 |
+
return {
|
72 |
+
"inputs": seq,
|
73 |
+
"targets": y,
|
74 |
+
"metadata": {
|
75 |
+
"ranges": GenomicRanges.from_interval(interval)
|
76 |
+
}
|
77 |
+
}
|
dataloader.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defined_as: dataloader.SeqDataset
|
2 |
+
args:
|
3 |
+
intervals_file:
|
4 |
+
doc: bed3 file with `chrom start end id score strand`
|
5 |
+
example: example_files/intervals.tsv
|
6 |
+
fasta_file:
|
7 |
+
doc: Reference genome sequence
|
8 |
+
example: example_files/hg38_chr22.fa
|
9 |
+
target_file:
|
10 |
+
doc: path to the targets (.tsv) file
|
11 |
+
optional: True
|
12 |
+
use_linecache:
|
13 |
+
doc: if True, use linecache https://docs.python.org/3/library/linecache.html to access bed file rows
|
14 |
+
optional: True
|
15 |
+
info:
|
16 |
+
authors:
|
17 |
+
- name: Lara Urban
|
18 |
+
github: LaraUrban
|
19 |
+
- name: Ziga Avsec
|
20 |
+
github: avsecz
|
21 |
+
doc: Dataloader for the DeepSEA model.
|
22 |
+
dependencies:
|
23 |
+
conda:
|
24 |
+
- python
|
25 |
+
- numpy
|
26 |
+
- pandas
|
27 |
+
- cython
|
28 |
+
pip:
|
29 |
+
- cython
|
30 |
+
- pybedtools
|
31 |
+
output_schema:
|
32 |
+
inputs:
|
33 |
+
name: input
|
34 |
+
shape: (1000, 4)
|
35 |
+
special_type: DNASeq
|
36 |
+
doc: DNA sequence
|
37 |
+
associated_metadata: ranges
|
38 |
+
targets:
|
39 |
+
name: epigen_mod
|
40 |
+
shape: (1, )
|
41 |
+
doc: Specific epigentic feature class (multi-task binary classification)
|
42 |
+
metadata:
|
43 |
+
ranges:
|
44 |
+
type: GenomicRanges
|
45 |
+
doc: Ranges describing inputs.seq
|
example_files/hg38_chr22.fa
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03397edb6bea565057a4f6f643daaec4399d8a8429eec15ee1e2845fad850fe6
|
3 |
+
size 50818476
|
example_files/hg38_chr22.fa.fai
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
chr22 50818468 7 50818468 50818469
|
example_files/intervals.tsv
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chr22 4997 5997 1 0 +
|
2 |
+
chr22 5330 6330 2 0 -
|
3 |
+
chr22 6728 7728 3 0 -
|
4 |
+
chr22 3482 4482 4 0 +
|
5 |
+
chr22 7989 8989 5 0 +
|
6 |
+
chr22 8136 9136 6 0 +
|
7 |
+
chr22 3617 4617 7 0 -
|
8 |
+
chr22 7887 8887 8 0 +
|
9 |
+
chr22 8428 9428 9 0 +
|
10 |
+
chr22 9444 10444 10 0 +
|
11 |
+
chr22 41 1041 11 0 +
|
12 |
+
chr22 5777 6777 12 0 +
|
13 |
+
chr22 7084 8084 13 0 +
|
14 |
+
chr22 5725 6725 14 0 +
|
example_files/test.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"intervals_file": "intervals.tsv",
|
2 |
+
"fasta_file": "hg38_chr22.fa"
|
3 |
+
}
|
expected.pred.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27ec18f672e5afc40381b6874c885b68f7c7ee805064d926e1dd91614bc66590
|
3 |
+
size 183002
|
model.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defined_as: kipoi.model.PyTorchModel
|
2 |
+
args:
|
3 |
+
module_file: model_files/pyt.py
|
4 |
+
module_obj: simple_model
|
5 |
+
weights: model_files/only_weights.pth
|
6 |
+
default_dataloader: . # path to the directory
|
7 |
+
info:
|
8 |
+
authors:
|
9 |
+
- name: Roman Kreuzhuber
|
10 |
+
github: krrome
|
11 |
+
doc: Simple testing model for pytorch
|
12 |
+
dependencies:
|
13 |
+
conda:
|
14 |
+
- pytorch::pytorch>=0.2.0
|
15 |
+
schema:
|
16 |
+
inputs:
|
17 |
+
name: input
|
18 |
+
shape: (1000, 4)
|
19 |
+
special_type: DNASeq
|
20 |
+
doc: DNA sequence
|
21 |
+
# associated_metadata: ranges # --> has to be defined in dataloader.yaml.
|
22 |
+
# This field is ignored in model.yaml.
|
23 |
+
targets:
|
24 |
+
shape: (1, )
|
25 |
+
doc: Predicted binding strength
|
26 |
+
column_labels:
|
27 |
+
- some_probability
|
28 |
+
test:
|
29 |
+
expect:
|
30 |
+
url: https://s3.eu-central-1.amazonaws.com/kipoi-models/predictions/example/models/pyt/expected.pred.h5
|
31 |
+
md5: d6d0779a7bdfb1301c76a59defd293ed
|
32 |
+
precision_decimal: 6
|
model_files/full_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb7034af1f15bf0dc242a41645b3bc781486964818fb6710f1fc78e1ca34b12b
|
3 |
+
size 1607392
|
model_files/only_weights.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b963f8eb38ae2ceab5a24d08a437f12b8fa94cb2cd5046be0475712e24d28ed
|
3 |
+
size 1601416
|
model_files/pyt.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
class Flatten(torch.nn.Module):
|
7 |
+
# https://gist.github.com/VoVAllen/5531c78a2d3f1ff3df772038bca37a83
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
super(Flatten, self).__init__()
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return x.view(x.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
def get_model():
|
17 |
+
# N is batch size; D_in is input dimension;
|
18 |
+
# H is hidden dimension; D_out is output dimension.
|
19 |
+
D_in, H, D_out = 4000, 100, 1
|
20 |
+
|
21 |
+
model = torch.nn.Sequential(
|
22 |
+
Flatten(),
|
23 |
+
torch.nn.Linear(D_in, H),
|
24 |
+
torch.nn.ReLU(),
|
25 |
+
torch.nn.Linear(H, D_out),
|
26 |
+
torch.nn.Sigmoid(),
|
27 |
+
)
|
28 |
+
return model
|
29 |
+
|
30 |
+
simple_model = get_model()
|
31 |
+
|
32 |
+
|
33 |
+
def generate_exmaple_model():
|
34 |
+
# get model
|
35 |
+
model = get_model()
|
36 |
+
|
37 |
+
# define loss function
|
38 |
+
loss_func = torch.nn.MSELoss()
|
39 |
+
|
40 |
+
# define optimizer
|
41 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)
|
42 |
+
|
43 |
+
minibatch_size = 10
|
44 |
+
np.random.seed(0)
|
45 |
+
x = torch.Tensor(50, 1000, 4).uniform_(0, 1)
|
46 |
+
y = torch.Tensor(50).uniform_(0, 1)
|
47 |
+
|
48 |
+
for epoch in tqdm(range(10)):
|
49 |
+
for mbi in tqdm(range(np.ceil(x.size()[0] / minibatch_size).astype(int))):
|
50 |
+
minibatch = x[(mbi * minibatch_size):min(((mbi + 1) * minibatch_size), x.size()[0])]
|
51 |
+
target = torch.autograd.Variable(y[(mbi * minibatch_size):min(((mbi + 1) * minibatch_size), x.size()[0])])
|
52 |
+
model.zero_grad()
|
53 |
+
|
54 |
+
# forward pass
|
55 |
+
out = model(torch.autograd.Variable(minibatch))
|
56 |
+
|
57 |
+
# backward pass
|
58 |
+
L = loss_func(out, target) # calculate loss
|
59 |
+
L.backward() # calculate gradients
|
60 |
+
optimizer.step() # make an update step
|
61 |
+
|
62 |
+
torch.save(model, "model_files/full_model.pth")
|
63 |
+
torch.save(model.state_dict(), "model_files/only_weights.pth")
|
64 |
+
|
65 |
+
## To comply with OldPyTorchModel
|
66 |
+
def get_model_w_weights():
|
67 |
+
model = get_model()
|
68 |
+
model.load_state_dict(torch.load("model_files/only_weights.pth"))
|
69 |
+
return model
|
70 |
+
|
71 |
+
def test_same_weights(dict1, dict2):
|
72 |
+
for k in dict1:
|
73 |
+
assert np.all(dict1[k].numpy() == dict2[k].numpy())
|
74 |
+
|
75 |
+
# test_same_weights(model.state_dict(), model_2.state_dict())
|
pyt.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
model_files/pyt.py
|
wrong.pred.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:178dccd292a4f787012e70dc91fc9377f85ec88637484802900480654223ec7e
|
3 |
+
size 183002
|