hexviz / tests /test_attention.py
aksell's picture
Rename project protention -> hexviz
28d4b1e
raw
history blame
1.8 kB
import torch
from Bio.PDB.Structure import Structure
from transformers import T5EncoderModel, T5Tokenizer
from hexviz.attention import (ModelType, get_attention, get_protT5,
get_sequences, get_structure,
unidirectional_sum_filtered)
def test_get_structure():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
assert structure is not None
assert isinstance(structure, Structure)
def test_get_sequences():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
sequences = get_sequences(structure)
assert sequences is not None
assert len(sequences) == 2
A, B = sequences
assert A[:3] == ["M", "R", "I"]
def test_get_protT5():
result = get_protT5()
assert result is not None
assert isinstance(result, tuple)
tokenizer, model = result
assert isinstance(tokenizer, T5Tokenizer)
assert isinstance(model, T5EncoderModel)
def test_get_attention_tape():
result = get_attention("1AKE", model=ModelType.TAPE_BERT)
assert result is not None
assert result.shape == torch.Size([12,12,456,456])
def test_get_unidirection_sum_filtered():
# 1 head, 1 layer, 4 residues long attention tensor
attention= torch.tensor([[[[1, 2, 3, 4],
[2, 5, 6, 7],
[3, 6, 8, 9],
[4, 7, 9, 11]]]], dtype=torch.float32)
result = unidirectional_sum_filtered(attention, 0, 0, 0)
assert result is not None
assert len(result) == 10
attention= torch.tensor([[[[1, 2, 3],
[2, 5, 6],
[4, 7, 91]]]], dtype=torch.float32)
result = unidirectional_sum_filtered(attention, 0, 0, 0)
assert len(result) == 6