hexviz / protention /attention.py
aksell's picture
Add get_protT5
87c0dbc
raw
history blame
1.28 kB
from io import StringIO
from urllib import request
import torch
from Bio.PDB import PDBParser, Structure
from transformers import T5EncoderModel, T5Tokenizer
def get_structure(pdb_code: str) -> Structure:
"""
Get structure from PDB
"""
pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb"
pdb_data = request.urlopen(pdb_url).read().decode("utf-8")
file = StringIO(pdb_data)
parser = PDBParser()
structure = parser.get_structure(pdb_code, file)
return structure
def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
device
)
model.full() if device == "cpu" else model.half()
return tokenizer, model
def get_attention(
pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
):
"""
Get attention from T5
"""
# fetch structure
structure = get_structure(pdb_code)
# get model
tokenizer, model = get_protT5()
# call model
# get attention
# extract attention