File size: 4,201 Bytes
c1e6692
088c2ad
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
996a1ec
 
3f3c29c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6692
5f8dde1
3f3c29c
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
9ab99fb
3f3c29c
 
 
 
c1e6692
daf9507
996a1ec
83fe210
daf9507
5f8dde1
 
 
513b115
 
17c6a2f
5f8dde1
996a1ec
17c6a2f
 
04805af
 
996a1ec
17c6a2f
 
04805af
 
 
 
 
 
 
996a1ec
5f8dde1
 
04805af
5f8dde1
 
c1e6692
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn


class LogisticRegressionTorch(nn.Module):

    def __init__(self,
                 input_dim: int,
                 output_dim: int):

        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):

    def __init__(self,
                 bert_model: AutoModel,
                 classifier: LogisticRegressionTorch,
                 num_labels: int):

        super(BertClassifier, self).__init__()
        self.bert = bert_model  # Assume bert_model is an instance of a pre-trained BertModel
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
                token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
        # Extract outputs from the BERT model
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
        # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
        pooled_output = outputs.hidden_states[-1][:, 0, :]

        assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
        # to-do later!

        # Pass the pooled output to the classifier to get the logits
        logits = self.classifier(pooled_output)

        # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            pred = logits.view(-1, self.num_labels)
            observed = labels.view(-1)
            loss = loss_fct(pred, observed)

        # Return the loss and logits
        return loss, logits



# Load the Hugging Face model and tokenizer

metadata_features = 0
N_UNIQUE_CLASSES = 38 ## or 38

base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

# Load Weights
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))

base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])

# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()

# Define a function to process the DNA sequence
def analyze_dna(sequence):
    # Preprocess the input sequence
    inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)

    print("tokenization done.")
    # Get model predictions
    _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

    print("Forward pass done.")
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    print("Probabilities, done.")
    # Get the top 5 most likely classes
    top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
    top_5_probs = [probabilities[i] for i in top_5_indices]
    
    # Prepare the output as a list of tuples (class_index, probability)
    result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
    
    return result

# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")

# Launch the interface
demo.launch()