Spaces:
Sleeping
Sleeping
import os | |
import json | |
import random | |
import torch | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from transformers import RobertaForMaskedLM, RobertaModel, RobertaConfig | |
# Load the SentenceTransformer model | |
DISTANCE= "L2" # L1 for Manhattan distance, L2 for Euclidean distance | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# For reproducibility | |
seed = 42 | |
torch.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
print(f"Loading model from checkpoint...") | |
sbert_model = SentenceTransformer("model", device=device) | |
transformer_model = sbert_model[0].auto_model.to(device) # Extract the core transformer model | |
# Get the configuration of the SentenceTransformer model | |
sbert_config = transformer_model.config | |
# Create a custom RobertaConfig with the same parameters | |
custom_config = RobertaConfig( | |
vocab_size=sbert_config.vocab_size, | |
hidden_size=sbert_config.hidden_size, | |
num_hidden_layers=sbert_config.num_hidden_layers, | |
num_attention_heads=sbert_config.num_attention_heads, | |
intermediate_size=sbert_config.intermediate_size, | |
max_position_embeddings=sbert_config.max_position_embeddings, | |
type_vocab_size=sbert_config.type_vocab_size, | |
initializer_range=sbert_config.initializer_range, | |
layer_norm_eps=sbert_config.layer_norm_eps, | |
hidden_dropout_prob=sbert_config.hidden_dropout_prob, | |
attention_probs_dropout_prob=sbert_config.attention_probs_dropout_prob, | |
) | |
# Initialize RobertaForMaskedLM and RobertaModel with the custom configuration | |
masked_lm_model = RobertaForMaskedLM(custom_config).to(device) | |
embedding_model = RobertaModel(custom_config).to(device) | |
tokenizer = sbert_model[0].tokenizer | |
# Copy the weights from transformer_model to masked_lm_model and embedding_model | |
transformer_model_state_dict = transformer_model.state_dict() | |
masked_lm_model.roberta.load_state_dict(transformer_model_state_dict, strict=False) | |
embedding_model.load_state_dict(transformer_model_state_dict, strict=False) | |
print(f"Weights successfully copied from checkpoint!") | |
# Set the models in evaluation mode | |
masked_lm_model.eval() | |
embedding_model.eval() | |
def mask_sentence(sentence): | |
tokens = tokenizer.tokenize(sentence) | |
masked_sentences = [] | |
for i in range(len(tokens)): | |
masked_tokens = tokens.copy() | |
masked_tokens[i] = '<mask>' | |
masked_sentence = ' '.join(masked_tokens) | |
masked_sentences.append(masked_sentence) | |
return masked_sentences | |
def predict_masked_words(masked_sentences): | |
predictions = [] | |
for masked_sentence in masked_sentences: | |
inputs = tokenizer.encode(masked_sentence, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
outputs = masked_lm_model(inputs) | |
predictions.append(outputs.logits) | |
return predictions | |
def get_sentence_embedding(sentence): | |
inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device) | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
return outputs.last_hidden_state.squeeze().cpu().numpy() | |
def calculate_semantic_change(average_embedding, target_sentence): | |
target_embedding = get_sentence_embedding(target_sentence) | |
if DISTANCE == "L2": | |
semantic_change = np.linalg.norm(average_embedding - target_embedding) # Euclidean distance | |
elif DISTANCE == "L1": | |
semantic_change = np.linalg.norm(average_embedding - target_embedding, ord=1) # Manhattan distance | |
return semantic_change | |
def calculate_sequence_probability(predictions, masked_sentences, target_sentence): | |
total_score = 0 | |
tokens = tokenizer.tokenize(target_sentence) | |
for i, logits in enumerate(predictions): | |
masked_index = masked_sentences[i].split().index('<mask>') | |
softmax = torch.nn.functional.softmax(logits[0, masked_index], dim=-1) | |
original_token_id = tokenizer.convert_tokens_to_ids(tokens[i]) | |
score = softmax[original_token_id].item() | |
total_score += score | |
average_score = total_score / len(masked_sentences) | |
return average_score | |
def calculate_inverse_perplexity(sentence): | |
inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device) | |
with torch.no_grad(): | |
outputs = masked_lm_model(**inputs, labels=inputs["input_ids"]) | |
loss = outputs.loss | |
perplexity = torch.exp(loss) | |
return 1/perplexity.item() | |
def get_average_embedding(ref_df): | |
ref_sequences = ref_df["sequence"].tolist() | |
print(f"Calculating average embedding for {len(ref_sequences)} reference sequences...") | |
embeddings = [] | |
for i, sentence in enumerate(ref_sequences): | |
emb = get_sentence_embedding(sentence) | |
embeddings.append(emb) | |
print(f"Embedding calculated for reference sequence {i} with shape {emb.shape}") | |
average_embedding = np.mean(embeddings, axis=0) | |
return average_embedding | |
def get_sc_sp_ip(average_embedding, target_dataset): | |
# Process each item in e | |
sc_scores = [] | |
sp_scores = [] | |
ip_scores = [] | |
for i, target in target_dataset.iterrows(): | |
target_sentence = target["sequence"] | |
target_accession = target["accession_id"] | |
# Semantic change score | |
sc_score = calculate_semantic_change(average_embedding, target_sentence) | |
sc_scores.append(sc_score) | |
# Sequence probability score | |
masked_sentences = mask_sentence(target_sentence) | |
predictions = predict_masked_words(masked_sentences) | |
sp_score = calculate_sequence_probability(predictions, masked_sentences, target_sentence) | |
sp_scores.append(sp_score) | |
# Perplexity score | |
ip_score = calculate_inverse_perplexity(target_sentence) | |
ip_scores.append(ip_score) | |
print(f"Target sequence {target_accession} -> sc: {sc_score:.6f}, sp: {sp_score:.6f}, ip: {ip_score:.6f}") | |
return sc_scores, sp_scores, ip_scores | |
def get_results_df(target_dataset, sc_scores, sp_scores, ip_scores): | |
# Create a DataFrame with the results | |
results_df = target_dataset.copy() | |
results_df["sc"] = sc_scores | |
results_df["sp"] = sp_scores | |
results_df["ip"] = ip_scores | |
# Calculate the mean of sc and ip | |
results_df["gr"] = (results_df["sp"] + results_df["ip"]) / 2 | |
# add log10 scores | |
results_df["log10(sc)"] = np.log10(results_df["sc"]) | |
results_df["log10(sp)"] = np.log10(results_df["sp"]) | |
results_df["log10(ip)"] = np.log10(results_df["ip"]) | |
results_df['log10(gr)'] = np.log10(results_df['gr']) | |
# add rank_by_sc, rank_by_sp and rank_by_ip | |
results_df["rank_by_sc"] = results_df["sc"].rank(ascending=False) | |
results_df["rank_by_sp"] = results_df["sp"].rank(ascending=False) | |
results_df["rank_by_ip"] = results_df["ip"].rank(ascending=False) | |
results_df["rank_by_gr"] = results_df["gr"].rank(ascending=False) | |
# make ranks integers | |
results_df["rank_by_sc"] = results_df["rank_by_sc"].astype(int) | |
results_df["rank_by_sp"] = results_df["rank_by_sp"].astype(int) | |
results_df["rank_by_ip"] = results_df["rank_by_ip"].astype(int) | |
results_df["rank_by_gr"] = results_df["rank_by_gr"].astype(int) | |
# add rank_by_sc_sp, rank_by_sc_ip, and rank_by_sc_gr by adding the ranks of sc and sp/ip/gr | |
results_df["rank_by_scsp"] = results_df["rank_by_sc"] + results_df["rank_by_sp"] | |
results_df["rank_by_scip"] = results_df["rank_by_sc"] + results_df["rank_by_ip"] | |
results_df["rank_by_scgr"] = results_df["rank_by_sc"] + results_df["rank_by_gr"] | |
# Drop the sequence column | |
results_df = results_df.drop(columns=["sequence"]) | |
# Apply rounding | |
# results_df = results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x) | |
# By default sort by rank_by_sc_gr | |
results_df = results_df.sort_values(by="rank_by_scgr") | |
return results_df | |
def process_target_data(average_embedding, target_dataset): | |
sc_scores, sp_scores, ip_scores = get_sc_sp_ip(average_embedding, target_dataset) | |
results_df = get_results_df(target_dataset, sc_scores, sp_scores, ip_scores) | |
return results_df | |