AmelieSchreiber commited on
Commit
8294a2b
1 Parent(s): 4a0c4d7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -1
README.md CHANGED
@@ -44,7 +44,44 @@ Validation Macro Recall: 0.9966
44
  ```
45
 
46
  ## Using the model
47
- First, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  and store the file locally, then provide the local path in the the code below:
49
 
50
  ```python
 
44
  ```
45
 
46
  ## Using the model
47
+
48
+ First, download the `train_sequences.fasta` file and the `train_terms.tsv` file, and provide the local paths in the code below:
49
+
50
+ ```python
51
+ import os
52
+ import numpy as np
53
+ import torch
54
+ from transformers import AutoTokenizer, EsmForSequenceClassification, AdamW
55
+ from torch.nn.functional import binary_cross_entropy_with_logits
56
+ from sklearn.model_selection import train_test_split
57
+ from sklearn.metrics import f1_score, precision_score, recall_score
58
+ # from accelerate import Accelerator
59
+ from Bio import SeqIO
60
+
61
+ # Step 1: Data Preprocessing (Replace with your local paths)
62
+ fasta_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
63
+ tsv_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_terms.tsv"
64
+
65
+ fasta_data = {}
66
+ tsv_data = {}
67
+
68
+ for record in SeqIO.parse(fasta_file, "fasta"):
69
+ fasta_data[record.id] = str(record.seq)
70
+
71
+ with open(tsv_file, 'r') as f:
72
+ for line in f:
73
+ parts = line.strip().split("\t")
74
+ tsv_data[parts[0]] = parts[1:]
75
+
76
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
77
+ seq_length = 1022
78
+ # tokenized_data = tokenizer(list(fasta_data.values()), padding=True, truncation=True, return_tensors="pt", max_length=seq_length)
79
+
80
+ unique_terms = list(set(term for terms in tsv_data.values() for term in terms))
81
+ ```
82
+
83
+
84
+ Second, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
85
  and store the file locally, then provide the local path in the the code below:
86
 
87
  ```python