AmelieSchreiber
commited on
Commit
•
8294a2b
1
Parent(s):
4a0c4d7
Update README.md
Browse files
README.md
CHANGED
@@ -44,7 +44,44 @@ Validation Macro Recall: 0.9966
|
|
44 |
```
|
45 |
|
46 |
## Using the model
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
and store the file locally, then provide the local path in the the code below:
|
49 |
|
50 |
```python
|
|
|
44 |
```
|
45 |
|
46 |
## Using the model
|
47 |
+
|
48 |
+
First, download the `train_sequences.fasta` file and the `train_terms.tsv` file, and provide the local paths in the code below:
|
49 |
+
|
50 |
+
```python
|
51 |
+
import os
|
52 |
+
import numpy as np
|
53 |
+
import torch
|
54 |
+
from transformers import AutoTokenizer, EsmForSequenceClassification, AdamW
|
55 |
+
from torch.nn.functional import binary_cross_entropy_with_logits
|
56 |
+
from sklearn.model_selection import train_test_split
|
57 |
+
from sklearn.metrics import f1_score, precision_score, recall_score
|
58 |
+
# from accelerate import Accelerator
|
59 |
+
from Bio import SeqIO
|
60 |
+
|
61 |
+
# Step 1: Data Preprocessing (Replace with your local paths)
|
62 |
+
fasta_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
|
63 |
+
tsv_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_terms.tsv"
|
64 |
+
|
65 |
+
fasta_data = {}
|
66 |
+
tsv_data = {}
|
67 |
+
|
68 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
69 |
+
fasta_data[record.id] = str(record.seq)
|
70 |
+
|
71 |
+
with open(tsv_file, 'r') as f:
|
72 |
+
for line in f:
|
73 |
+
parts = line.strip().split("\t")
|
74 |
+
tsv_data[parts[0]] = parts[1:]
|
75 |
+
|
76 |
+
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
77 |
+
seq_length = 1022
|
78 |
+
# tokenized_data = tokenizer(list(fasta_data.values()), padding=True, truncation=True, return_tensors="pt", max_length=seq_length)
|
79 |
+
|
80 |
+
unique_terms = list(set(term for terms in tsv_data.values() for term in terms))
|
81 |
+
```
|
82 |
+
|
83 |
+
|
84 |
+
Second, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
|
85 |
and store the file locally, then provide the local path in the the code below:
|
86 |
|
87 |
```python
|