File size: 2,131 Bytes
466a8f2
f8402f9
 
a71a737
 
 
 
 
 
 
f8402f9
 
 
d1b48a1
f8402f9
87c0dbc
f8402f9
 
87c0dbc
a71a737
b7ab123
 
 
a71a737
b7ab123
 
 
 
 
 
 
87c0dbc
cfba77f
 
 
 
 
 
a71a737
 
cfba77f
 
a71a737
cfba77f
 
 
 
 
 
a71a737
 
cfba77f
466a8f2
a71a737
cfba77f
466a8f2
cfba77f
466a8f2
 
a71a737
 
58c7b8d
1c423ed
 
 
 
 
 
 
a71a737
3de389d
58c7b8d
a71a737
 
 
58c7b8d
3de389d
58c7b8d
 
 
 
a71a737
58c7b8d
3de389d
58c7b8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from Bio.PDB.Structure import Structure

from hexviz.attention import (
    ModelType,
    get_attention,
    get_sequences,
    get_structure,
    unidirectional_avg_filtered,
)


def test_get_structure():
    pdb_id = "2I62"
    structure = get_structure(pdb_id)

    assert structure is not None
    assert isinstance(structure, Structure)


def test_get_sequences():
    pdb_id = "1AKE"
    structure = get_structure(pdb_id)

    sequences = get_sequences(structure)

    assert sequences is not None
    assert len(sequences) == 2

    A, B = sequences
    assert A[:3] == ["M", "R", "I"]


def test_get_attention_zymctrl():

    result = get_attention("GGG", model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36, 16, 3, 3])


def test_get_attention_zymctrl_long_chain():
    structure = get_structure(pdb_code="6A5J")  # 13 residues long

    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36, 16, 13, 13])


def test_get_attention_tape():
    structure = get_structure(pdb_code="6A5J")  # 13 residues long
    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)

    assert result is not None
    assert result.shape == torch.Size([12, 12, 13, 13])


def test_get_attention_prot_bert():

    result = get_attention("GGG", model_type=ModelType.PROT_BERT)

    assert result is not None
    assert result.shape == torch.Size([30, 16, 3, 3])


def test_get_unidirection_avg_filtered():
    # 1 head, 1 layer, 4 residues long attention tensor
    attention = torch.tensor(
        [[[[1, 2, 3, 4], [2, 5, 6, 7], [3, 6, 8, 9], [4, 7, 9, 11]]]], dtype=torch.float32
    )

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert result is not None
    assert len(result) == 10

    attention = torch.tensor([[[[1, 2, 3], [2, 5, 6], [4, 7, 91]]]], dtype=torch.float32)

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert len(result) == 6