Spaces:
Sleeping
Sleeping
import gradio as gr | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
# Assuming label_to_int is a dictionary with {label_name: label_index} | |
label_to_int = pd.read_pickle('label_to_int.pkl') | |
int_to_label = {v: k for k, v in label_to_int.items()} | |
class LogisticRegressionTorch(nn.Module): | |
def __init__(self, input_dim: int, output_dim: int): | |
super(LogisticRegressionTorch, self).__init__() | |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim) | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
x = self.batch_norm(x) | |
out = self.linear(x) | |
return out | |
class BertClassifier(nn.Module): | |
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int): | |
super(BertClassifier, self).__init__() | |
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel | |
self.classifier = classifier | |
self.num_labels = num_labels | |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, | |
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None): | |
# Extract outputs from the BERT model | |
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch | |
pooled_output = outputs.hidden_states[-1][:, 0, :] | |
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}" | |
# to-do later! | |
# Pass the pooled output to the classifier to get the logits | |
logits = self.classifier(pooled_output) | |
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
pred = logits.view(-1, self.num_labels) | |
observed = labels.view(-1) | |
loss = loss_fct(pred, observed) | |
# Return the loss and logits | |
return loss, logits | |
# Load the Hugging Face model and tokenizer | |
metadata_features = 0 | |
N_UNIQUE_CLASSES = 38 # or 38 | |
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) | |
# Initialize the classifier | |
input_size = 768 + metadata_features # featurizer output size + metadata size | |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) | |
# Load Weights | |
model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth' | |
weights = torch.load(model_weights_path, map_location=torch.device('cpu')) | |
base_model.load_state_dict(weights['model_state_dict']) | |
log_reg.load_state_dict(weights['log_reg_state_dict']) | |
# Creating Model | |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) | |
# Define a function to process the DNA sequence | |
def analyze_dna(sequence): | |
# Preprocess the input sequence | |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) | |
print("Tokenization done.") | |
# Get model predictions | |
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
print("Forward pass done.") | |
# Convert logits to probabilities | |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() | |
print("Probabilities done.") | |
# Get the top 5 most likely classes | |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] | |
top_5_probs = [probabilities[i] for i in top_5_indices] | |
# Map indices to label names | |
top_5_labels = [int_to_label[i] for i in top_5_indices] | |
# Prepare the output as a list of tuples (label_name, probability) | |
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] | |
# Plot histogram | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.barh(top_5_labels, top_5_probs, color='skyblue') | |
ax.set_xlabel('Probability') | |
ax.set_title('Top 5 Most Likely Labels') | |
plt.gca().invert_yaxis() # Highest probabilities at the top | |
# Save plot to a PNG image in memory | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
image_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
buf.close() | |
return result, f'<img src="data:image/png;base64,{image_base64}" />' | |
# Create a Gradio interface | |
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"]) | |
# Launch the interface | |
demo.launch() | |