Spaces:
Running
Running
import os | |
import sys | |
import torch | |
from torch.utils.data import Dataset | |
import json | |
import numpy as np | |
from torch.utils.data.dataloader import default_collate | |
import time | |
class ESMDataset(Dataset): | |
def __init__(self, pdb_root, ann_paths, chain="A"): | |
""" | |
protein (string): Root directory of protein (e.g. coco/images/) | |
ann_root (string): directory to store the annotation file | |
""" | |
self.pdb_root = pdb_root | |
self.annotation = json.load(open(ann_paths, "r")) | |
self.pdb_ids = {} | |
self.chain = chain | |
def __len__(self): | |
return len(self.annotation) | |
def __getitem__(self, index): | |
ann = self.annotation[index] | |
protein_embedding = '{}.pt'.format(ann["pdb_id"]) | |
protein_embedding_path = os.path.join(self.pdb_root, protein_embedding) | |
protein_embedding = torch.load(protein_embedding_path, map_location=torch.device('cpu')) | |
protein_embedding.requires_grad = False | |
caption = ann["caption"] | |
return { | |
"text_input": caption, | |
"encoder_out": protein_embedding, | |
"chain": self.chain, | |
"pdb_id": ann["pdb_id"] | |
} | |
def collater(self, samples): | |
max_len_protein_dim0 = -1 | |
for pdb_json in samples: | |
pdb_embeddings = pdb_json["encoder_out"] | |
shape_dim0 = pdb_embeddings.shape[0] | |
max_len_protein_dim0 = max(max_len_protein_dim0, shape_dim0) | |
for pdb_json in samples: | |
pdb_embeddings = pdb_json["encoder_out"] | |
shape_dim0 = pdb_embeddings.shape[0] | |
pad1 = ((0, max_len_protein_dim0 - shape_dim0), (0, 0), (0, 0)) | |
arr1_padded = np.pad(pdb_embeddings, pad1, mode='constant', ) | |
pdb_json["encoder_out"] = arr1_padded | |
return default_collate(samples) |