import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel import torch import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import io import base64 # Assuming label_to_int is a dictionary with {label_name: label_index} label_to_int = pd.read_pickle('label_to_int.pkl') int_to_label = {v: k for k, v in label_to_int.items()} class LogisticRegressionTorch(nn.Module): def __init__(self, input_dim: int, output_dim: int): super(LogisticRegressionTorch, self).__init__() self.batch_norm = nn.BatchNorm1d(num_features=input_dim) self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): x = self.batch_norm(x) out = self.linear(x) return out class BertClassifier(nn.Module): def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int): super(BertClassifier, self).__init__() self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel self.classifier = classifier self.num_labels = num_labels def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, token_type_ids: torch.Tensor = None, labels: torch.Tensor = None): # Extract outputs from the BERT model outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True) # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch pooled_output = outputs.hidden_states[-1][:, 0, :] assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}" # to-do later! # Pass the pooled output to the classifier to get the logits logits = self.classifier(pooled_output) # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() pred = logits.view(-1, self.num_labels) observed = labels.view(-1) loss = loss_fct(pred, observed) # Return the loss and logits return loss, logits # Load the Hugging Face model and tokenizer metadata_features = 0 N_UNIQUE_CLASSES = 38 # or 38 base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) # Initialize the classifier input_size = 768 + metadata_features # featurizer output size + metadata size log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) # Load Weights model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth' weights = torch.load(model_weights_path, map_location=torch.device('cpu')) base_model.load_state_dict(weights['model_state_dict']) log_reg.load_state_dict(weights['log_reg_state_dict']) # Creating Model model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) # Define a function to process the DNA sequence def analyze_dna(sequence): # Preprocess the input sequence inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) print("Tokenization done.") # Get model predictions _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) print("Forward pass done.") # Convert logits to probabilities probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() print("Probabilities done.") # Get the top 5 most likely classes top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] top_5_probs = [probabilities[i] for i in top_5_indices] # Map indices to label names top_5_labels = [int_to_label[i] for i in top_5_indices] # Prepare the output as a list of tuples (label_name, probability) result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] # Plot histogram fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(top_5_labels, top_5_probs, color='skyblue') ax.set_xlabel('Probability') ax.set_title('Top 5 Most Likely Labels') plt.gca().invert_yaxis() # Highest probabilities at the top # Save plot to a PNG image in memory buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() return result, f'' # Create a Gradio interface demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"]) # Launch the interface demo.launch()