cov-snn-app / predict.py
smtnkc
Using rank_by_scip
6d42f96
import os
import json
import random
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from transformers import RobertaForMaskedLM, RobertaModel, RobertaConfig
# Load the SentenceTransformer model
DISTANCE= "L2" # L1 for Manhattan distance, L2 for Euclidean distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# For reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"Loading model from checkpoint...")
sbert_model = SentenceTransformer("model", device=device)
transformer_model = sbert_model[0].auto_model.to(device) # Extract the core transformer model
# Get the configuration of the SentenceTransformer model
sbert_config = transformer_model.config
# Create a custom RobertaConfig with the same parameters
custom_config = RobertaConfig(
vocab_size=sbert_config.vocab_size,
hidden_size=sbert_config.hidden_size,
num_hidden_layers=sbert_config.num_hidden_layers,
num_attention_heads=sbert_config.num_attention_heads,
intermediate_size=sbert_config.intermediate_size,
max_position_embeddings=sbert_config.max_position_embeddings,
type_vocab_size=sbert_config.type_vocab_size,
initializer_range=sbert_config.initializer_range,
layer_norm_eps=sbert_config.layer_norm_eps,
hidden_dropout_prob=sbert_config.hidden_dropout_prob,
attention_probs_dropout_prob=sbert_config.attention_probs_dropout_prob,
)
# Initialize RobertaForMaskedLM and RobertaModel with the custom configuration
masked_lm_model = RobertaForMaskedLM(custom_config).to(device)
embedding_model = RobertaModel(custom_config).to(device)
tokenizer = sbert_model[0].tokenizer
# Copy the weights from transformer_model to masked_lm_model and embedding_model
transformer_model_state_dict = transformer_model.state_dict()
masked_lm_model.roberta.load_state_dict(transformer_model_state_dict, strict=False)
embedding_model.load_state_dict(transformer_model_state_dict, strict=False)
print(f"Weights successfully copied from checkpoint!")
# Set the models in evaluation mode
masked_lm_model.eval()
embedding_model.eval()
def mask_sentence(sentence):
tokens = tokenizer.tokenize(sentence)
masked_sentences = []
for i in range(len(tokens)):
masked_tokens = tokens.copy()
masked_tokens[i] = '<mask>'
masked_sentence = ' '.join(masked_tokens)
masked_sentences.append(masked_sentence)
return masked_sentences
def predict_masked_words(masked_sentences):
predictions = []
for masked_sentence in masked_sentences:
inputs = tokenizer.encode(masked_sentence, return_tensors='pt').to(device)
with torch.no_grad():
outputs = masked_lm_model(inputs)
predictions.append(outputs.logits)
return predictions
def get_sentence_embedding(sentence):
inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device)
with torch.no_grad():
outputs = embedding_model(**inputs)
return outputs.last_hidden_state.squeeze().cpu().numpy()
def calculate_semantic_change(average_embedding, target_sentence):
target_embedding = get_sentence_embedding(target_sentence)
if DISTANCE == "L2":
semantic_change = np.linalg.norm(average_embedding - target_embedding) # Euclidean distance
elif DISTANCE == "L1":
semantic_change = np.linalg.norm(average_embedding - target_embedding, ord=1) # Manhattan distance
return semantic_change
def calculate_sequence_probability(predictions, masked_sentences, target_sentence):
total_score = 0
tokens = tokenizer.tokenize(target_sentence)
for i, logits in enumerate(predictions):
masked_index = masked_sentences[i].split().index('<mask>')
softmax = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)
original_token_id = tokenizer.convert_tokens_to_ids(tokens[i])
score = softmax[original_token_id].item()
total_score += score
average_score = total_score / len(masked_sentences)
return average_score
def calculate_inverse_perplexity(sentence):
inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device)
with torch.no_grad():
outputs = masked_lm_model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
perplexity = torch.exp(loss)
return 1/perplexity.item()
def get_average_embedding(ref_df):
ref_sequences = ref_df["sequence"].tolist()
print(f"Calculating average embedding for {len(ref_sequences)} reference sequences...")
embeddings = []
for i, sentence in enumerate(ref_sequences):
emb = get_sentence_embedding(sentence)
embeddings.append(emb)
print(f"Embedding calculated for reference sequence {i} with shape {emb.shape}")
average_embedding = np.mean(embeddings, axis=0)
return average_embedding
def get_sc_sp_ip(average_embedding, target_dataset):
# Process each item in e
sc_scores = []
sp_scores = []
ip_scores = []
for i, target in target_dataset.iterrows():
target_sentence = target["sequence"]
target_accession = target["accession_id"]
# Semantic change score
sc_score = calculate_semantic_change(average_embedding, target_sentence)
sc_scores.append(sc_score)
# Sequence probability score
masked_sentences = mask_sentence(target_sentence)
predictions = predict_masked_words(masked_sentences)
sp_score = calculate_sequence_probability(predictions, masked_sentences, target_sentence)
sp_scores.append(sp_score)
# Perplexity score
ip_score = calculate_inverse_perplexity(target_sentence)
ip_scores.append(ip_score)
print(f"Target sequence {target_accession} -> sc: {sc_score:.6f}, sp: {sp_score:.6f}, ip: {ip_score:.6f}")
return sc_scores, sp_scores, ip_scores
def get_results_df(target_dataset, sc_scores, sp_scores, ip_scores):
# Create a DataFrame with the results
results_df = target_dataset.copy()
results_df["sc"] = sc_scores
results_df["sp"] = sp_scores
results_df["ip"] = ip_scores
# Calculate the mean of sc and ip
results_df["gr"] = (results_df["sp"] + results_df["ip"]) / 2
# add log10 scores
results_df["log10(sc)"] = np.log10(results_df["sc"])
results_df["log10(sp)"] = np.log10(results_df["sp"])
results_df["log10(ip)"] = np.log10(results_df["ip"])
results_df['log10(gr)'] = np.log10(results_df['gr'])
# add rank_by_sc, rank_by_sp and rank_by_ip
results_df["rank_by_sc"] = results_df["sc"].rank(ascending=False)
results_df["rank_by_sp"] = results_df["sp"].rank(ascending=False)
results_df["rank_by_ip"] = results_df["ip"].rank(ascending=False)
results_df["rank_by_gr"] = results_df["gr"].rank(ascending=False)
# make ranks integers
results_df["rank_by_sc"] = results_df["rank_by_sc"].astype(int)
results_df["rank_by_sp"] = results_df["rank_by_sp"].astype(int)
results_df["rank_by_ip"] = results_df["rank_by_ip"].astype(int)
results_df["rank_by_gr"] = results_df["rank_by_gr"].astype(int)
# add rank_by_sc_sp, rank_by_sc_ip, and rank_by_sc_gr by adding the ranks of sc and sp/ip/gr
results_df["rank_by_scsp"] = results_df["rank_by_sc"] + results_df["rank_by_sp"]
results_df["rank_by_scip"] = results_df["rank_by_sc"] + results_df["rank_by_ip"]
results_df["rank_by_scgr"] = results_df["rank_by_sc"] + results_df["rank_by_gr"]
# Drop the sequence column
results_df = results_df.drop(columns=["sequence"])
# Apply rounding
# results_df = results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x)
# By default sort by rank_by_sc_gr
results_df = results_df.sort_values(by="rank_by_scgr")
return results_df
def process_target_data(average_embedding, target_dataset):
sc_scores, sp_scores, ip_scores = get_sc_sp_ip(average_embedding, target_dataset)
results_df = get_results_df(target_dataset, sc_scores, sp_scores, ip_scores)
return results_df