NOOTestspace / app.py
mawairon's picture
Update app.py
9bf3f2b verified
raw
history blame
3.84 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
pooled_output = outputs.hidden_states[-1][:, 0, :]
logits = self.classifier(pooled_output)
return logits
def load_model():
metadata_features = 0
N_UNIQUE_CLASSES = 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
model_weights_path = os.getenv('MODEL_PATH')
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
def analyze_dna(sequence):
try:
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return "Error: Sequence contains invalid characters", ""
if len(sequence) < 300:
return "Error: Sequence needs to be at least 300 nucleotides long", ""
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return str(e), ""
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
# Launch the interface
demo.launch()