NOOTestspace / app.py
mawairon's picture
Update app.py
93aa4b0 verified
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN, ResNet1d
import tangermeme
from tangermeme.utils import one_hot_encode
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
def load_model(model_name: str):
metadata_features = 0
N_UNIQUE_CLASSES = 38
if model_name == 'GENA-Bert':
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
elif model_name == 'CNN-m2-8k-context':
hidden_dim = 2048
width = 2048
seq_drop_prob = 0.05
train_sequence_length = 8000
weight_decay = 0.0001
num_countries = 38
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5),
MLP([hidden_dim*2 , num_countries])
)
model = torch.nn.Sequential(
model_seq,
new_head
)
weights = torch.load('CNN_1stGEAC_m2_best.pth',map_location=torch.device('cpu'))
model.load_state_dict(weights)
return model, None
elif model_name == 'CNN-m4-16k-context':
seq_drop_prob = 0.05
hidden_dim = 2000
width = 768
train_sequence_length = 16000
weight_decay = 0.0001
num_labs = 38
model_seq = nn.Sequential(
nn.Conv1d(4, width, 7, padding=3),
nn.ReLU(),
nn.BatchNorm1d(width),
ResNet1d(width, [(3, width // 2, 1)] * 1, dropout=None, dilation=7),
nn.ReLU(),
Pool2BN(width),
)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5), ## for DEEPLIFT comment out
MLP([width * 2, num_labs])
)
joined_model = torch.nn.Sequential(
model_seq,
new_head
)
weights = torch.load('NOO_CNN_1stGEAC_m4_16kcw_best.pth',map_location=torch.device('cpu'))
joined_model.load_state_dict(weights)
return joined_model, None
else:
raise ValueError("Invalid model name")
def analyze_dna(username, password, sequence, model_name):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
model, tokenizer = load_model(model_name)
def get_logits(sequence, model_name):
if model_name == 'GENA-Bert':
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
elif 'CNN' in model_name:
SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
pad_char = 'N'
# Truncate sequence
sequence = sequence[:SEQUENCE_LENGTH]
# Pad sequences to the desired length
sequence = sequence.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
# Apply one-hot encoding to the sequence
input_tensor = one_hot_encode(sequence).unsqueeze(0).float()
print(f'shape of input tensor{input_tensor.shape}')
model.eval()
with torch.no_grad():
logits = model(input_tensor)
return logits
else:
raise ValueError("Invalid model name")
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
# num_shifts = len(sequence) // 1000
# logits_sum = None
# for i in range(num_shifts):
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
# logits = get_logits(shifted_sequence)
# if logits_sum is None:
# logits_sum = logits
# else:
# logits_sum += logits
# logits_avg = logits_sum / num_shifts
# else:
logits_avg = get_logits(sequence, model_name)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Check if the current version of Gradio supports HTML output
try:
html_output = gr.HTML
except AttributeError:
# Fallback or custom handling if HTML is not supported
html_output = gr.Textbox # You can replace this with an appropriate component
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence"),
gr.Dropdown(label="Model", choices=[
"GENA-Bert",
"CNN-m2-8k-context",
"CNN-m4-16k-context"
])
],
outputs=[
gr.JSON(), # Properly instantiate the JSON output component
gr.HTML() # Properly instantiate the HTML output component
]
)
# Launch the interface
demo.launch()