NOOTestspace /
mawairon's picture
dc7d693 verified
history blame
4.76 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
class LogisticRegressionTorch(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self,
bert_model: AutoModel,
classifier: LogisticRegressionTorch,
num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
# Extract outputs from the BERT model
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
pooled_output = outputs.hidden_states[-1][:, 0, :]
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
# to-do later!
# Pass the pooled output to the classifier to get the logits
logits = self.classifier(pooled_output)
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
pred = logits.view(-1, self.num_labels)
observed = labels.view(-1)
loss = loss_fct(pred, observed)
# Return the loss and logits
return loss, logits
# Load the Hugging Face model and tokenizer
metadata_features = 0
N_UNIQUE_CLASSES = 38 ## or 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
# Load Weights
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
# Dictionary to decode model predictions
label_to_int = pd.read_pkl('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Define a function to process the DNA sequence
def analyze_dna(sequence):
# Preprocess the input sequence
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
print("Tokenization done.")
# Get model predictions
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
print("Forward pass done.")
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
print("Probabilities done.")
# Get the top 5 most likely classes
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
# Map indices to label names
top_5_labels = [int_to_label[i] for i in top_5_indices]
# Prepare the output as a list of tuples (label_name, probability)
#result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis() # Highest probabilities at the top
#return result
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
# Launch the interface