dnouri commited on
Commit
9274888
·
1 Parent(s): 19aaf96

Add pyt model from Kipo examples

Browse files

Source revision:
https://github.com/kipoi/kipoi/tree/6b5460c1cd1ba9667c23b7cb029640116147646b/example/models/pyt

.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.fa filter=lfs diff=lfs merge=lfs -text
dataloader.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepSEA dataloader
2
+ """
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pybedtools
6
+ from pybedtools import BedTool
7
+ from kipoi.data import Dataset
8
+ from kipoi.metadata import GenomicRanges
9
+ from kipoiseq.extractors import FastaStringExtractor
10
+ import linecache
11
+ from kipoiseq.transforms.functional import one_hot_dna
12
+
13
+ # --------------------------------------------
14
+
15
+
16
+ class BedToolLinecache(BedTool):
17
+ """Fast BedTool accessor by Ziga Avsec
18
+
19
+ Normal BedTools loops through the whole file to get the
20
+ line of interest. Hence the access it o(n)
21
+ """
22
+
23
+ def __getitem__(self, idx):
24
+ line = linecache.getline(self.fn, idx + 1)
25
+ return pybedtools.create_interval_from_list(line.strip().split("\t"))
26
+
27
+
28
+ class SeqDataset(Dataset):
29
+ """
30
+ Args:
31
+ intervals_file: bed3 file containing intervals
32
+ fasta_file: file path; Genome sequence
33
+ target_file: file path; path to the targets in the csv format
34
+ """
35
+
36
+ def __init__(self, intervals_file, fasta_file, target_file=None, use_linecache=False):
37
+
38
+ # intervals
39
+ if use_linecache:
40
+ self.bt = BedToolLinecache(intervals_file)
41
+ else:
42
+ self.bt = BedTool(intervals_file)
43
+ self.fasta_file = fasta_file
44
+ self.fasta_extractor = None
45
+
46
+ # Targets
47
+ if target_file is not None:
48
+ self.targets = pd.read_csv(target_file)
49
+ else:
50
+ self.targets = None
51
+
52
+ def __len__(self):
53
+ return len(self.bt)
54
+
55
+ def __getitem__(self, idx):
56
+ if self.fasta_extractor is None:
57
+ self.fasta_extractor = FastaStringExtractor(self.fasta_file)
58
+
59
+ interval = self.bt[idx]
60
+
61
+ # Intervals need to be 1000bp wide
62
+ assert interval.stop - interval.start == 1000
63
+
64
+ if self.targets is not None:
65
+ y = self.targets.iloc[idx].values
66
+ else:
67
+ y = {}
68
+
69
+ # Run the fasta extractor
70
+ seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release
71
+ return {
72
+ "inputs": seq,
73
+ "targets": y,
74
+ "metadata": {
75
+ "ranges": GenomicRanges.from_interval(interval)
76
+ }
77
+ }
dataloader.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defined_as: dataloader.SeqDataset
2
+ args:
3
+ intervals_file:
4
+ doc: bed3 file with `chrom start end id score strand`
5
+ example: example_files/intervals.tsv
6
+ fasta_file:
7
+ doc: Reference genome sequence
8
+ example: example_files/hg38_chr22.fa
9
+ target_file:
10
+ doc: path to the targets (.tsv) file
11
+ optional: True
12
+ use_linecache:
13
+ doc: if True, use linecache https://docs.python.org/3/library/linecache.html to access bed file rows
14
+ optional: True
15
+ info:
16
+ authors:
17
+ - name: Lara Urban
18
+ github: LaraUrban
19
+ - name: Ziga Avsec
20
+ github: avsecz
21
+ doc: Dataloader for the DeepSEA model.
22
+ dependencies:
23
+ conda:
24
+ - python
25
+ - numpy
26
+ - pandas
27
+ - cython
28
+ pip:
29
+ - cython
30
+ - pybedtools
31
+ output_schema:
32
+ inputs:
33
+ name: input
34
+ shape: (1000, 4)
35
+ special_type: DNASeq
36
+ doc: DNA sequence
37
+ associated_metadata: ranges
38
+ targets:
39
+ name: epigen_mod
40
+ shape: (1, )
41
+ doc: Specific epigentic feature class (multi-task binary classification)
42
+ metadata:
43
+ ranges:
44
+ type: GenomicRanges
45
+ doc: Ranges describing inputs.seq
example_files/hg38_chr22.fa ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03397edb6bea565057a4f6f643daaec4399d8a8429eec15ee1e2845fad850fe6
3
+ size 50818476
example_files/hg38_chr22.fa.fai ADDED
@@ -0,0 +1 @@
 
 
1
+ chr22 50818468 7 50818468 50818469
example_files/intervals.tsv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chr22 4997 5997 1 0 +
2
+ chr22 5330 6330 2 0 -
3
+ chr22 6728 7728 3 0 -
4
+ chr22 3482 4482 4 0 +
5
+ chr22 7989 8989 5 0 +
6
+ chr22 8136 9136 6 0 +
7
+ chr22 3617 4617 7 0 -
8
+ chr22 7887 8887 8 0 +
9
+ chr22 8428 9428 9 0 +
10
+ chr22 9444 10444 10 0 +
11
+ chr22 41 1041 11 0 +
12
+ chr22 5777 6777 12 0 +
13
+ chr22 7084 8084 13 0 +
14
+ chr22 5725 6725 14 0 +
example_files/test.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"intervals_file": "intervals.tsv",
2
+ "fasta_file": "hg38_chr22.fa"
3
+ }
expected.pred.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27ec18f672e5afc40381b6874c885b68f7c7ee805064d926e1dd91614bc66590
3
+ size 183002
model.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defined_as: kipoi.model.PyTorchModel
2
+ args:
3
+ module_file: model_files/pyt.py
4
+ module_obj: simple_model
5
+ weights: model_files/only_weights.pth
6
+ default_dataloader: . # path to the directory
7
+ info:
8
+ authors:
9
+ - name: Roman Kreuzhuber
10
+ github: krrome
11
+ doc: Simple testing model for pytorch
12
+ dependencies:
13
+ conda:
14
+ - pytorch::pytorch>=0.2.0
15
+ schema:
16
+ inputs:
17
+ name: input
18
+ shape: (1000, 4)
19
+ special_type: DNASeq
20
+ doc: DNA sequence
21
+ # associated_metadata: ranges # --> has to be defined in dataloader.yaml.
22
+ # This field is ignored in model.yaml.
23
+ targets:
24
+ shape: (1, )
25
+ doc: Predicted binding strength
26
+ column_labels:
27
+ - some_probability
28
+ test:
29
+ expect:
30
+ url: https://s3.eu-central-1.amazonaws.com/kipoi-models/predictions/example/models/pyt/expected.pred.h5
31
+ md5: d6d0779a7bdfb1301c76a59defd293ed
32
+ precision_decimal: 6
model_files/full_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb7034af1f15bf0dc242a41645b3bc781486964818fb6710f1fc78e1ca34b12b
3
+ size 1607392
model_files/only_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b963f8eb38ae2ceab5a24d08a437f12b8fa94cb2cd5046be0475712e24d28ed
3
+ size 1601416
model_files/pyt.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+
6
+ class Flatten(torch.nn.Module):
7
+ # https://gist.github.com/VoVAllen/5531c78a2d3f1ff3df772038bca37a83
8
+
9
+ def __init__(self):
10
+ super(Flatten, self).__init__()
11
+
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ def get_model():
17
+ # N is batch size; D_in is input dimension;
18
+ # H is hidden dimension; D_out is output dimension.
19
+ D_in, H, D_out = 4000, 100, 1
20
+
21
+ model = torch.nn.Sequential(
22
+ Flatten(),
23
+ torch.nn.Linear(D_in, H),
24
+ torch.nn.ReLU(),
25
+ torch.nn.Linear(H, D_out),
26
+ torch.nn.Sigmoid(),
27
+ )
28
+ return model
29
+
30
+ simple_model = get_model()
31
+
32
+
33
+ def generate_exmaple_model():
34
+ # get model
35
+ model = get_model()
36
+
37
+ # define loss function
38
+ loss_func = torch.nn.MSELoss()
39
+
40
+ # define optimizer
41
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)
42
+
43
+ minibatch_size = 10
44
+ np.random.seed(0)
45
+ x = torch.Tensor(50, 1000, 4).uniform_(0, 1)
46
+ y = torch.Tensor(50).uniform_(0, 1)
47
+
48
+ for epoch in tqdm(range(10)):
49
+ for mbi in tqdm(range(np.ceil(x.size()[0] / minibatch_size).astype(int))):
50
+ minibatch = x[(mbi * minibatch_size):min(((mbi + 1) * minibatch_size), x.size()[0])]
51
+ target = torch.autograd.Variable(y[(mbi * minibatch_size):min(((mbi + 1) * minibatch_size), x.size()[0])])
52
+ model.zero_grad()
53
+
54
+ # forward pass
55
+ out = model(torch.autograd.Variable(minibatch))
56
+
57
+ # backward pass
58
+ L = loss_func(out, target) # calculate loss
59
+ L.backward() # calculate gradients
60
+ optimizer.step() # make an update step
61
+
62
+ torch.save(model, "model_files/full_model.pth")
63
+ torch.save(model.state_dict(), "model_files/only_weights.pth")
64
+
65
+ ## To comply with OldPyTorchModel
66
+ def get_model_w_weights():
67
+ model = get_model()
68
+ model.load_state_dict(torch.load("model_files/only_weights.pth"))
69
+ return model
70
+
71
+ def test_same_weights(dict1, dict2):
72
+ for k in dict1:
73
+ assert np.all(dict1[k].numpy() == dict2[k].numpy())
74
+
75
+ # test_same_weights(model.state_dict(), model_2.state_dict())
pyt.py ADDED
@@ -0,0 +1 @@
 
 
1
+ model_files/pyt.py
wrong.pred.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:178dccd292a4f787012e70dc91fc9377f85ec88637484802900480654223ec7e
3
+ size 183002