import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel import torch import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import io import base64 import os # Load label mapping label_to_int = pd.read_pickle('label_to_int.pkl') int_to_label = {v: k for k, v in label_to_int.items()} class LogisticRegressionTorch(nn.Module): def __init__(self, input_dim: int, output_dim: int): super(LogisticRegressionTorch, self).__init__() self.batch_norm = nn.BatchNorm1d(num_features=input_dim) self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): x = self.batch_norm(x) out = self.linear(x) return out class BertClassifier(nn.Module): def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int): super(BertClassifier, self).__init__() self.bert = bert_model self.classifier = classifier self.num_labels = num_labels def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True) pooled_output = outputs.hidden_states[-1][:, 0, :] logits = self.classifier(pooled_output) return logits def load_model(): metadata_features = 0 N_UNIQUE_CLASSES = 38 base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) input_size = 768 + metadata_features log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) model_weights_path = os.getenv('MODEL_PATH') weights = torch.load(model_weights_path, map_location=torch.device('cpu')) base_model.load_state_dict(weights['model_state_dict']) log_reg.load_state_dict(weights['log_reg_state_dict']) model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) model.eval() return model, tokenizer model, tokenizer = load_model() def analyze_dna(sequence): try: if not all(nucleotide in 'ACTGN' for nucleotide in sequence): return "Error: Sequence contains invalid characters", "" if len(sequence) < 300: return "Error: Sequence needs to be at least 300 nucleotides long", "" inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] top_5_probs = [probabilities[i] for i in top_5_indices] top_5_labels = [int_to_label[i] for i in top_5_indices] result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(top_5_labels, top_5_probs, color='skyblue') ax.set_xlabel('Probability') ax.set_title('Top 5 Most Likely Labels') plt.gca().invert_yaxis() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() return result, f'' except Exception as e: return str(e), "" # Create a Gradio interface demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"]) # Launch the interface demo.launch()