File size: 3,840 Bytes
c1e6692
088c2ad
996a1ec
3f3c29c
 
dc7d693
66990c3
 
 
9bf3f2b
dc7d693
9bf3f2b
66990c3
 
3f3c29c
 
66990c3
3f3c29c
996a1ec
3f3c29c
 
 
 
 
 
 
 
66990c3
3f3c29c
9bf3f2b
3f3c29c
 
 
9bf3f2b
996a1ec
3f3c29c
 
9bf3f2b
3f3c29c
9bf3f2b
 
 
3f3c29c
9bf3f2b
 
3f3c29c
9bf3f2b
 
3f3c29c
9bf3f2b
 
3f3c29c
9bf3f2b
 
3f3c29c
9bf3f2b
 
3f3c29c
9bf3f2b
db9f444
9bf3f2b
1f65033
5f8dde1
356d0ee
 
9bf3f2b
dc3cae8
356d0ee
9bf3f2b
356d0ee
 
9bf3f2b
356d0ee
 
 
 
 
 
 
 
 
 
 
9bf3f2b
356d0ee
 
 
 
 
 
 
 
 
9bf3f2b
356d0ee
5f8dde1
 
66990c3
5f8dde1
 
c1e6692
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os

# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}

class LogisticRegressionTorch(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):
    def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        pooled_output = outputs.hidden_states[-1][:, 0, :]
        logits = self.classifier(pooled_output)
        return logits

def load_model():
    metadata_features = 0
    N_UNIQUE_CLASSES = 38  

    base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
    tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

    input_size = 768 + metadata_features
    log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

    model_weights_path = os.getenv('MODEL_PATH')
    weights = torch.load(model_weights_path, map_location=torch.device('cpu'))

    base_model.load_state_dict(weights['model_state_dict'])
    log_reg.load_state_dict(weights['log_reg_state_dict'])

    model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
    model.eval()

    return model, tokenizer

model, tokenizer = load_model()

def analyze_dna(sequence):
    try:
        if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
            return "Error: Sequence contains invalid characters", ""

        if len(sequence) < 300:
            return "Error: Sequence needs to be at least 300 nucleotides long", ""

        inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
        logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

        probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
        top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
        top_5_probs = [probabilities[i] for i in top_5_indices]
        top_5_labels = [int_to_label[i] for i in top_5_indices]
        result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.barh(top_5_labels, top_5_probs, color='skyblue')
        ax.set_xlabel('Probability')
        ax.set_title('Top 5 Most Likely Labels')
        plt.gca().invert_yaxis()

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return result, f'<img src="data:image/png;base64,{image_base64}" />'

    except Exception as e:
        return str(e), ""

# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])

# Launch the interface
demo.launch()