Spaces:
Sleeping
Sleeping
File size: 5,225 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 66990c3 dc7d693 66990c3 3f3c29c 66990c3 3f3c29c 996a1ec 3f3c29c 66990c3 3f3c29c 996a1ec 3f3c29c 996a1ec 3f3c29c 5f8dde1 3f3c29c 66990c3 3f3c29c 996a1ec 3f3c29c 66990c3 3f3c29c db9f444 3f3c29c c1e6692 daf9507 996a1ec 9b849f4 1f65033 5f8dde1 dc3cae8 5f8dde1 513b115 1f65033 5f8dde1 996a1ec 17c6a2f 04805af 996a1ec 17c6a2f 1f65033 04805af 1f65033 66990c3 dc7d693 66990c3 dc7d693 66990c3 5f8dde1 66990c3 5f8dde1 c1e6692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
# Assuming label_to_int is a dictionary with {label_name: label_index}
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
# Extract outputs from the BERT model
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
pooled_output = outputs.hidden_states[-1][:, 0, :]
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
# to-do later!
# Pass the pooled output to the classifier to get the logits
logits = self.classifier(pooled_output)
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
pred = logits.view(-1, self.num_labels)
observed = labels.view(-1)
loss = loss_fct(pred, observed)
# Return the loss and logits
return loss, logits
# Load the Hugging Face model and tokenizer
metadata_features = 0
N_UNIQUE_CLASSES = 38 # or 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
# Load Weights
import os
# Get the model path from the environment variable
model_path = os.getenv('MODEL_PATH')
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
# Define a function to process the DNA sequence
def analyze_dna(sequence):
assert all(nucleotide in 'ACTGN' for nucleotide in sequence), "Sequence contains invalid characters"
assert len(sequence) >= 300, "Sequence needs to be at least 300 nucleotides long"
# Preprocess the input sequence
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
print("Tokenization done.")
# Get model predictions
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
print("Forward pass done.")
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
print("Probabilities done.")
# Get the top 5 most likely classes
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
# Map indices to label names
top_5_labels = [int_to_label[i] for i in top_5_indices]
# Prepare the output as a list of tuples (label_name, probability)
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis() # Highest probabilities at the top
# Save plot to a PNG image in memory
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
# Launch the interface
demo.launch()
|