datasets: | |
- genbio-ai/rna-downstream-tasks | |
base_model: | |
- genbio-ai/rnafm-1.6b | |
LoRA fine-tuned checkpoint for splice site prediction. | |
## How to Use | |
### Download model | |
```python | |
from huggingface_hub import snapshot_download | |
from pathlib import Path | |
model_name = "genbio-ai/rnafm-1.6b-csp-acceptor-ckpt" | |
genbio_models_path = Path.home().joinpath('genbio_models', model_name) | |
genbio_models_path.mkdir(parents=True, exist_ok=True) | |
snapshot_download(repo_id=model_name, local_dir=genbio_models_path) | |
``` | |
### Load model for inference | |
```python | |
import torch | |
from modelgenerator.tasks import SequenceClassification | |
ckpt_path = genbio_models_path.joinpath('model.ckpt') | |
model = SequenceClassification.load_from_checkpoint(ckpt_path, strict_loading=False).eval() | |
collated_batch = model.transform({"sequences": ["ACGT", "AGCT"]}) | |
logits = model(collated_batch) | |
print(logits) | |
print(torch.argmax(logits, dim=-1)) | |
``` |