File size: 2,266 Bytes
466a8f2
f8402f9
 
eb9ae1f
3de389d
f8402f9
 
 
d1b48a1
f8402f9
87c0dbc
f8402f9
 
87c0dbc
b7ab123
 
 
 
 
 
 
 
 
 
 
87c0dbc
cfba77f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466a8f2
cfba77f
 
466a8f2
cfba77f
466a8f2
 
cfba77f
58c7b8d
1c423ed
 
 
 
 
 
 
3de389d
58c7b8d
 
 
 
 
 
3de389d
58c7b8d
 
 
 
58b87e3
58c7b8d
 
 
3de389d
58c7b8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from Bio.PDB.Structure import Structure

from hexviz.attention import (ModelType, get_attention, get_sequences,
                              get_structure, unidirectional_avg_filtered)


def test_get_structure():
    pdb_id = "2I62"
    structure = get_structure(pdb_id)

    assert structure is not None
    assert isinstance(structure, Structure)

def test_get_sequences():
    pdb_id = "1AKE"
    structure = get_structure(pdb_id)
    
    sequences = get_sequences(structure)

    assert sequences is not None
    assert len(sequences) == 2

    A, B = sequences
    assert A[:3] == ["M", "R", "I"]


def test_get_attention_zymctrl():

    result = get_attention("GGG", model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36,16,3,3])

def test_get_attention_zymctrl_long_chain():
    structure = get_structure(pdb_code="6A5J") # 13 residues long

    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36,16,13,13])

def test_get_attention_tape():
    structure = get_structure(pdb_code="6A5J") # 13 residues long
    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)

    assert result is not None
    assert result.shape == torch.Size([12,12,13,13])

def test_get_attention_prot_bert():

    result = get_attention("GGG", model_type=ModelType.PROT_BERT)

    assert result is not None
    assert result.shape == torch.Size([30, 16, 3, 3])

def test_get_unidirection_avg_filtered():
    # 1 head, 1 layer, 4 residues long attention tensor
    attention= torch.tensor([[[[1, 2, 3, 4],
                               [2, 5, 6, 7],
                               [3, 6, 8, 9],
                               [4, 7, 9, 11]]]], dtype=torch.float32)

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert result is not None
    assert len(result) == 10

    attention = torch.tensor([[[[1, 2, 3],
                               [2, 5, 6],
                               [4, 7, 91]]]], dtype=torch.float32)

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert len(result) == 6