Caleb Ellington
update with major refactor
b7a40ba
|
raw
history blame
914 Bytes
metadata
datasets:
  - genbio-ai/rna-downstream-tasks
base_model:
  - genbio-ai/rnafm-1.6b

LoRA fine-tuned checkpoint for splice site prediction.

How to Use

Download model

from huggingface_hub import snapshot_download
from pathlib import Path

model_name = "genbio-ai/rnafm-1.6b-csp-acceptor-ckpt"
genbio_models_path = Path.home().joinpath('genbio_models', model_name)
genbio_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id=model_name, local_dir=genbio_models_path)

Load model for inference

import torch
from modelgenerator.tasks import SequenceClassification

ckpt_path = genbio_models_path.joinpath('model.ckpt')
model = SequenceClassification.load_from_checkpoint(ckpt_path, strict_loading=False).eval()

collated_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
logits = model(collated_batch)
print(logits)
print(torch.argmax(logits, dim=-1))