File size: 6,533 Bytes
c1e6692
088c2ad
996a1ec
3f3c29c
 
dc7d693
66990c3
 
 
9bf3f2b
65391f8
6dddebe
ace289f
 
 
 
dc7d693
9bf3f2b
66990c3
 
3f3c29c
db842cf
 
 
 
 
 
 
 
 
306f08b
3f3c29c
ace289f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f3c29c
db9f444
1f65033
ace289f
b1eabde
 
5935bca
 
c214f12
5935bca
6dddebe
356d0ee
ace289f
b1eabde
 
 
356d0ee
fe4ceb1
dc3cae8
356d0ee
fe4ceb1
356d0ee
ace289f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1eabde
 
356d0ee
 
 
 
 
 
 
 
ad5e55a
9bf3f2b
356d0ee
 
 
 
 
 
 
a948bde
356d0ee
9bf3f2b
fe4ceb1
5f8dde1
 
6dddebe
 
5935bca
 
 
ace289f
 
 
 
 
5935bca
ace289f
6dddebe
5f8dde1
 
c1e6692
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
import tangermeme
from tangermeme import one_hot_encode

# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}

# Update labels based on the given conditions
for k, v in int_to_label.items():
    if "KOREA" in v:
        int_to_label[k] = "KOREA"
    elif "KINGDOM" in v:
        int_to_label[k] = "UK"
    elif "RUSSIAN" in v:
        int_to_label[k] = "RUSSIA"



def load_model(model_name: str):
    metadata_features = 0
    N_UNIQUE_CLASSES = 38  
    
    if model_name == 'gena-bert':
        base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
        tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
    
        input_size = 768 + metadata_features
        log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
    
        token = os.getenv('HUGGINGFACE_TOKEN')
        if token is None:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
    
        login(token=token)
        file_path = hf_hub_download(
            repo_id="mawairon/noo_test",
            filename="gena-blastln-bs33-lr4e-05-S168.pth",
            use_auth_token=token
        )
        weights = torch.load(file_path, map_location=torch.device('cpu'))
    
        base_model.load_state_dict(weights['model_state_dict'])
        log_reg.load_state_dict(weights['log_reg_state_dict'])
    
        model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
        model.eval()
    
        return model, tokenizer

    elif model_name == 'CNN':
        hidden_dim = 2048
        width = 2048
        seq_drop_prob = 0.05
        train_sequence_length = 8000
        weight_decay = 0.0001
        num_labs = len(set(y_train))
        
        
        model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
        new_head = torch.nn.Sequential(
                torch.nn.Dropout(0.5),
                MLP([hidden_dim*2 , num_labs])
            )
        
        model = torch.nn.Sequential(
            model_seq,
            new_head
        )
        return model, None

    else: 
        return {"error": "Invalid model name"}
        



def analyze_dna(username, password, sequence, model_name):
    
    valid_usernames = os.getenv('USERNAME').split(',')
    env_password = os.getenv('PASSWORD')
    
    if username not in valid_usernames or password != env_password:
        return {"error": "Invalid username or password"}, ""

    try:
        
        # Remove all whitespace characters
        sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
        
        if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
            return {"error": "Sequence contains invalid characters"}, ""

        if len(sequence) < 300:
            return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""

        model, tokenizer = load_model(model_name)

        def get_logits(seq, model_name):
            if model_name == 'gena-bert':
                inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
                with torch.no_grad():
                    logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
                return logits

            elif model_name == 'CNN':
                # Truncate sequence
                SEQUENCE_LENGTH = 8000
                seq = seq[:SEQUENCE_LENGTH]

                # Pad sequences to the desired length
                seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH]

                # Apply one-hot encoding to the 'sequence' column
                input = seq.one_hot_encode()
                with torch.no_grad():
                    logits = model(input)
                return logits
                

        # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
        #     num_shifts = len(sequence) // 1000
        #     logits_sum = None
        #     for i in range(num_shifts):
        #         shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
        #         logits = get_logits(shifted_sequence)
        #         if logits_sum is None:
        #             logits_sum = logits
        #         else:
        #             logits_sum += logits
        #     logits_avg = logits_sum / num_shifts
        # else:
        logits_avg = get_logits(sequence)

        probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
        top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
        top_5_probs = [probabilities[i] for i in top_5_indices]
        top_5_labels = [int_to_label[i] for i in top_5_indices]
        result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.barh(top_5_labels, top_5_probs, color='skyblue')
        ax.set_xlabel('Probability')
        ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
        plt.gca().invert_yaxis()

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return result, f'<img src="data:image/png;base64,{image_base64}" />'

    except Exception as e:
        return {"error": str(e)}, ""

# Create a Gradio interface
demo = gr.Interface(
    fn=analyze_dna,
    inputs=[
        gr.Textbox(label="Username"),
        gr.Textbox(label="Password", type="password"),
        gr.Textbox(label="DNA Sequence"),
        gr.Dropdown(label="Model", choices=[
            "gena-bert",
            "CNN"
        ])
    ],
    outputs=["json", "HTML"]
)

# Launch the interface
demo.launch()