import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel import torch import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import io import base64 import os import huggingface_hub from huggingface_hub import hf_hub_download, login import model_archs from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN import tangermeme from tangermeme import one_hot_encode # Load label mapping label_to_int = pd.read_pickle('label_to_int.pkl') int_to_label = {v: k for k, v in label_to_int.items()} # Update labels based on the given conditions for k, v in int_to_label.items(): if "KOREA" in v: int_to_label[k] = "KOREA" elif "KINGDOM" in v: int_to_label[k] = "UK" elif "RUSSIAN" in v: int_to_label[k] = "RUSSIA" def load_model(model_name: str): metadata_features = 0 N_UNIQUE_CLASSES = 38 if model_name == 'gena-bert': base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) input_size = 768 + metadata_features log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) token = os.getenv('HUGGINGFACE_TOKEN') if token is None: raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") login(token=token) file_path = hf_hub_download( repo_id="mawairon/noo_test", filename="gena-blastln-bs33-lr4e-05-S168.pth", use_auth_token=token ) weights = torch.load(file_path, map_location=torch.device('cpu')) base_model.load_state_dict(weights['model_state_dict']) log_reg.load_state_dict(weights['log_reg_state_dict']) model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) model.eval() return model, tokenizer elif model_name == 'CNN': hidden_dim = 2048 width = 2048 seq_drop_prob = 0.05 train_sequence_length = 8000 weight_decay = 0.0001 num_labs = len(set(y_train)) model_seq = SimpleCNN(18, hidden_dim, additional_layer=False) new_head = torch.nn.Sequential( torch.nn.Dropout(0.5), MLP([hidden_dim*2 , num_labs]) ) model = torch.nn.Sequential( model_seq, new_head ) return model, None else: return {"error": "Invalid model name"} def analyze_dna(username, password, sequence, model_name): valid_usernames = os.getenv('USERNAME').split(',') env_password = os.getenv('PASSWORD') if username not in valid_usernames or password != env_password: return {"error": "Invalid username or password"}, "" try: # Remove all whitespace characters sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "") if not all(nucleotide in 'ACTGN' for nucleotide in sequence): return {"error": "Sequence contains invalid characters"}, "" if len(sequence) < 300: return {"error": "Sequence needs to be at least 300 nucleotides long"}, "" model, tokenizer = load_model(model_name) def get_logits(seq, model_name): if model_name == 'gena-bert': inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) with torch.no_grad(): logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) return logits elif model_name == 'CNN': # Truncate sequence SEQUENCE_LENGTH = 8000 seq = seq[:SEQUENCE_LENGTH] # Pad sequences to the desired length seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH] # Apply one-hot encoding to the 'sequence' column input = seq.one_hot_encode() with torch.no_grad(): logits = model(input) return logits # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'): # num_shifts = len(sequence) // 1000 # logits_sum = None # for i in range(num_shifts): # shifted_sequence = sequence[i*1000:] + sequence[:i*1000] # logits = get_logits(shifted_sequence) # if logits_sum is None: # logits_sum = logits # else: # logits_sum += logits # logits_avg = logits_sum / num_shifts # else: logits_avg = get_logits(sequence) probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist() top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] top_5_probs = [probabilities[i] for i in top_5_indices] top_5_labels = [int_to_label[i] for i in top_5_indices] result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(top_5_labels, top_5_probs, color='skyblue') ax.set_xlabel('Probability') ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:') plt.gca().invert_yaxis() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() return result, f'' except Exception as e: return {"error": str(e)}, "" # Create a Gradio interface demo = gr.Interface( fn=analyze_dna, inputs=[ gr.Textbox(label="Username"), gr.Textbox(label="Password", type="password"), gr.Textbox(label="DNA Sequence"), gr.Dropdown(label="Model", choices=[ "gena-bert", "CNN" ]) ], outputs=["json", "HTML"] ) # Launch the interface demo.launch()