|
import multiprocessing |
|
import gradio as gr |
|
import torch |
|
from omnigenome import OmniGenomeModelForRNADesign |
|
import RNA |
|
import tempfile |
|
import os |
|
|
|
|
|
model = OmniGenomeModelForRNADesign(model_path="anonymous8/OmniGenome-186M") |
|
model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
def design_rna(target_structure): |
|
if not 0 < len(target_structure) <= 50: |
|
return "The online demo only supports RNA structures with 1 to 100 characters.", None |
|
|
|
|
|
best_sequences = model.run_rna_design( |
|
structure=target_structure.strip(), |
|
mutation_ratio=0.5, |
|
num_population=50, |
|
num_generation=100 |
|
) |
|
|
|
|
|
best_sequence = best_sequences[0] |
|
|
|
|
|
plot_path = plot_rna_structure(best_sequence, target_structure) |
|
|
|
return best_sequence, plot_path |
|
|
|
|
|
|
|
def plot_rna_structure(sequence, structure): |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".svg") as tmpfile: |
|
plot_path = tmpfile.name |
|
|
|
|
|
RNA.svg_rna_plot(sequence, structure, plot_path) |
|
|
|
return plot_path |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
multiprocessing.set_start_method('spawn', force=True) |
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# RNA Design with OmniGenome") |
|
gr.Markdown( |
|
"Enter a target RNA secondary structure to generate a designed RNA sequence and visualize its structure. " |
|
"Please note that the online demo only supports RNA structures with 1 to 50 bases due to computational resource shortage." |
|
"For larger structures, please run the model locally." |
|
) |
|
gr.Markdown(""" |
|
### Example RNA Structures: |
|
- `(((((......)))))` |
|
- `((((((.((((....))))))).)))..........` |
|
- `((....)).((....))` |
|
- `.(((((((((((...)))))....)))))).` |
|
- `..((((((((.....))))((((.....))))))))..` |
|
""") |
|
|
|
with gr.Column(): |
|
target_structure_input = gr.Textbox( |
|
label="Target RNA Secondary Structure", |
|
placeholder="Enter RNA structure here, e.g., (((((......)))))" |
|
) |
|
output_sequence = gr.Textbox(label="Designed RNA Sequence") |
|
output_plot = gr.Image(type="filepath", label="RNA Structure Plot") |
|
|
|
|
|
submit_button = gr.Button("Submit") |
|
submit_button.click( |
|
design_rna, |
|
inputs=target_structure_input, |
|
outputs=[output_sequence, output_plot] |
|
) |
|
|
|
iface.launch() |
|
|