File size: 4,282 Bytes
b212cb1
 
 
fba8f5e
 
b212cb1
fba8f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b212cb1
fba8f5e
b212cb1
fba8f5e
 
 
b212cb1
 
 
 
 
fba8f5e
b212cb1
 
 
fba8f5e
 
 
 
 
 
 
b212cb1
fba8f5e
 
b212cb1
fba8f5e
b212cb1
fba8f5e
b212cb1
fba8f5e
 
b212cb1
fba8f5e
b212cb1
 
 
fba8f5e
 
 
 
 
 
 
 
 
 
 
b212cb1
fba8f5e
b212cb1
fba8f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b212cb1
fba8f5e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from huggingface_hub import HfApi, ModelFilter
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers.tokenization_utils_base import BatchEncoding
from transformers.modeling_outputs import MaskedLMOutput

# Function to fetch suitable ESM models from HuggingFace Hub
def get_models() -> list[None|str]:
    """Fetch suitable ESM models from HuggingFace Hub."""
    if not any(
        out := [
            m.modelId for m in HfApi().list_models(
                filter=ModelFilter(
                    author="facebook", model_name="esm", task="fill-mask"
                ), 
                sort="lastModified", 
                direction=-1
            )
        ]
    ):
        raise RuntimeError("Error while retrieving models from HuggingFace Hub")
    return out

# Class to wrap ESM models
class Model:
    """Wrapper for ESM models."""
    def __init__(self, model_name: str = ""):
        """Load selected model and tokenizer."""
        self.model_name = model_name
        if model_name:
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.batch_converter = AutoTokenizer.from_pretrained(model_name)
            self.alphabet = self.batch_converter.get_vocab()
            # Check if CUDA is available and if so, use it
            if torch.cuda.is_available():
                self.model = self.model.cuda()

    def tokenise(self, input: str) -> BatchEncoding:
        """Convert input string to batch of tokens."""
        return self.batch_converter(input, return_tensors="pt")

    def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
        """Run model on batch of tokens."""
        return self.model(batch_tokens, **kwargs)

    def __getitem__(self, key: str) -> int:
        """Get token ID from character."""
        return self.alphabet[key]

    def run_model(self, data):
        """Run model on data."""
        def label_row(row, token_probs):
            """Label row with score."""
            # Extract wild type, index and mutant type from the row
            wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
            # Calculate the score as the difference between the token probabilities of the mutant type and the wild type
            score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
            return score.item()

        # Tokenise the sequence data
        batch_tokens = self.tokenise(data.seq).input_ids

        # Calculate the token probabilities without updating the model parameters
        with torch.no_grad():
            token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1)
        # Store the token probabilities in the data
        data.token_probs = token_probs.cpu().numpy()

        # If the scoring strategy starts with "masked-marginals"
        if data.scoring_strategy.startswith("masked-marginals"):
            all_token_probs = []
            # For each token in the batch
            for i in range(batch_tokens.size()[1]):
                # If the token is in the list of residues
                if i in data.resi:
                    # Clone the batch tokens and mask the current token
                    batch_tokens_masked = batch_tokens.clone()
                    batch_tokens_masked[0, i] = self['<mask>']
                    # Calculate the masked token probabilities
                    with torch.no_grad():
                        masked_token_probs = torch.log_softmax(
                            self(batch_tokens_masked).logits, dim=-1
                        )
                else:
                    # If the token is not in the list of residues, use the original token probabilities
                    masked_token_probs = token_probs
                # Append the token probabilities to the list
                all_token_probs.append(masked_token_probs[:, i])
            # Concatenate all token probabilities
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)

        # Apply the label_row function to each row of the substitutions dataframe
        data.out[self.model_name] = data.sub.apply(
            lambda row: label_row(
                row['0'],
                token_probs,
            ),
            axis=1,
        )