hdallatorre
commited on
Commit
•
821d743
1
Parent(s):
d5caa84
Update README.md
Browse files
README.md
CHANGED
@@ -40,10 +40,9 @@ import torch
|
|
40 |
# Import the tokenizer and the model
|
41 |
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
|
42 |
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
|
43 |
-
|
44 |
# Create a dummy dna sequence and tokenize it
|
45 |
-
sequences = [
|
46 |
-
tokens_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt")["input_ids"]
|
47 |
|
48 |
# Compute the embeddings
|
49 |
attention_mask = tokens_ids != tokenizer.pad_token_id
|
@@ -59,8 +58,11 @@ embeddings = torch_outs['hidden_states'][-1].detach().numpy()
|
|
59 |
print(f"Embeddings shape: {embeddings.shape}")
|
60 |
print(f"Embeddings per token: {embeddings}")
|
61 |
|
|
|
|
|
|
|
62 |
# Compute mean embeddings per sequence
|
63 |
-
mean_sequence_embeddings = torch.sum(attention_mask
|
64 |
print(f"Mean sequence embeddings: {mean_sequence_embeddings}")
|
65 |
```
|
66 |
|
|
|
40 |
# Import the tokenizer and the model
|
41 |
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
|
42 |
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
|
|
|
43 |
# Create a dummy dna sequence and tokenize it
|
44 |
+
sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
|
45 |
+
tokens_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt", padding="max_length", max_length = max_length)["input_ids"]
|
46 |
|
47 |
# Compute the embeddings
|
48 |
attention_mask = tokens_ids != tokenizer.pad_token_id
|
|
|
58 |
print(f"Embeddings shape: {embeddings.shape}")
|
59 |
print(f"Embeddings per token: {embeddings}")
|
60 |
|
61 |
+
# Add embed dimension axis
|
62 |
+
attention_mask = torch.unsqueeze(attention_mask, dim=-1)
|
63 |
+
|
64 |
# Compute mean embeddings per sequence
|
65 |
+
mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2)/torch.sum(attention_mask, axis=1)
|
66 |
print(f"Mean sequence embeddings: {mean_sequence_embeddings}")
|
67 |
```
|
68 |
|