Spaces:
Sleeping
Sleeping
File size: 3,840 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 66990c3 9bf3f2b dc7d693 9bf3f2b 66990c3 3f3c29c 66990c3 3f3c29c 996a1ec 3f3c29c 66990c3 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 996a1ec 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b db9f444 9bf3f2b 1f65033 5f8dde1 356d0ee 9bf3f2b dc3cae8 356d0ee 9bf3f2b 356d0ee 9bf3f2b 356d0ee 9bf3f2b 356d0ee 9bf3f2b 356d0ee 5f8dde1 66990c3 5f8dde1 c1e6692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
pooled_output = outputs.hidden_states[-1][:, 0, :]
logits = self.classifier(pooled_output)
return logits
def load_model():
metadata_features = 0
N_UNIQUE_CLASSES = 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
model_weights_path = os.getenv('MODEL_PATH')
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
def analyze_dna(sequence):
try:
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return "Error: Sequence contains invalid characters", ""
if len(sequence) < 300:
return "Error: Sequence needs to be at least 300 nucleotides long", ""
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return str(e), ""
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
# Launch the interface
demo.launch()
|