File size: 20,244 Bytes
39406f0
 
 
 
 
 
 
 
 
 
d1ca73b
 
39406f0
 
 
 
 
 
 
 
d1ca73b
 
39406f0
 
 
d1ca73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39406f0
 
d1ca73b
 
 
 
 
 
 
 
 
 
 
 
39406f0
d1ca73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39406f0
 
d1ca73b
 
 
 
 
 
 
 
 
 
 
39406f0
d1ca73b
39406f0
 
 
 
 
 
 
d1ca73b
39406f0
 
 
 
 
d1ca73b
39406f0
 
 
d1ca73b
39406f0
 
 
 
d1ca73b
39406f0
 
 
d1ca73b
39406f0
d1ca73b
39406f0
d1ca73b
39406f0
 
 
 
 
 
 
 
d1ca73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39406f0
 
 
 
cad5ffa
 
39406f0
 
cad5ffa
39406f0
cad5ffa
39406f0
cad5ffa
 
 
 
39406f0
 
 
 
 
 
 
 
 
94f42df
39406f0
 
 
d1ca73b
39406f0
 
d1ca73b
94f42df
d1ca73b
 
 
 
 
 
 
 
94f42df
d1ca73b
 
94f42df
d1ca73b
 
 
 
94f42df
d1ca73b
 
 
 
 
 
 
 
 
 
39406f0
d1ca73b
 
 
 
 
39406f0
 
 
 
 
 
 
 
 
 
d1ca73b
 
39406f0
d1ca73b
39406f0
d1ca73b
 
39406f0
d1ca73b
 
 
 
39406f0
d1ca73b
 
 
39406f0
d1ca73b
 
 
39406f0
 
 
 
 
 
 
 
d1ca73b
39406f0
d1ca73b
39406f0
 
d1ca73b
39406f0
d1ca73b
 
39406f0
 
d1ca73b
 
 
 
39406f0
d1ca73b
 
39406f0
d1ca73b
 
39406f0
d1ca73b
 
 
 
39406f0
d1ca73b
 
 
39406f0
 
d1ca73b
 
39406f0
d1ca73b
39406f0
 
d1ca73b
39406f0
d1ca73b
39406f0
d1ca73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39406f0
d1ca73b
39406f0
 
 
 
f0548df
39406f0
d1ca73b
39406f0
 
 
 
ca7782f
d1ca73b
39406f0
 
 
f0548df
 
 
 
 
39406f0
 
 
 
 
 
 
 
 
d1ca73b
f0548df
d1ca73b
39406f0
 
 
 
d1ca73b
39406f0
 
d1ca73b
 
39406f0
 
adb2f18
39406f0
 
f0548df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import pandas as pd
from IPython.display import clear_output
import torch
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib
matplotlib.use('Agg')  # Use the non-interactive Agg backend
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import gradio as gr
import io
from PIL import Image
import Bio
from Bio import SeqIO
from Bio.Blast import NCBIXML
import subprocess
import zipfile
import os

GTA_fam_dict = {
  0: "GT116",
  1: "GT12",
  2: "GT13",
  3: "GT14",
  4: "GT15",
  5: "GT16",
  6: "GT17",
  7: "GT2-clade1",
  8: "GT2-clade2",
  9: "GT2-clade3",
  10: "GT2-clade4",
  11: "GT2-clade5",
  12: "GT2-related",
  13: "GT21",
  14: "GT24",
  15: "GT25",
  16: "GT27",
  17: "GT31",
  18: "GT32",
  19: "GT34",
  20: "GT40",
  21: "GT43",
  22: "GT45",
  23: "GT49",
  24: "GT54",
  25: "GT55",
  26: "GT6",
  27: "GT60",
  28: "GT62",
  29: "GT64",
  30: "GT67",
  31: "GT7",
  32: "GT75",
  33: "GT77",
  34: "GT78",
  35: "GT8",
  36: "GT81",
  37: "GT82",
  38: "GT84",
  39: "GT88",
  40: "GT92"
}


GTA_don_dict = {
  0: "N-Acetyl Galactosamine",
  1: "N-Acetyl Glucosamine",
  2: "Arabinose",
  3: "Galactose",
  4: "Galacturonic Acid",
  5: "Glucose",
  6: "Glucuronic Acid",
  7: "Mannose",
  8: "Rhamnose",
  9: "Xylose"
}

GTB_fam_dict = {
  0: "GT1",
  1: "GT10",
  2: "GT104",
  3: "GT11",
  4: "GT18",
  5: "GT19",
  6: "GT20",
  7: "GT23",
  8: "GT28",
  9: "GT3",
  10: "GT30",
  11: "GT35",
  12: "GT37",
  13: "GT38",
  14: "GT4",
  15: "GT41",
  16: "GT5",
  17: "GT52",
  18: "GT63",
  19: "GT65",
  20: "GT68",
  21: "GT70",
  22: "GT72",
  23: "GT80",
  24: "GT9",
  25: "GT90",
  26: "GT99"
}


GTB_don_dict = {
  0: "Fucose",
  1: "Galactose",
  2: "N-Acetyl Galactosamine",
  3: "Glucuronic Acid",
  4: "N-Acetyl Glucosamine",
  5: "Glucose",
  6: "Mannose",
  7: "Other",
  8: "Xylose"
}

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D

glycosyltransferase_db = {
    "GT40"           : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'},
    "GT16"           : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6  ', 'More Info': 'http://www.cazy.org/GT16.html'},
    "GT27"           : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5  ', 'More Info': 'http://www.cazy.org/GT27.html'},
    "GT55"           : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2  ', 'More Info': 'http://www.cazy.org/GT55.html'},
    "GT25"           : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6  ', 'More Info': 'http://www.cazy.org/GT25.html'},
    "GT2"            : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2  ', 'More Info': 'http://www.cazy.org/GT2.html' },
    "GT84"           : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1  ', 'More Info': 'http://www.cazy.org/GT84.html'},
    "GT13"           : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6  ', 'More Info': 'http://www.cazy.org/GT13.html'},
    "GT67"           : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8  ', 'More Info': 'http://www.cazy.org/GT67.html'},
    "GT82"           : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7  ', 'More Info': 'http://www.cazy.org/GT82.html'},
    "GT24"           : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9  ', 'More Info': 'http://www.cazy.org/GT24.html'},
    "GT81"           : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2  ', 'More Info': 'http://www.cazy.org/GT81.html'},
    "GT49"           : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'},
    "GT34"           : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'},
    "GT45"           : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'},
    "GT32"           : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
    "GT88"           : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9  ', 'More Info': 'http://www.cazy.org/GT88.html'},
    "GT21"           : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1  ', 'More Info': 'http://www.cazy.org/GT21.html'},
    "GT54"           : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6  ', 'More Info': 'http://www.cazy.org/GT54.html'},
    "GT6"            : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' },
    "GT7"            : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5  ', 'More Info': 'http://www.cazy.org/GT7.html' },
    "GT64"           : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'},
    "GT78"           : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2  ', 'More Info': 'http://www.cazy.org/GT78.html'},
    "GT12"           : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'},
    "GT31"           : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8  ', 'More Info': 'http://www.cazy.org/GT31.html'},
    "GT62"           : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3  ', 'More Info': 'http://www.cazy.org/GT62.html'},
    "GT8"            : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
    "GT15"           : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8  ', 'More Info': 'http://www.cazy.org/GT15.html'},
    "GT43"           : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
    "GT60"           : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5  ', 'More Info': 'http://www.cazy.org/GT60.html'},
    "GT14"           : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7  ', 'More Info': 'http://www.cazy.org/GT14.html'},
    "GT17"           : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7  ', 'More Info': 'http://www.cazy.org/GT17.html'},
    "GT77"           : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9  ', 'More Info': 'http://www.cazy.org/GT77.html'},
    "GT75"           : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'},

}

def parse_blast_output_for_best_evalue(output_file):
    with open(output_file) as result_handle:
        blast_record = NCBIXML.read(result_handle)

    if len(blast_record.alignments) == 0:
        # Handle the case where no alignments are found
        # You might return a high e-value or None to indicate no match
        return None

    best_hit = blast_record.alignments[0]
    best_evalue = best_hit.hsps[0].expect
    print(best_evalue)
    return best_evalue

def run_local_blast(sequence, database):
    # Temporarily save the query sequence to a file
    query_file = "temp_query.fasta"
    with open(query_file, "w") as file:
        file.write(">Query\n" + sequence)
    
    # Specify the output file for BLAST results
    output_file = "blast_results.xml"
    
    # Construct the BLAST command
    blast_cmd = [
        "blastp",
        "-query", query_file,
        "-db", database,
        "-out", output_file,
        "-outfmt", "5",  # Output format 5 is XML
        "-evalue", "1e-2"  # Set your desired E-value threshold here
    ]
    
    # Execute the BLAST search
    subprocess.run(blast_cmd, check=True)
    
    # Parse the BLAST output to find the best E-value
    best_evalue = parse_blast_output_for_best_evalue(output_file)
    
    # Clean up temporary files
    os.remove(query_file)
    os.remove(output_file)
    
    return best_evalue


def get_family_info(family_name):
    family_info = glycosyltransferase_db.get(family_name, {})
    
    output = ""
    for key, value in family_info.items():
        if key == "more_info":
            output += "**{}:**".format(key.title().replace("_", " ")) + "\n"
            for link in value:
                output += "[{}]({})  ".format(link, link)
        else:
            output += "**{}:** {}  ".format(key.title().replace("_", " "), value)
    
    return output


def fig_to_img(fig):
    """Converts a matplotlib figure to a PIL Image and returns it"""
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    return img

def preprocess_protein_sequence(protein_fasta):
    lines = protein_fasta.split('\n')
    headers = [line for line in lines if line.startswith('>')]
    if len(headers) > 1:
        return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."

    protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
    valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
    
    # Check if every character in the sequence is in the set of valid characters.
    if any(char.upper() not in valid_characters for char in protein_sequence):
        return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids."
    
    print("Running Blast.")
    
    gta_db_path = "blast_data/GTA/GTA.db"
    gtb_db_path = "blast_data/GTB/GTB.db"

    evalue_gta = run_local_blast(protein_sequence, gta_db_path)
    evalue_gta = evalue_gta if evalue_gta is not None else 1e+100

    evalue_gtb = run_local_blast(protein_sequence, gtb_db_path)
    evalue_gtb = evalue_gtb if evalue_gtb is not None else 1e+100
    print("E-value GT-A:", evalue_gta, "E-value GT-B:", evalue_gtb)
    print("Blast finished running. Checking sequence against known data.")

    # Determine which models to use based on the best E-value
    model_fam = "GTA_fam.pth" if evalue_gta < evalue_gtb else "GTB_fam.pth"
    model_don = "GTA_don.pth" if evalue_gta < evalue_gtb else "GTB_don.pth"
    print("Selected model for family:", model_fam, "and donor:", model_don)


    # Adjust your existing condition to check if both E-values exceed the threshold
    if evalue_gta > 1e-2 and evalue_gtb > 1e-2:
        # If both E-values are above the threshold, it suggests the sequence does not match well with either database
        return None, None, None, "**Warning:** The sequence does not appear to be a GT-A or GT-B. Please ensure you are submitting a sequence from these families."

    return protein_sequence, model_fam, model_don, None



def process_family_sequence(protein_sequence, modelfam, label_dict):
    encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
    input_idsfam = encoded_input["input_ids"]
    attention_maskfam = encoded_input["attention_mask"]

    with torch.no_grad():
        outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam)
        logitsfam = outputfam.logits
        probabilitiesfam = F.softmax(logitsfam, dim=1)
        _, predicted_labelsfam = torch.max(logitsfam, dim=1)

    predicted_label_index_fam = predicted_labelsfam.item()  # Assuming single sample prediction
    decoded_label_fam = label_dict.get(predicted_label_index_fam, "Unknown Label")  # Decoding label using the dictionary

    family_info = get_family_info(decoded_label_fam)

    figfam = plt.figure(figsize=(10, 5))
    # probabilitiesfam_flat = probabilitiesfam.squeeze().tolist()  # Flatten probabilities

    # Extract and sort top 5 label probabilities
    top5_probs, top5_labels = torch.topk(probabilitiesfam, 5)
    top5_labels = top5_labels.squeeze().tolist()
    top5_decoded_labels = [label_dict.get(label, "Unknown") for label in top5_labels]

    # For debugging
    print("Top 5 labels:", top5_labels)
    print("Available keys in label_dict:", label_dict.keys())

    y_posfam = np.arange(len(top5_decoded_labels))
    plt.barh(y_posfam, [prob * 100 for prob in top5_probs.squeeze().tolist()], align='center', alpha=0.5)
    plt.yticks(y_posfam, top5_decoded_labels)
    plt.xlabel('Probability (%)')
    plt.title('Top 5 Family Class Probabilities')
    plt.xlim(0, 100)
    plt.close(figfam)

    img = fig_to_img(figfam)

    if len(protein_sequence) < 100:
        return decoded_label_fam, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"

    return decoded_label_fam, img, None, family_info


def process_donor_sequence(protein_sequence, modeldon, label_dict):
    encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
    input_idsdon = encoded_input["input_ids"]
    attention_maskdon = encoded_input["attention_mask"]

    with torch.no_grad():
        outputdon = modeldon(input_idsdon, attention_mask=attention_maskdon)
        logitsdon = outputdon.logits
        probabilitiesdon = F.softmax(logitsdon, dim=1)
        _, predicted_labelsdon = torch.max(logitsdon, dim=1)

    predicted_label_index_don = predicted_labelsdon.item()  # Assuming single sample prediction
    decoded_label_don = label_dict.get(predicted_label_index_don, "Unknown Label")  # Decoding label using the dictionary

    figdon = plt.figure(figsize=(10, 5))
    probabilitiesdon_flat = probabilitiesdon.squeeze().tolist()  # Flatten probabilities

    # Extract and sort top 5 label probabilities
    top3_probs, top3_labels = torch.topk(probabilitiesdon, 3)
    top3_labels = top3_labels.squeeze().tolist()
    top3_decoded_labels = [label_dict.get(label, "Unknown") for label in top3_labels]

    y_posdon = np.arange(len(top3_decoded_labels))
    plt.barh(y_posdon, [prob * 100 for prob in top3_probs.squeeze().tolist()], align='center', alpha=0.5)
    plt.yticks(y_posdon, top3_decoded_labels)
    plt.xlabel('Probability (%)')
    plt.title('Top 3 Donor Class Probabilities')
    plt.xlim(0, 100)
    plt.close(figdon)

    img = fig_to_img(figdon)

    if len(protein_sequence) < 100:
        return decoded_label_don, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"

    return decoded_label_don, img, None

def main_function_single(sequence):
    # Initial preprocessing including BLAST-based model selection
    protein_sequence, model_fam_path, model_don_path, error_msg = preprocess_protein_sequence(sequence)
    if error_msg:
        print(error_msg)
        return None, None, error_msg, None, None

    model_config = {
        "GTA_fam.pth": {"num_labels": 41, "label_dict": GTA_fam_dict},
        "GTB_fam.pth": {"num_labels": 27, "label_dict": GTB_fam_dict},
        "GTA_don.pth": {"num_labels": 10, "label_dict": GTA_don_dict},
        "GTB_don.pth": {"num_labels": 9, "label_dict": GTB_don_dict},
    }

    # Load the model for family classification
    config_fam = model_config[model_fam_path]
    model_fam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_fam["num_labels"])    
    model_fam.load_state_dict(torch.load(model_fam_path, map_location=torch.device('cpu')), strict=False)
    model_fam.eval()
    model_fam.to('cpu')

    # Load the model for donor classification
    config_don = model_config[model_don_path]
    model_don = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_don["num_labels"])
    model_don.load_state_dict(torch.load(model_don_path, map_location=torch.device('cpu')), strict=False)
    model_don.eval()
    model_don.to('cpu')

    print(config_fam["label_dict"])

    # Pass the label dictionary along with the model to the processing functions
    family_label, family_img, _, family_info = process_family_sequence(protein_sequence, model_fam, config_fam["label_dict"])
    donor_label, donor_img, _ = process_donor_sequence(protein_sequence, model_don, config_don["label_dict"])
    
    return family_label, family_img, family_info, donor_label, donor_img


prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph")
prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph")

with gr.Blocks() as app:
    gr.Markdown("# Glydentify (alpha v0.5)")

    with gr.Tab("Single Sequence Prediction"):
        with gr.Row().style(equal_height=True):
            with gr.Column():
                sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence")
                # explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False)
            with gr.Column():
                with gr.Accordion("Example:"):
                    gr.Markdown("""
                                \>sp|Q9Y5Z6|B3GT1_HUMAN Beta-1,3-galactosyltransferase 1 OS=Homo sapiens OX=9606 GN=B3GALT1 PE=1 SV=1  
                                MASKVSCLYVLTVVCWASALWYLSITRPTSSYTGSKPFSHLTVARKNFTFGNIRTRPINPHSFEFLINEPNKCEKNIPFLVILIST  
                                THKEFDARQAIRETWGDENNFKGIKIATLFLLGKNADPVLNQMVEQESQIFHDIIVEDFIDSYHNLTLKTLMGMRWVATFCSK  
                                AKYVMKTDSDIFVNMDNLIYKLLKPSTKPRRRYFTGYVINGGPIRDVRSKWYMPRDLYPDSNYPPFCSGTGYIFSADVAELIYK  
                                TSLHTRLLHLEDVYVGLCLRKLGIHPFQNSGFNHWKMAYSLCRYRRVITVHQISPEEMHRIWNDMSSKKHLRC  
                                """)
                family_prediction = gr.outputs.Textbox(label="Predicted family")
                donor_prediction = gr.outputs.Textbox(label="Predicted donor")
                info_markdown = gr.Markdown()

        # Predict and Clear buttons
        with gr.Row().style(equal_height=True):
            with gr.Column():
                predict_button = gr.Button("Predict")
                predict_button.click(main_function_single, inputs=[sequence],
                                     outputs=[family_prediction, prediction_imagefam, info_markdown,
                                              donor_prediction, prediction_imagedonor])

        # Family & Donor Section
        with gr.Row().style(equal_height=True):
            with gr.Column():
                with gr.Accordion("Family Prediction:"):
                    prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph")
            with gr.Column():
                with gr.Accordion("Donor Prediction:"):          
                    prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph")
    

app.launch(show_error=True)