NOOTestspace / app.py
mawairon's picture
Update app.py
66990c3 verified
raw
history blame
4.98 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
# Assuming label_to_int is a dictionary with {label_name: label_index}
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
# Extract outputs from the BERT model
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
pooled_output = outputs.hidden_states[-1][:, 0, :]
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
# to-do later!
# Pass the pooled output to the classifier to get the logits
logits = self.classifier(pooled_output)
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
pred = logits.view(-1, self.num_labels)
observed = labels.view(-1)
loss = loss_fct(pred, observed)
# Return the loss and logits
return loss, logits
# Load the Hugging Face model and tokenizer
metadata_features = 0
N_UNIQUE_CLASSES = 38 # or 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
# Load Weights
model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
# Define a function to process the DNA sequence
def analyze_dna(sequence):
# Preprocess the input sequence
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
print("Tokenization done.")
# Get model predictions
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
print("Forward pass done.")
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
print("Probabilities done.")
# Get the top 5 most likely classes
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
# Map indices to label names
top_5_labels = [int_to_label[i] for i in top_5_indices]
# Prepare the output as a list of tuples (label_name, probability)
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis() # Highest probabilities at the top
# Save plot to a PNG image in memory
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
# Launch the interface
demo.launch()