Spaces:
Sleeping
Sleeping
File size: 6,533 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 66990c3 9bf3f2b 65391f8 6dddebe ace289f dc7d693 9bf3f2b 66990c3 3f3c29c db842cf 306f08b 3f3c29c ace289f 3f3c29c db9f444 1f65033 ace289f b1eabde 5935bca c214f12 5935bca 6dddebe 356d0ee ace289f b1eabde 356d0ee fe4ceb1 dc3cae8 356d0ee fe4ceb1 356d0ee ace289f b1eabde 356d0ee ad5e55a 9bf3f2b 356d0ee a948bde 356d0ee 9bf3f2b fe4ceb1 5f8dde1 6dddebe 5935bca ace289f 5935bca ace289f 6dddebe 5f8dde1 c1e6692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
import tangermeme
from tangermeme import one_hot_encode
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
def load_model(model_name: str):
metadata_features = 0
N_UNIQUE_CLASSES = 38
if model_name == 'gena-bert':
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
elif model_name == 'CNN':
hidden_dim = 2048
width = 2048
seq_drop_prob = 0.05
train_sequence_length = 8000
weight_decay = 0.0001
num_labs = len(set(y_train))
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5),
MLP([hidden_dim*2 , num_labs])
)
model = torch.nn.Sequential(
model_seq,
new_head
)
return model, None
else:
return {"error": "Invalid model name"}
def analyze_dna(username, password, sequence, model_name):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
model, tokenizer = load_model(model_name)
def get_logits(seq, model_name):
if model_name == 'gena-bert':
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
elif model_name == 'CNN':
# Truncate sequence
SEQUENCE_LENGTH = 8000
seq = seq[:SEQUENCE_LENGTH]
# Pad sequences to the desired length
seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH]
# Apply one-hot encoding to the 'sequence' column
input = seq.one_hot_encode()
with torch.no_grad():
logits = model(input)
return logits
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
# num_shifts = len(sequence) // 1000
# logits_sum = None
# for i in range(num_shifts):
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
# logits = get_logits(shifted_sequence)
# if logits_sum is None:
# logits_sum = logits
# else:
# logits_sum += logits
# logits_avg = logits_sum / num_shifts
# else:
logits_avg = get_logits(sequence)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence"),
gr.Dropdown(label="Model", choices=[
"gena-bert",
"CNN"
])
],
outputs=["json", "HTML"]
)
# Launch the interface
demo.launch()
|