Spaces:
Sleeping
Sleeping
File size: 8,514 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 66990c3 9bf3f2b 65391f8 6dddebe ace289f 93aa4b0 ace289f 5492a5f dc7d693 9bf3f2b 66990c3 3f3c29c db842cf 306f08b 3f3c29c ace289f 3a67180 ace289f 7f58142 ace289f 7680154 ace289f 7680154 ace289f 3a67180 b7c4acd ace289f 7f58142 5a0bdbe 7f58142 3a67180 5a0bdbe 7f58142 ace289f 8c5a0b0 ace289f 3f3c29c db9f444 1f65033 ace289f b1eabde 5935bca c214f12 5935bca 6dddebe 356d0ee ace289f b1eabde 356d0ee fe4ceb1 dc3cae8 356d0ee fe4ceb1 356d0ee ace289f 93aa4b0 8c5a0b0 809ae7d 8c5a0b0 93aa4b0 ace289f 7f58142 8c5a0b0 16b0032 8c5a0b0 93aa4b0 ace289f 93aa4b0 ace289f 8c5a0b0 93aa4b0 ace289f 8c5a0b0 ace289f 809ae7d ace289f b312d27 b1eabde 356d0ee ad5e55a 9bf3f2b 356d0ee a948bde 356d0ee 9bf3f2b fe4ceb1 5f8dde1 5dca1cc 5f8dde1 6dddebe 5935bca ace289f 16b0032 7f58142 ace289f 5935bca 99d9c5e 6dddebe 5f8dde1 5dca1cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN, ResNet1d
import tangermeme
from tangermeme.utils import one_hot_encode
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
def load_model(model_name: str):
metadata_features = 0
N_UNIQUE_CLASSES = 38
if model_name == 'GENA-Bert':
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
elif model_name == 'CNN-m2-8k-context':
hidden_dim = 2048
width = 2048
seq_drop_prob = 0.05
train_sequence_length = 8000
weight_decay = 0.0001
num_countries = 38
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5),
MLP([hidden_dim*2 , num_countries])
)
model = torch.nn.Sequential(
model_seq,
new_head
)
weights = torch.load('CNN_1stGEAC_m2_best.pth',map_location=torch.device('cpu'))
model.load_state_dict(weights)
return model, None
elif model_name == 'CNN-m4-16k-context':
seq_drop_prob = 0.05
hidden_dim = 2000
width = 768
train_sequence_length = 16000
weight_decay = 0.0001
num_labs = 38
model_seq = nn.Sequential(
nn.Conv1d(4, width, 7, padding=3),
nn.ReLU(),
nn.BatchNorm1d(width),
ResNet1d(width, [(3, width // 2, 1)] * 1, dropout=None, dilation=7),
nn.ReLU(),
Pool2BN(width),
)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5), ## for DEEPLIFT comment out
MLP([width * 2, num_labs])
)
joined_model = torch.nn.Sequential(
model_seq,
new_head
)
weights = torch.load('NOO_CNN_1stGEAC_m4_16kcw_best.pth',map_location=torch.device('cpu'))
joined_model.load_state_dict(weights)
return joined_model, None
else:
raise ValueError("Invalid model name")
def analyze_dna(username, password, sequence, model_name):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
model, tokenizer = load_model(model_name)
def get_logits(sequence, model_name):
if model_name == 'GENA-Bert':
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
elif 'CNN' in model_name:
SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
pad_char = 'N'
# Truncate sequence
sequence = sequence[:SEQUENCE_LENGTH]
# Pad sequences to the desired length
sequence = sequence.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
# Apply one-hot encoding to the sequence
input_tensor = one_hot_encode(sequence).unsqueeze(0).float()
print(f'shape of input tensor{input_tensor.shape}')
model.eval()
with torch.no_grad():
logits = model(input_tensor)
return logits
else:
raise ValueError("Invalid model name")
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
# num_shifts = len(sequence) // 1000
# logits_sum = None
# for i in range(num_shifts):
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
# logits = get_logits(shifted_sequence)
# if logits_sum is None:
# logits_sum = logits
# else:
# logits_sum += logits
# logits_avg = logits_sum / num_shifts
# else:
logits_avg = get_logits(sequence, model_name)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Check if the current version of Gradio supports HTML output
try:
html_output = gr.HTML
except AttributeError:
# Fallback or custom handling if HTML is not supported
html_output = gr.Textbox # You can replace this with an appropriate component
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence"),
gr.Dropdown(label="Model", choices=[
"GENA-Bert",
"CNN-m2-8k-context",
"CNN-m4-16k-context"
])
],
outputs=[
gr.JSON(), # Properly instantiate the JSON output component
gr.HTML() # Properly instantiate the HTML output component
]
)
# Launch the interface
demo.launch() |