yangliz5 commited on
Commit
49ed5db
·
1 Parent(s): 73fe683

feat: Add DeepChopper Gradio app for DNA sequence analysis

Browse files
Files changed (2) hide show
  1. app.py +187 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import lightning
7
+ import torch
8
+ from datasets import Dataset
9
+ from torch.utils.data import DataLoader
10
+
11
+ import deepchopper
12
+ from deepchopper.deepchopper import default, encode_qual, remove_intervals_and_keep_left, smooth_label_region
13
+ from deepchopper.models.llm import (
14
+ tokenize_and_align_labels_and_quals,
15
+ )
16
+ from deepchopper.utils import (
17
+ summary_predict,
18
+ )
19
+
20
+
21
+ def parse_fq_record(text: str):
22
+ """Parse a single FASTQ record into a dictionary."""
23
+ lines = text.strip().split("\n")
24
+ for i in range(0, len(lines), 4):
25
+ content = lines[i : i + 4]
26
+ record_id, seq, _, qual = content
27
+ assert len(seq) == len(qual) # noqa: S101
28
+
29
+ yield {
30
+ "id": record_id,
31
+ "seq": seq,
32
+ "qual": encode_qual(qual, default.KMER_SIZE),
33
+ "target": [0, 0],
34
+ }
35
+
36
+
37
+ def load_dataset(text: str, tokenizer):
38
+ """Load dataset from text."""
39
+ dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch")
40
+ tokenized_dataset = dataset.map(
41
+ partial(
42
+ tokenize_and_align_labels_and_quals,
43
+ tokenizer=tokenizer,
44
+ max_length=tokenizer.max_len_single_sentence,
45
+ ),
46
+ num_proc=multiprocessing.cpu_count(), # type: ignore
47
+ ).remove_columns(["id", "seq", "qual", "target"])
48
+ return dataset, tokenized_dataset
49
+
50
+
51
+ def predict(
52
+ text: str,
53
+ smooth_window_size: int = 21,
54
+ min_interval_size: int = 13,
55
+ approved_interval_number: int = 20,
56
+ max_process_intervals: int = 8, # default is 4
57
+ batch_size: int = 1,
58
+ num_workers: int = 1,
59
+ ):
60
+ tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen")
61
+ dataset, tokenized_dataset = load_dataset(text, tokenizer)
62
+
63
+ dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
64
+ model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper")
65
+
66
+ accelerator = "cpu" if torch.cuda.is_available() else "gpu"
67
+ trainer = lightning.pytorch.trainer.Trainer(
68
+ accelerator=accelerator,
69
+ devices=-1,
70
+ deterministic=False,
71
+ logger=False,
72
+ )
73
+
74
+ predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True)
75
+
76
+ assert len(predicts) == 1 # noqa: S101
77
+
78
+ smooth_interval_json: list[dict[str, int]] = []
79
+ highlighted_text: list[tuple[str, str | None]] = []
80
+
81
+ for idx, preds in enumerate(predicts):
82
+ true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1])
83
+
84
+ _id = dataset[idx]["id"]
85
+ seq = dataset[idx]["seq"]
86
+
87
+ smooth_predict_targets = smooth_label_region(
88
+ true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number
89
+ )
90
+
91
+ if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals:
92
+ continue
93
+
94
+ # zip two consecutive elements
95
+ _selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets)
96
+ total_intervals = sorted(selected_intervals + smooth_predict_targets)
97
+
98
+ smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets)
99
+
100
+ highlighted_text.extend(
101
+ (seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None)
102
+ for interval in total_intervals
103
+ )
104
+ return smooth_interval_json, highlighted_text
105
+
106
+
107
+ def process_input(text: str | None, file: str | None):
108
+ """Process the input and return the prediction."""
109
+ if not text and not file:
110
+ gr.Warning("Both text and file are empty")
111
+
112
+ if file:
113
+ MAX_LINES = 4
114
+ file_content = []
115
+ with Path(file).open() as f:
116
+ for idx, line in enumerate(f):
117
+ if idx >= MAX_LINES:
118
+ break
119
+ file_content.append(line)
120
+ text = "".join(file_content)
121
+ return predict(text=text)
122
+
123
+ return predict(text=text)
124
+
125
+
126
+ def create_gradio_app():
127
+ """Create a Gradio app for DeepChopper."""
128
+ example = (
129
+ "@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n"
130
+ "GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n"
131
+ "+\n"
132

133
+ )
134
+
135
+ custom_css = """
136
+ .header { text-align: center; margin-bottom: 30px; }
137
+ .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
138
+ """
139
+
140
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
141
+ gr.HTML(
142
+ """
143
+ <div class="header">
144
+ <h1>🧬 DeepChopper: DNA Sequence Analysis</h1>
145
+ <p>Analyze DNA sequences and detect artificial sequences</p>
146
+ </div>
147
+ """
148
+ )
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=1):
152
+ text_input = gr.Textbox(
153
+ label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10
154
+ )
155
+ file_input = gr.File(label="Or upload a FASTQ file")
156
+ submit_btn = gr.Button("Analyze", variant="primary")
157
+
158
+ with gr.Column(scale=1):
159
+ json_output = gr.JSON(label="Detected Artificial Regions")
160
+ highlighted_text = gr.HighlightedText(label="Highlighted Sequence")
161
+
162
+ submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text])
163
+
164
+ gr.Examples(
165
+ examples=[[example]],
166
+ inputs=[text_input],
167
+ )
168
+
169
+ gr.HTML(
170
+ """
171
+ <div class="footer">
172
+ <p>DeepChopper - Powered by AI for DNA sequence analysis</p>
173
+ </div>
174
+ """
175
+ )
176
+
177
+ return demo
178
+
179
+
180
+ def main():
181
+ """Launch the Gradio app."""
182
+ app = create_gradio_app()
183
+ app.launch()
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.1.0
2
+ lightning>=2.1.2
3
+ datasets>=2.17.1
4
+ deepchopper>=1.0.1