import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel import torch import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import io import base64 import os import huggingface_hub from huggingface_hub import hf_hub_download, login import model_archs from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN import tangermeme from tangermeme.utils import one_hot_encode # Load label mapping label_to_int = pd.read_pickle('label_to_int.pkl') int_to_label = {v: k for k, v in label_to_int.items()} # Update labels based on the given conditions for k, v in int_to_label.items(): if "KOREA" in v: int_to_label[k] = "KOREA" elif "KINGDOM" in v: int_to_label[k] = "UK" elif "RUSSIAN" in v: int_to_label[k] = "RUSSIA" def load_model(model_name: str): metadata_features = 0 N_UNIQUE_CLASSES = 38 if model_name == 'GENA-Bert': base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) input_size = 768 + metadata_features log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) token = os.getenv('HUGGINGFACE_TOKEN') if token is None: raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") login(token=token) file_path = hf_hub_download( repo_id="mawairon/noo_test", filename="gena-blastln-bs33-lr4e-05-S168.pth", use_auth_token=token ) weights = torch.load(file_path, map_location=torch.device('cpu')) base_model.load_state_dict(weights['model_state_dict']) log_reg.load_state_dict(weights['log_reg_state_dict']) model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) model.eval() return model, tokenizer elif model_name == 'CNN-m2-8k-context': hidden_dim = 2048 width = 2048 seq_drop_prob = 0.05 train_sequence_length = 8000 weight_decay = 0.0001 num_countries = 38 model_seq = SimpleCNN(18, hidden_dim, additional_layer=False) new_head = torch.nn.Sequential( torch.nn.Dropout(0.5), MLP([hidden_dim*2 , num_countries]) ) model = torch.nn.Sequential( model_seq, new_head ) weights = torch.load('CNN_1stGEAC_m2_best.pth',map_location=torch.device('cpu')) model.load_state_dict(weights) return model, None elif model_name == 'CNN-m4-16k-context': seq_drop_prob = 0.05 hidden_dim = 2000 width = 768 train_sequence_length = 16000 weight_decay = 0.0001 num_labs = 38 model_seq = nn.Sequential( nn.Conv1d(4, width, 7, padding=3), nn.ReLU(), nn.BatchNorm1d(width), ResNet1d(width, [(3, width // 2, 1)] * 1, dropout=None, dilation=7), nn.ReLU(), Pool2BN(width), ) new_head = torch.nn.Sequential( torch.nn.Dropout(0.5), ## for DEEPLIFT comment out MLP([width * 2, num_labs]) ) joined_model = torch.nn.Sequential( model_seq, new_head ) weights = torch.load('NOO_CNN_1stGEAC_m4_16kcw_best.pth',map_location=torch.device('cpu')) joined_model.load_state_dict(weights) return joined_model, None else: raise ValueError("Invalid model name") def analyze_dna(username, password, sequence, model_name): valid_usernames = os.getenv('USERNAME').split(',') env_password = os.getenv('PASSWORD') if username not in valid_usernames or password != env_password: return {"error": "Invalid username or password"}, "" try: # Remove all whitespace characters sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "") if not all(nucleotide in 'ACTGN' for nucleotide in sequence): return {"error": "Sequence contains invalid characters"}, "" if len(sequence) < 300: return {"error": "Sequence needs to be at least 300 nucleotides long"}, "" model, tokenizer = load_model(model_name) def get_logits(seq, model_name): if model_name == 'GENA-Bert': inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) with torch.no_grad(): logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) return logits elif 'CNN' in model_name: SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None) pad_char = 'N' # Truncate sequence seq = seq[:SEQUENCE_LENGTH] # Pad sequences to the desired length seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH] # Apply one-hot encoding to the sequence input_tensor = one_hot_encode(seq).unsqueeze(0).float() with torch.no_grad(): logits = model(input_tensor) return logits else: raise ValueError("Invalid model name") # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'): # num_shifts = len(sequence) // 1000 # logits_sum = None # for i in range(num_shifts): # shifted_sequence = sequence[i*1000:] + sequence[:i*1000] # logits = get_logits(shifted_sequence) # if logits_sum is None: # logits_sum = logits # else: # logits_sum += logits # logits_avg = logits_sum / num_shifts # else: logits_avg = get_logits(sequence, model_name) probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist() top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] top_5_probs = [probabilities[i] for i in top_5_indices] top_5_labels = [int_to_label[i] for i in top_5_indices] result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(top_5_labels, top_5_probs, color='skyblue') ax.set_xlabel('Probability') ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:') plt.gca().invert_yaxis() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() return result, f'' except Exception as e: return {"error": str(e)}, "" # Check if the current version of Gradio supports HTML output try: html_output = gr.HTML except AttributeError: # Fallback or custom handling if HTML is not supported html_output = gr.Textbox # You can replace this with an appropriate component # Create a Gradio interface demo = gr.Interface( fn=analyze_dna, inputs=[ gr.Textbox(label="Username"), gr.Textbox(label="Password", type="password"), gr.Textbox(label="DNA Sequence"), gr.Dropdown(label="Model", choices=[ "GENA-Bert", "CNN-m2-8k-context", "CNN-m4-16k-context" ]) ], outputs=[ gr.JSON(), # Properly instantiate the JSON output component gr.HTML() # Properly instantiate the HTML output component ] ) # Launch the interface demo.launch()