File size: 5,225 Bytes
c1e6692
088c2ad
996a1ec
3f3c29c
 
dc7d693
66990c3
 
 
dc7d693
66990c3
 
 
3f3c29c
 
 
66990c3
3f3c29c
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
66990c3
3f3c29c
 
 
 
 
 
996a1ec
3f3c29c
996a1ec
 
3f3c29c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8dde1
3f3c29c
 
66990c3
3f3c29c
996a1ec
3f3c29c
 
 
66990c3
3f3c29c
 
 
db9f444
 
 
 
3f3c29c
 
 
 
c1e6692
daf9507
996a1ec
9b849f4
1f65033
5f8dde1
 
dc3cae8
 
 
 
5f8dde1
513b115
 
1f65033
5f8dde1
996a1ec
17c6a2f
 
04805af
 
996a1ec
17c6a2f
1f65033
04805af
 
 
 
1f65033
 
 
 
66990c3
 
dc7d693
66990c3
dc7d693
 
 
 
 
 
66990c3
 
 
 
 
 
 
 
5f8dde1
 
66990c3
5f8dde1
 
c1e6692
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64

# Assuming label_to_int is a dictionary with {label_name: label_index}
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}

class LogisticRegressionTorch(nn.Module):

    def __init__(self, input_dim: int, output_dim: int):
        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):

    def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
        super(BertClassifier, self).__init__()
        self.bert = bert_model  # Assume bert_model is an instance of a pre-trained BertModel
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
                token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
        # Extract outputs from the BERT model
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
        # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
        pooled_output = outputs.hidden_states[-1][:, 0, :]

        assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
        # to-do later!

        # Pass the pooled output to the classifier to get the logits
        logits = self.classifier(pooled_output)

        # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            pred = logits.view(-1, self.num_labels)
            observed = labels.view(-1)
            loss = loss_fct(pred, observed)

        # Return the loss and logits
        return loss, logits

# Load the Hugging Face model and tokenizer

metadata_features = 0
N_UNIQUE_CLASSES = 38  # or 38

base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

# Initialize the classifier
input_size = 768 + metadata_features  # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

# Load Weights
import os

# Get the model path from the environment variable
model_path = os.getenv('MODEL_PATH')
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))

base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])

# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()

# Define a function to process the DNA sequence
def analyze_dna(sequence):

    assert all(nucleotide in 'ACTGN' for nucleotide in sequence), "Sequence contains invalid characters"
    assert len(sequence) >= 300, "Sequence needs to be at least 300 nucleotides long"

    # Preprocess the input sequence
    inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)

    print("Tokenization done.")
    # Get model predictions
    _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

    print("Forward pass done.")
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    print("Probabilities done.")
    # Get the top 5 most likely classes
    top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
    top_5_probs = [probabilities[i] for i in top_5_indices]
    
    # Map indices to label names
    top_5_labels = [int_to_label[i] for i in top_5_indices]
    
    # Prepare the output as a list of tuples (label_name, probability)
    result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
    
    # Plot histogram
    
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.barh(top_5_labels, top_5_probs, color='skyblue')
    ax.set_xlabel('Probability')
    ax.set_title('Top 5 Most Likely Labels')
    plt.gca().invert_yaxis()  # Highest probabilities at the top

    # Save plot to a PNG image in memory
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode('utf-8')
    buf.close()

    return result, f'<img src="data:image/png;base64,{image_base64}" />'

# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])

# Launch the interface
demo.launch()