NOOTestspace / app.py
mawairon's picture
Update app.py
356d0ee verified
raw
history blame
5.46 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
# Assuming label_to_int is a dictionary with {label_name: label_index}
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
# Extract outputs from the BERT model
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
pooled_output = outputs.hidden_states[-1][:, 0, :]
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
# to-do later!
# Pass the pooled output to the classifier to get the logits
logits = self.classifier(pooled_output)
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
pred = logits.view(-1, self.num_labels)
observed = labels.view(-1)
loss = loss_fct(pred, observed)
# Return the loss and logits
return loss, logits
# Load the Hugging Face model and tokenizer
metadata_features = 0
N_UNIQUE_CLASSES = 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
# Load Weights
import os
# Get the model path from the environment variable
model_weights_path = os.getenv('MODEL_PATH')
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
def analyze_dna(sequence):
try:
# Check if the sequence contains only valid characters
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
raise ValueError("Sequence contains invalid characters")
# Check if the sequence is at least 300 nucleotides long
if len(sequence) < 300:
raise ValueError("Sequence needs to be at least 300 nucleotides long")
# Preprocess the input sequence
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
# Get model predictions
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
# Get the top 5 most likely classes
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
# Map indices to label names
top_5_labels = [int_to_label[i] for i in top_5_indices]
# Prepare the output as a list of tuples (label_name, probability)
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis() # Highest probabilities at the top
# Save plot to a PNG image in memory
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except ValueError as e:
# Return the error message
return str(e), ""
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
# Launch the interface
demo.launch()