ljw20180420 commited on
Commit
b553c4d
·
verified ·
1 Parent(s): 32ba7fe

Upload AI_models/FOREcasT/load_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. AI_models/FOREcasT/load_data.py +251 -0
AI_models/FOREcasT/load_data.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import numpy as np
3
+ import torch
4
+ from ..config import get_config
5
+
6
+ args = get_config(config_file="config_FOREcasT.ini")
7
+
8
+ lefts = np.concatenate([
9
+ np.arange(-DEL_SIZE, 1)
10
+ for DEL_SIZE in range(args.FOREcasT_MAX_DEL_SIZE, -1, -1)
11
+ ] + [np.zeros(20, np.int64)])
12
+ rights = np.concatenate([
13
+ np.arange(0, DEL_SIZE + 1)
14
+ for DEL_SIZE in range(args.FOREcasT_MAX_DEL_SIZE, -1, -1)
15
+ ] + [np.zeros(20, np.int64)])
16
+ inss = (args.FOREcasT_MAX_DEL_SIZE + 2) * (args.FOREcasT_MAX_DEL_SIZE + 1) // 2 * [""] + ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT"]
17
+
18
+ feature_DelSize = []
19
+ for left, right, ins_seq in zip(lefts, rights, inss):
20
+ dsize = right - left
21
+ feature_DelSize.append(
22
+ (len(ins_seq) == 0) & torch.tensor([True, dsize == 1, dsize >=2 and dsize < 4, dsize >=4 and dsize < 8, dsize >= 8 and dsize < 13, dsize >= 13])
23
+ )
24
+ feature_DelSize = torch.stack(feature_DelSize)
25
+
26
+ feature_InsSize = torch.tensor([
27
+ [len(ins_seq) > 0, len(ins_seq) == 1, len(ins_seq) == 2]
28
+ for ins_seq in inss
29
+ ])
30
+
31
+ feature_DelLoc = []
32
+ for left, right, ins_seq in zip(lefts, rights, inss):
33
+ if len(ins_seq) > 0:
34
+ feature_DelLoc.append([False] * 18)
35
+ continue
36
+ feature_DelLoc.append([
37
+ left == 0, left == -1, left == -2, left > -2 and left <= -5, left > -5 and left <= -9, left > -9 and left <= -14, left > -14 and left <= -29, left < -29, left >= 1, right == 0, right == 1, right == 2, right > 2 and right <=5, right > 5 and right <= 9, right > 9 and right <= 14, right > 14 and right <= 29, right < 0, right > 30
38
+ ])
39
+ feature_DelLoc = torch.tensor(feature_DelLoc)
40
+
41
+ feature_InsSeq = torch.cat([
42
+ torch.full(((args.FOREcasT_MAX_DEL_SIZE + 2) * (args.FOREcasT_MAX_DEL_SIZE + 1) // 2, 20), False),
43
+ torch.eye(20, dtype=torch.bool)
44
+ ])
45
+
46
+ feature_InsLoc = []
47
+ for left, ins_seq in zip(lefts, inss):
48
+ if len(ins_seq) == 0:
49
+ feature_InsLoc.append([False] * 5)
50
+ continue
51
+ feature_InsLoc.append([
52
+ left == 0, left == -1, left == -2, left < -2, left >= 1
53
+ ])
54
+ feature_InsLoc = torch.tensor(feature_InsLoc)
55
+
56
+ def get_feature_LocalCutSiteSequence(ref, cut):
57
+ return F.one_hot(
58
+ torch.from_numpy(
59
+ (np.frombuffer(ref[cut - 5:cut + 4].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
60
+ ),
61
+ num_classes=4
62
+ ).flatten()
63
+
64
+ def get_feature_LocalCutSiteSeqMatches(ref, cut):
65
+ offset1_bases = ref[cut - 2] + ref[cut - 1] * 2 + ref[cut] * 3 + ref[cut + 1] * 4
66
+ offset2_bases = ref[cut - 3] + ref[cut - 3:cut - 1] + ref[cut - 3:cut] + ref[cut - 3:cut + 1]
67
+ return (
68
+ F.one_hot(
69
+ torch.from_numpy(
70
+ (np.frombuffer(offset1_bases.encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
71
+ ),
72
+ num_classes=4
73
+ ).flatten() *
74
+ F.one_hot(
75
+ torch.from_numpy(
76
+ (np.frombuffer(offset2_bases.encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
77
+ ),
78
+ num_classes=4
79
+ ).flatten()
80
+ )
81
+
82
+ def get_feature_LocalRelativeSequence(ref, cut, left, right, ins_seq):
83
+ if len(ins_seq) > 0:
84
+ return torch.zeros(48, dtype=torch.int64)
85
+ return torch.cat([
86
+ F.one_hot(
87
+ torch.from_numpy(
88
+ (np.frombuffer(ref[cut + left - 3:cut + left + 3].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
89
+ ),
90
+ num_classes=4
91
+ ).flatten(),
92
+ F.one_hot(
93
+ torch.from_numpy(
94
+ (np.frombuffer(ref[cut + right - 3:cut + right + 3].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
95
+ ),
96
+ num_classes=4
97
+ ).flatten()
98
+ ])
99
+
100
+ def get_feature_SeqMatches(ref, cut, left, right, ins_seq):
101
+ if len(ins_seq) > 0:
102
+ return torch.zeros(72, dtype=torch.int64)
103
+ return F.one_hot(
104
+ torch.from_numpy(
105
+ (np.frombuffer(ref[cut + left - 3:cut + left + 3].encode(), dtype=np.int8)[:, None] == np.frombuffer(ref[cut + right - 3:cut + right + 3].encode(), dtype=np.int8)).astype(np.int64)
106
+ ),
107
+ num_classes=2
108
+ ).flatten()
109
+
110
+ def get_feature_I1or2Rpt(ref, cut, left, ins_seq):
111
+ if len(ins_seq) == 0:
112
+ return torch.full((4,), False)
113
+ return torch.tensor([ins_seq == ref[cut - 1], len(ins_seq) == 1 and ins_seq != ref[cut - 1], ins_seq == (ref[cut - 1] * 2), len(ins_seq) == 2 and ins_seq != (ref[cut - 1] * 2)]).logical_and(torch.tensor(left == 0))
114
+
115
+ def getLeftMH(ref, cut, left, right, mh_max=16):
116
+ left_mh = None
117
+ for i in range(1, mh_max + 2):
118
+ if i > mh_max or ref[cut + left - i] != ref[cut + right - i]:
119
+ if left_mh is None:
120
+ left_mh = i - 1
121
+ else:
122
+ left_mh_1 = i - 1
123
+ break
124
+ if left_mh == mh_max:
125
+ left_mh_1 = mh_max
126
+ return left_mh, left_mh_1
127
+
128
+ def getRightMH(ref, cut, left, right, mh_max=16):
129
+ right_mh = None
130
+ for i in range(0, mh_max + 1):
131
+ if i >= mh_max or ref[cut + left + i] != ref[cut + right + i]:
132
+ if right_mh is None:
133
+ right_mh = i
134
+ else:
135
+ right_mh_1 = i
136
+ break
137
+ if right_mh == mh_max:
138
+ right_mh_1 = mh_max
139
+ return right_mh, right_mh_1
140
+
141
+ def get_feature_microhomology(ref, cut, left, right, ins_seq):
142
+ if len(ins_seq) > 0:
143
+ return [False] * 21
144
+ left_mh, left_mh_1 = getLeftMH(ref, cut, left, right)
145
+ right_mh, right_mh_1 = getRightMH(ref, cut, left, right)
146
+ return [
147
+ left_mh == 1,
148
+ right_mh == 1,
149
+ left_mh == 2,
150
+ right_mh == 2,
151
+ left_mh == 3,
152
+ right_mh == 3,
153
+ left_mh_1 == 3,
154
+ right_mh_1 == 3,
155
+ left_mh >= 4 and left_mh < 7,
156
+ right_mh >= 4 and right_mh < 7,
157
+ left_mh_1 >= 4 and left_mh_1 < 7,
158
+ right_mh_1 >= 4 and right_mh_1 < 7,
159
+ left_mh >= 7 and left_mh < 11,
160
+ right_mh >= 7 and right_mh < 11,
161
+ left_mh_1 >= 7 and left_mh_1 < 11,
162
+ right_mh_1 >= 7 and right_mh_1 < 11,
163
+ left_mh >= 11 and left_mh < 16,
164
+ right_mh >= 11 and right_mh < 16,
165
+ left_mh_1 >= 11 and left_mh_1 < 16,
166
+ right_mh_1 >= 11 and right_mh_1 < 16,
167
+ left_mh == 0 or left_mh >= 16 and right_mh == 0 or right_mh >= 16 and left_mh_1 == 0 or left_mh_1 >=16 and right_mh_1 == 0 or right_mh_1 >=16
168
+ ]
169
+
170
+ def features_pairwise(features1, features2):
171
+ return (features1.unsqueeze(-1) * features2.unsqueeze(-2)).flatten(start_dim=-2)
172
+
173
+ feature_fix = torch.cat([
174
+ features_pairwise(feature_DelSize, feature_DelLoc),
175
+ feature_InsSize,
176
+ feature_DelSize,
177
+ feature_DelLoc,
178
+ feature_InsLoc,
179
+ feature_InsSeq
180
+ ], dim=-1).to(torch.float32).unsqueeze(0)
181
+ feature_InsSize_DelSize = torch.cat([
182
+ feature_InsSize,
183
+ feature_DelSize
184
+ ], dim=-1).to(torch.float32)
185
+ feature_DelSize_DelLoc = torch.cat([
186
+ feature_DelSize,
187
+ feature_DelLoc
188
+ ], dim=-1).to(torch.float32)
189
+
190
+ @torch.no_grad()
191
+ def data_collector(examples, output_count=True):
192
+ features_var = []
193
+ if output_count:
194
+ counts = []
195
+ for example in examples:
196
+ feature_I1or2Rpt, feature_LocalCutSiteSequence, feature_LocalCutSiteSeqMatches, feature_LocalRelativeSequence, feature_SeqMatches, feature_microhomology = [], [], [], [], [], []
197
+ for left, right, ins_seq in zip(lefts, rights, inss):
198
+ feature_I1or2Rpt.append(get_feature_I1or2Rpt(example["ref"], example["cut"], left, ins_seq))
199
+ feature_LocalCutSiteSequence.append(get_feature_LocalCutSiteSequence(example["ref"], example["cut"]))
200
+ feature_LocalCutSiteSeqMatches.append(get_feature_LocalCutSiteSeqMatches(example["ref"], example["cut"]))
201
+ feature_LocalRelativeSequence.append(get_feature_LocalRelativeSequence(example["ref"], example["cut"], left, right, ins_seq))
202
+ feature_SeqMatches.append(get_feature_SeqMatches(example["ref"], example["cut"], left, right, ins_seq))
203
+ feature_microhomology.append(get_feature_microhomology(example["ref"], example["cut"], left, right, ins_seq))
204
+ feature_I1or2Rpt = torch.stack(feature_I1or2Rpt)
205
+ feature_LocalCutSiteSequence = torch.stack(feature_LocalCutSiteSequence)
206
+ feature_LocalCutSiteSeqMatches = torch.stack(feature_LocalCutSiteSeqMatches)
207
+ feature_LocalRelativeSequence = torch.stack(feature_LocalRelativeSequence)
208
+ feature_SeqMatches = torch.stack(feature_SeqMatches)
209
+ feature_microhomology = torch.tensor(feature_microhomology)
210
+ features_var.append(torch.cat([
211
+ features_pairwise(feature_LocalCutSiteSequence, feature_InsSize_DelSize),
212
+ features_pairwise(
213
+ torch.cat([
214
+ feature_microhomology,
215
+ feature_LocalRelativeSequence
216
+ ], dim=-1),
217
+ feature_DelSize_DelLoc
218
+ ),
219
+ features_pairwise(
220
+ torch.cat([
221
+ feature_LocalCutSiteSeqMatches,
222
+ feature_SeqMatches
223
+ ], dim=-1),
224
+ feature_DelSize
225
+ ),
226
+ features_pairwise(
227
+ torch.cat([
228
+ feature_InsSeq,
229
+ feature_LocalCutSiteSequence,
230
+ feature_LocalCutSiteSeqMatches
231
+ ], dim=-1),
232
+ feature_I1or2Rpt
233
+ ),
234
+ feature_I1or2Rpt,
235
+ feature_LocalCutSiteSequence,
236
+ feature_LocalCutSiteSeqMatches,
237
+ feature_LocalRelativeSequence,
238
+ feature_SeqMatches,
239
+ feature_microhomology
240
+ ], dim=-1).to(torch.float32))
241
+ if output_count:
242
+ counts.append(example["count"])
243
+ features = torch.cat([feature_fix.expand(len(examples), -1, -1), torch.stack(features_var)], dim=-1)
244
+ if output_count:
245
+ return {
246
+ "feature": features,
247
+ "count": torch.tensor(counts)
248
+ }
249
+ return {
250
+ "feature": features
251
+ }