import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel import torch import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import io import base64 import os import huggingface_hub from huggingface_hub import hf_hub_download, login # Load label mapping label_to_int = pd.read_pickle('label_to_int.pkl') int_to_label = {v: k for k, v in label_to_int.items()} # Update labels based on the given conditions for k, v in int_to_label.items(): if "KOREA" in v: int_to_label[k] = "KOREA" elif "KINGDOM" in v: int_to_label[k] = "UK" elif "RUSSIAN" in v: int_to_label[k] = "RUSSIA" class LogisticRegressionTorch(nn.Module): def __init__(self, input_dim: int, output_dim: int): super(LogisticRegressionTorch, self).__init__() self.batch_norm = nn.BatchNorm1d(num_features=input_dim) self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): x = self.batch_norm(x) out = self.linear(x) return out class BertClassifier(nn.Module): def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int): super(BertClassifier, self).__init__() self.bert = bert_model self.classifier = classifier self.num_labels = num_labels def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True) pooled_output = outputs.hidden_states[-1][:, 0, :] logits = self.classifier(pooled_output) return logits def load_model(): metadata_features = 0 N_UNIQUE_CLASSES = 38 base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) input_size = 768 + metadata_features log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) token = os.getenv('HUGGINGFACE_TOKEN') if token is None: raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") login(token=token) file_path = hf_hub_download( repo_id="mawairon/noo_test", filename="gena-blastln-bs33-lr4e-05-S168.pth", use_auth_token=token ) weights = torch.load(file_path, map_location=torch.device('cpu')) base_model.load_state_dict(weights['model_state_dict']) log_reg.load_state_dict(weights['log_reg_state_dict']) model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) model.eval() return model, tokenizer model, tokenizer = load_model() def analyze_dna(username, password, sequence): valid_usernames = os.getenv('USERNAME').split(',') env_password = os.getenv('PASSWORD') if username not in valid_usernames or password != env_password: return {"error": "Invalid username or password"}, "" try: # Remove all whitespace characters sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "") if not all(nucleotide in 'ACTGN' for nucleotide in sequence): return {"error": "Sequence contains invalid characters"}, "" if len(sequence) < 300: return {"error": "Sequence needs to be at least 300 nucleotides long"}, "" def get_logits(seq): inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) with torch.no_grad(): logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) return logits if len(sequence) > 3000: num_shifts = len(sequence) // 1000 logits_sum = None for i in range(num_shifts): shifted_sequence = sequence[i*1000:] + sequence[:i*1000] logits = get_logits(shifted_sequence) if logits_sum is None: logits_sum = logits else: logits_sum += logits logits_avg = logits_sum / num_shifts else: logits_avg = get_logits(sequence) probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist() top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] top_5_probs = [probabilities[i] for i in top_5_indices] top_5_labels = [int_to_label[i] for i in top_5_indices] result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(top_5_labels, top_5_probs, color='skyblue') ax.set_xlabel('Probability') ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:') plt.gca().invert_yaxis() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() return result, f'' except Exception as e: return {"error": str(e)}, "" # Create a Gradio interface demo = gr.Interface( fn=analyze_dna, inputs=[ gr.Textbox(label="Username"), gr.Textbox(label="Password", type="password"), gr.Textbox(label="DNA Sequence") ], outputs=["json", "html"] ) # Launch the interface demo.launch()