import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch model_name = "yangheng/PlantRNA-FM" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) def predict_rna(sequence): inputs = tokenizer(sequence, return_tensors="pt") mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1] # 找到 的位置 with torch.no_grad(): outputs = model(**inputs) mask_token_logits = outputs.logits[0, mask_token_index, :] predicted_token_ids = torch.argmax(mask_token_logits, dim=-1) predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids) return " ".join(predicted_tokens) input_text = gr.Textbox(lines=2, placeholder="Input RNA Sequence with , e.g., AAAGAGTCATATACGATATTGTCGACCGTGGAGAGAGAAGAATGTACGATTGGAGT") output_text = gr.Textbox() app = gr.Interface( fn=predict_rna, inputs=input_text, outputs=output_text, title="Zero-shot PlantFM-RNA MNM Inference", description="Zero-shot PlantFM-RNA MNM Inference: Predicts only the tokens." ) if __name__ == "__main__": app.launch()