NOOTestspace / app.py
mawairon's picture
Update app.py
7680154 verified
raw
history blame
7.14 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
import tangermeme
from tangermeme.utils import one_hot_encode
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
def load_model(model_name: str):
metadata_features = 0
N_UNIQUE_CLASSES = 38
if model_name == 'gena-bert':
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
elif model_name == 'CNN':
hidden_dim = 2048
width = 2048
seq_drop_prob = 0.05
train_sequence_length = 8000
weight_decay = 0.0001
num_countries = 38
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5),
MLP([hidden_dim*2 , num_countries])
)
model = torch.nn.Sequential(
model_seq,
new_head
)
weights = torch.load('/CNN_1stGEAC_m2_best.pth')
model.load_state_dict(weights)
return model, None
else:
raise ValueError("Invalid model name")
def analyze_dna(username, password, sequence, model_name):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
model, tokenizer = load_model(model_name)
def get_logits(seq, model_name):
if model_name == 'gena-bert':
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
elif model_name == 'CNN':
SEQUENCE_LENGTH = 8000
pad_char = 'N'
# Truncate sequence
seq = seq[:SEQUENCE_LENGTH]
# Pad sequences to the desired length
seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
# Apply one-hot encoding to the sequence
input_tensor = one_hot_encode(seq).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)
return logits
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
# num_shifts = len(sequence) // 1000
# logits_sum = None
# for i in range(num_shifts):
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
# logits = get_logits(shifted_sequence)
# if logits_sum is None:
# logits_sum = logits
# else:
# logits_sum += logits
# logits_avg = logits_sum / num_shifts
# else:
logits_avg = get_logits(sequence)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Check if the current version of Gradio supports HTML output
try:
html_output = gr.HTML
except AttributeError:
# Fallback or custom handling if HTML is not supported
html_output = gr.Textbox # You can replace this with an appropriate component
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence"),
gr.Dropdown(label="Model", choices=[
"gena-bert",
"CNN"
])
],
outputs=[
gr.JSON(), # Properly instantiate the JSON output component
gr.HTML() # Properly instantiate the HTML output component
]
)
# Launch the interface
demo.launch()