Spaces:
Sleeping
Sleeping
File size: 2,266 Bytes
466a8f2 f8402f9 eb9ae1f 3de389d f8402f9 d1b48a1 f8402f9 87c0dbc f8402f9 87c0dbc b7ab123 87c0dbc cfba77f 466a8f2 cfba77f 466a8f2 cfba77f 466a8f2 cfba77f 58c7b8d 1c423ed 3de389d 58c7b8d 3de389d 58c7b8d 58b87e3 58c7b8d 3de389d 58c7b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from Bio.PDB.Structure import Structure
from hexviz.attention import (ModelType, get_attention, get_sequences,
get_structure, unidirectional_avg_filtered)
def test_get_structure():
pdb_id = "2I62"
structure = get_structure(pdb_id)
assert structure is not None
assert isinstance(structure, Structure)
def test_get_sequences():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
sequences = get_sequences(structure)
assert sequences is not None
assert len(sequences) == 2
A, B = sequences
assert A[:3] == ["M", "R", "I"]
def test_get_attention_zymctrl():
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
assert result is not None
assert result.shape == torch.Size([36,16,3,3])
def test_get_attention_zymctrl_long_chain():
structure = get_structure(pdb_code="6A5J") # 13 residues long
sequences = get_sequences(structure)
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
assert result is not None
assert result.shape == torch.Size([36,16,13,13])
def test_get_attention_tape():
structure = get_structure(pdb_code="6A5J") # 13 residues long
sequences = get_sequences(structure)
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
assert result is not None
assert result.shape == torch.Size([12,12,13,13])
def test_get_attention_prot_bert():
result = get_attention("GGG", model_type=ModelType.PROT_BERT)
assert result is not None
assert result.shape == torch.Size([30, 16, 3, 3])
def test_get_unidirection_avg_filtered():
# 1 head, 1 layer, 4 residues long attention tensor
attention= torch.tensor([[[[1, 2, 3, 4],
[2, 5, 6, 7],
[3, 6, 8, 9],
[4, 7, 9, 11]]]], dtype=torch.float32)
result = unidirectional_avg_filtered(attention, 0, 0, 0)
assert result is not None
assert len(result) == 10
attention = torch.tensor([[[[1, 2, 3],
[2, 5, 6],
[4, 7, 91]]]], dtype=torch.float32)
result = unidirectional_avg_filtered(attention, 0, 0, 0)
assert len(result) == 6
|