Spaces:
Sleeping
Sleeping
import gradio as gr | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
import os | |
import huggingface_hub | |
from huggingface_hub import hf_hub_download, login | |
import model_archs | |
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN | |
import tangermeme | |
from tangermeme.utils import one_hot_encode | |
# Load label mapping | |
label_to_int = pd.read_pickle('label_to_int.pkl') | |
int_to_label = {v: k for k, v in label_to_int.items()} | |
# Update labels based on the given conditions | |
for k, v in int_to_label.items(): | |
if "KOREA" in v: | |
int_to_label[k] = "KOREA" | |
elif "KINGDOM" in v: | |
int_to_label[k] = "UK" | |
elif "RUSSIAN" in v: | |
int_to_label[k] = "RUSSIA" | |
def load_model(model_name: str): | |
metadata_features = 0 | |
N_UNIQUE_CLASSES = 38 | |
if model_name == 'gena-bert': | |
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True) | |
input_size = 768 + metadata_features | |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES) | |
token = os.getenv('HUGGINGFACE_TOKEN') | |
if token is None: | |
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") | |
login(token=token) | |
file_path = hf_hub_download( | |
repo_id="mawairon/noo_test", | |
filename="gena-blastln-bs33-lr4e-05-S168.pth", | |
use_auth_token=token | |
) | |
weights = torch.load(file_path, map_location=torch.device('cpu')) | |
base_model.load_state_dict(weights['model_state_dict']) | |
log_reg.load_state_dict(weights['log_reg_state_dict']) | |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES) | |
model.eval() | |
return model, tokenizer | |
elif model_name == 'CNN': | |
hidden_dim = 2048 | |
width = 2048 | |
seq_drop_prob = 0.05 | |
train_sequence_length = 8000 | |
weight_decay = 0.0001 | |
num_countries = 38 | |
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False) | |
new_head = torch.nn.Sequential( | |
torch.nn.Dropout(0.5), | |
MLP([hidden_dim*2 , num_countries]) | |
) | |
model = torch.nn.Sequential( | |
model_seq, | |
new_head | |
) | |
weights = torch.load('/CNN_1stGEAC_m2_best.pth') | |
model.load_state_dict(weights) | |
return model, None | |
else: | |
raise ValueError("Invalid model name") | |
def analyze_dna(username, password, sequence, model_name): | |
valid_usernames = os.getenv('USERNAME').split(',') | |
env_password = os.getenv('PASSWORD') | |
if username not in valid_usernames or password != env_password: | |
return {"error": "Invalid username or password"}, "" | |
try: | |
# Remove all whitespace characters | |
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "") | |
if not all(nucleotide in 'ACTGN' for nucleotide in sequence): | |
return {"error": "Sequence contains invalid characters"}, "" | |
if len(sequence) < 300: | |
return {"error": "Sequence needs to be at least 300 nucleotides long"}, "" | |
model, tokenizer = load_model(model_name) | |
def get_logits(seq, model_name): | |
if model_name == 'gena-bert': | |
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False) | |
with torch.no_grad(): | |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
return logits | |
elif model_name == 'CNN': | |
SEQUENCE_LENGTH = 8000 | |
pad_char = 'N' | |
# Truncate sequence | |
seq = seq[:SEQUENCE_LENGTH] | |
# Pad sequences to the desired length | |
seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH] | |
# Apply one-hot encoding to the sequence | |
input_tensor = one_hot_encode(seq).unsqueeze(0) | |
with torch.no_grad(): | |
logits = model(input_tensor) | |
return logits | |
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'): | |
# num_shifts = len(sequence) // 1000 | |
# logits_sum = None | |
# for i in range(num_shifts): | |
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000] | |
# logits = get_logits(shifted_sequence) | |
# if logits_sum is None: | |
# logits_sum = logits | |
# else: | |
# logits_sum += logits | |
# logits_avg = logits_sum / num_shifts | |
# else: | |
logits_avg = get_logits(sequence) | |
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist() | |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5] | |
top_5_probs = [probabilities[i] for i in top_5_indices] | |
top_5_labels = [int_to_label[i] for i in top_5_indices] | |
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)] | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.barh(top_5_labels, top_5_probs, color='skyblue') | |
ax.set_xlabel('Probability') | |
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:') | |
plt.gca().invert_yaxis() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
image_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
buf.close() | |
return result, f'<img src="data:image/png;base64,{image_base64}" />' | |
except Exception as e: | |
return {"error": str(e)}, "" | |
# Check if the current version of Gradio supports HTML output | |
try: | |
html_output = gr.HTML | |
except AttributeError: | |
# Fallback or custom handling if HTML is not supported | |
html_output = gr.Textbox # You can replace this with an appropriate component | |
# Create a Gradio interface | |
demo = gr.Interface( | |
fn=analyze_dna, | |
inputs=[ | |
gr.Textbox(label="Username"), | |
gr.Textbox(label="Password", type="password"), | |
gr.Textbox(label="DNA Sequence"), | |
gr.Dropdown(label="Model", choices=[ | |
"gena-bert", | |
"CNN" | |
]) | |
], | |
outputs=[ | |
gr.JSON(), # Properly instantiate the JSON output component | |
gr.HTML() # Properly instantiate the HTML output component | |
] | |
) | |
# Launch the interface | |
demo.launch() |