Spaces:
Sleeping
Sleeping
Upload AI_models/FOREcasT/load_data.py with huggingface_hub
Browse files- AI_models/FOREcasT/load_data.py +251 -0
AI_models/FOREcasT/load_data.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from ..config import get_config
|
5 |
+
|
6 |
+
args = get_config(config_file="config_FOREcasT.ini")
|
7 |
+
|
8 |
+
lefts = np.concatenate([
|
9 |
+
np.arange(-DEL_SIZE, 1)
|
10 |
+
for DEL_SIZE in range(args.FOREcasT_MAX_DEL_SIZE, -1, -1)
|
11 |
+
] + [np.zeros(20, np.int64)])
|
12 |
+
rights = np.concatenate([
|
13 |
+
np.arange(0, DEL_SIZE + 1)
|
14 |
+
for DEL_SIZE in range(args.FOREcasT_MAX_DEL_SIZE, -1, -1)
|
15 |
+
] + [np.zeros(20, np.int64)])
|
16 |
+
inss = (args.FOREcasT_MAX_DEL_SIZE + 2) * (args.FOREcasT_MAX_DEL_SIZE + 1) // 2 * [""] + ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT"]
|
17 |
+
|
18 |
+
feature_DelSize = []
|
19 |
+
for left, right, ins_seq in zip(lefts, rights, inss):
|
20 |
+
dsize = right - left
|
21 |
+
feature_DelSize.append(
|
22 |
+
(len(ins_seq) == 0) & torch.tensor([True, dsize == 1, dsize >=2 and dsize < 4, dsize >=4 and dsize < 8, dsize >= 8 and dsize < 13, dsize >= 13])
|
23 |
+
)
|
24 |
+
feature_DelSize = torch.stack(feature_DelSize)
|
25 |
+
|
26 |
+
feature_InsSize = torch.tensor([
|
27 |
+
[len(ins_seq) > 0, len(ins_seq) == 1, len(ins_seq) == 2]
|
28 |
+
for ins_seq in inss
|
29 |
+
])
|
30 |
+
|
31 |
+
feature_DelLoc = []
|
32 |
+
for left, right, ins_seq in zip(lefts, rights, inss):
|
33 |
+
if len(ins_seq) > 0:
|
34 |
+
feature_DelLoc.append([False] * 18)
|
35 |
+
continue
|
36 |
+
feature_DelLoc.append([
|
37 |
+
left == 0, left == -1, left == -2, left > -2 and left <= -5, left > -5 and left <= -9, left > -9 and left <= -14, left > -14 and left <= -29, left < -29, left >= 1, right == 0, right == 1, right == 2, right > 2 and right <=5, right > 5 and right <= 9, right > 9 and right <= 14, right > 14 and right <= 29, right < 0, right > 30
|
38 |
+
])
|
39 |
+
feature_DelLoc = torch.tensor(feature_DelLoc)
|
40 |
+
|
41 |
+
feature_InsSeq = torch.cat([
|
42 |
+
torch.full(((args.FOREcasT_MAX_DEL_SIZE + 2) * (args.FOREcasT_MAX_DEL_SIZE + 1) // 2, 20), False),
|
43 |
+
torch.eye(20, dtype=torch.bool)
|
44 |
+
])
|
45 |
+
|
46 |
+
feature_InsLoc = []
|
47 |
+
for left, ins_seq in zip(lefts, inss):
|
48 |
+
if len(ins_seq) == 0:
|
49 |
+
feature_InsLoc.append([False] * 5)
|
50 |
+
continue
|
51 |
+
feature_InsLoc.append([
|
52 |
+
left == 0, left == -1, left == -2, left < -2, left >= 1
|
53 |
+
])
|
54 |
+
feature_InsLoc = torch.tensor(feature_InsLoc)
|
55 |
+
|
56 |
+
def get_feature_LocalCutSiteSequence(ref, cut):
|
57 |
+
return F.one_hot(
|
58 |
+
torch.from_numpy(
|
59 |
+
(np.frombuffer(ref[cut - 5:cut + 4].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
|
60 |
+
),
|
61 |
+
num_classes=4
|
62 |
+
).flatten()
|
63 |
+
|
64 |
+
def get_feature_LocalCutSiteSeqMatches(ref, cut):
|
65 |
+
offset1_bases = ref[cut - 2] + ref[cut - 1] * 2 + ref[cut] * 3 + ref[cut + 1] * 4
|
66 |
+
offset2_bases = ref[cut - 3] + ref[cut - 3:cut - 1] + ref[cut - 3:cut] + ref[cut - 3:cut + 1]
|
67 |
+
return (
|
68 |
+
F.one_hot(
|
69 |
+
torch.from_numpy(
|
70 |
+
(np.frombuffer(offset1_bases.encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
|
71 |
+
),
|
72 |
+
num_classes=4
|
73 |
+
).flatten() *
|
74 |
+
F.one_hot(
|
75 |
+
torch.from_numpy(
|
76 |
+
(np.frombuffer(offset2_bases.encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
|
77 |
+
),
|
78 |
+
num_classes=4
|
79 |
+
).flatten()
|
80 |
+
)
|
81 |
+
|
82 |
+
def get_feature_LocalRelativeSequence(ref, cut, left, right, ins_seq):
|
83 |
+
if len(ins_seq) > 0:
|
84 |
+
return torch.zeros(48, dtype=torch.int64)
|
85 |
+
return torch.cat([
|
86 |
+
F.one_hot(
|
87 |
+
torch.from_numpy(
|
88 |
+
(np.frombuffer(ref[cut + left - 3:cut + left + 3].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
|
89 |
+
),
|
90 |
+
num_classes=4
|
91 |
+
).flatten(),
|
92 |
+
F.one_hot(
|
93 |
+
torch.from_numpy(
|
94 |
+
(np.frombuffer(ref[cut + right - 3:cut + right + 3].encode(), dtype=np.int8) % 5).clip(max=3).astype(np.int64)
|
95 |
+
),
|
96 |
+
num_classes=4
|
97 |
+
).flatten()
|
98 |
+
])
|
99 |
+
|
100 |
+
def get_feature_SeqMatches(ref, cut, left, right, ins_seq):
|
101 |
+
if len(ins_seq) > 0:
|
102 |
+
return torch.zeros(72, dtype=torch.int64)
|
103 |
+
return F.one_hot(
|
104 |
+
torch.from_numpy(
|
105 |
+
(np.frombuffer(ref[cut + left - 3:cut + left + 3].encode(), dtype=np.int8)[:, None] == np.frombuffer(ref[cut + right - 3:cut + right + 3].encode(), dtype=np.int8)).astype(np.int64)
|
106 |
+
),
|
107 |
+
num_classes=2
|
108 |
+
).flatten()
|
109 |
+
|
110 |
+
def get_feature_I1or2Rpt(ref, cut, left, ins_seq):
|
111 |
+
if len(ins_seq) == 0:
|
112 |
+
return torch.full((4,), False)
|
113 |
+
return torch.tensor([ins_seq == ref[cut - 1], len(ins_seq) == 1 and ins_seq != ref[cut - 1], ins_seq == (ref[cut - 1] * 2), len(ins_seq) == 2 and ins_seq != (ref[cut - 1] * 2)]).logical_and(torch.tensor(left == 0))
|
114 |
+
|
115 |
+
def getLeftMH(ref, cut, left, right, mh_max=16):
|
116 |
+
left_mh = None
|
117 |
+
for i in range(1, mh_max + 2):
|
118 |
+
if i > mh_max or ref[cut + left - i] != ref[cut + right - i]:
|
119 |
+
if left_mh is None:
|
120 |
+
left_mh = i - 1
|
121 |
+
else:
|
122 |
+
left_mh_1 = i - 1
|
123 |
+
break
|
124 |
+
if left_mh == mh_max:
|
125 |
+
left_mh_1 = mh_max
|
126 |
+
return left_mh, left_mh_1
|
127 |
+
|
128 |
+
def getRightMH(ref, cut, left, right, mh_max=16):
|
129 |
+
right_mh = None
|
130 |
+
for i in range(0, mh_max + 1):
|
131 |
+
if i >= mh_max or ref[cut + left + i] != ref[cut + right + i]:
|
132 |
+
if right_mh is None:
|
133 |
+
right_mh = i
|
134 |
+
else:
|
135 |
+
right_mh_1 = i
|
136 |
+
break
|
137 |
+
if right_mh == mh_max:
|
138 |
+
right_mh_1 = mh_max
|
139 |
+
return right_mh, right_mh_1
|
140 |
+
|
141 |
+
def get_feature_microhomology(ref, cut, left, right, ins_seq):
|
142 |
+
if len(ins_seq) > 0:
|
143 |
+
return [False] * 21
|
144 |
+
left_mh, left_mh_1 = getLeftMH(ref, cut, left, right)
|
145 |
+
right_mh, right_mh_1 = getRightMH(ref, cut, left, right)
|
146 |
+
return [
|
147 |
+
left_mh == 1,
|
148 |
+
right_mh == 1,
|
149 |
+
left_mh == 2,
|
150 |
+
right_mh == 2,
|
151 |
+
left_mh == 3,
|
152 |
+
right_mh == 3,
|
153 |
+
left_mh_1 == 3,
|
154 |
+
right_mh_1 == 3,
|
155 |
+
left_mh >= 4 and left_mh < 7,
|
156 |
+
right_mh >= 4 and right_mh < 7,
|
157 |
+
left_mh_1 >= 4 and left_mh_1 < 7,
|
158 |
+
right_mh_1 >= 4 and right_mh_1 < 7,
|
159 |
+
left_mh >= 7 and left_mh < 11,
|
160 |
+
right_mh >= 7 and right_mh < 11,
|
161 |
+
left_mh_1 >= 7 and left_mh_1 < 11,
|
162 |
+
right_mh_1 >= 7 and right_mh_1 < 11,
|
163 |
+
left_mh >= 11 and left_mh < 16,
|
164 |
+
right_mh >= 11 and right_mh < 16,
|
165 |
+
left_mh_1 >= 11 and left_mh_1 < 16,
|
166 |
+
right_mh_1 >= 11 and right_mh_1 < 16,
|
167 |
+
left_mh == 0 or left_mh >= 16 and right_mh == 0 or right_mh >= 16 and left_mh_1 == 0 or left_mh_1 >=16 and right_mh_1 == 0 or right_mh_1 >=16
|
168 |
+
]
|
169 |
+
|
170 |
+
def features_pairwise(features1, features2):
|
171 |
+
return (features1.unsqueeze(-1) * features2.unsqueeze(-2)).flatten(start_dim=-2)
|
172 |
+
|
173 |
+
feature_fix = torch.cat([
|
174 |
+
features_pairwise(feature_DelSize, feature_DelLoc),
|
175 |
+
feature_InsSize,
|
176 |
+
feature_DelSize,
|
177 |
+
feature_DelLoc,
|
178 |
+
feature_InsLoc,
|
179 |
+
feature_InsSeq
|
180 |
+
], dim=-1).to(torch.float32).unsqueeze(0)
|
181 |
+
feature_InsSize_DelSize = torch.cat([
|
182 |
+
feature_InsSize,
|
183 |
+
feature_DelSize
|
184 |
+
], dim=-1).to(torch.float32)
|
185 |
+
feature_DelSize_DelLoc = torch.cat([
|
186 |
+
feature_DelSize,
|
187 |
+
feature_DelLoc
|
188 |
+
], dim=-1).to(torch.float32)
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def data_collector(examples, output_count=True):
|
192 |
+
features_var = []
|
193 |
+
if output_count:
|
194 |
+
counts = []
|
195 |
+
for example in examples:
|
196 |
+
feature_I1or2Rpt, feature_LocalCutSiteSequence, feature_LocalCutSiteSeqMatches, feature_LocalRelativeSequence, feature_SeqMatches, feature_microhomology = [], [], [], [], [], []
|
197 |
+
for left, right, ins_seq in zip(lefts, rights, inss):
|
198 |
+
feature_I1or2Rpt.append(get_feature_I1or2Rpt(example["ref"], example["cut"], left, ins_seq))
|
199 |
+
feature_LocalCutSiteSequence.append(get_feature_LocalCutSiteSequence(example["ref"], example["cut"]))
|
200 |
+
feature_LocalCutSiteSeqMatches.append(get_feature_LocalCutSiteSeqMatches(example["ref"], example["cut"]))
|
201 |
+
feature_LocalRelativeSequence.append(get_feature_LocalRelativeSequence(example["ref"], example["cut"], left, right, ins_seq))
|
202 |
+
feature_SeqMatches.append(get_feature_SeqMatches(example["ref"], example["cut"], left, right, ins_seq))
|
203 |
+
feature_microhomology.append(get_feature_microhomology(example["ref"], example["cut"], left, right, ins_seq))
|
204 |
+
feature_I1or2Rpt = torch.stack(feature_I1or2Rpt)
|
205 |
+
feature_LocalCutSiteSequence = torch.stack(feature_LocalCutSiteSequence)
|
206 |
+
feature_LocalCutSiteSeqMatches = torch.stack(feature_LocalCutSiteSeqMatches)
|
207 |
+
feature_LocalRelativeSequence = torch.stack(feature_LocalRelativeSequence)
|
208 |
+
feature_SeqMatches = torch.stack(feature_SeqMatches)
|
209 |
+
feature_microhomology = torch.tensor(feature_microhomology)
|
210 |
+
features_var.append(torch.cat([
|
211 |
+
features_pairwise(feature_LocalCutSiteSequence, feature_InsSize_DelSize),
|
212 |
+
features_pairwise(
|
213 |
+
torch.cat([
|
214 |
+
feature_microhomology,
|
215 |
+
feature_LocalRelativeSequence
|
216 |
+
], dim=-1),
|
217 |
+
feature_DelSize_DelLoc
|
218 |
+
),
|
219 |
+
features_pairwise(
|
220 |
+
torch.cat([
|
221 |
+
feature_LocalCutSiteSeqMatches,
|
222 |
+
feature_SeqMatches
|
223 |
+
], dim=-1),
|
224 |
+
feature_DelSize
|
225 |
+
),
|
226 |
+
features_pairwise(
|
227 |
+
torch.cat([
|
228 |
+
feature_InsSeq,
|
229 |
+
feature_LocalCutSiteSequence,
|
230 |
+
feature_LocalCutSiteSeqMatches
|
231 |
+
], dim=-1),
|
232 |
+
feature_I1or2Rpt
|
233 |
+
),
|
234 |
+
feature_I1or2Rpt,
|
235 |
+
feature_LocalCutSiteSequence,
|
236 |
+
feature_LocalCutSiteSeqMatches,
|
237 |
+
feature_LocalRelativeSequence,
|
238 |
+
feature_SeqMatches,
|
239 |
+
feature_microhomology
|
240 |
+
], dim=-1).to(torch.float32))
|
241 |
+
if output_count:
|
242 |
+
counts.append(example["count"])
|
243 |
+
features = torch.cat([feature_fix.expand(len(examples), -1, -1), torch.stack(features_var)], dim=-1)
|
244 |
+
if output_count:
|
245 |
+
return {
|
246 |
+
"feature": features,
|
247 |
+
"count": torch.tensor(counts)
|
248 |
+
}
|
249 |
+
return {
|
250 |
+
"feature": features
|
251 |
+
}
|