Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification | |
import torch.nn.functional as F | |
from copy import copy | |
from torch.nn.functional import softmax | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_mlm_positive = BertForMaskedLM.from_pretrained( | |
'ewriji/heil-A.412C-positive', return_dict=True | |
) | |
bert_mlm_negative = BertForMaskedLM.from_pretrained( | |
'ewriji/heil-A.412C-negative', return_dict=True | |
) | |
classification_model = BertForSequenceClassification.from_pretrained( | |
'ewriji/heil-A.412C-classification', return_dict=True | |
) | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
""" | |
- split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- find :num_tokens: tokens with the highest ratio (see above) | |
- replace them with :k_best: words according to bert_mlm_positive | |
:return: a list of all possible strings (up to k_best * num_tokens) | |
""" | |
words = sentence.split() | |
batch = [] | |
mask_word = [] | |
for i in range(len(words)): | |
masked = copy(words) | |
mask_word.append(masked[i]) | |
masked[i] = tokenizer.mask_token | |
batch.append(masked) | |
input = tokenizer(batch, padding=True, return_tensors="pt", is_split_into_words=True) | |
mask_ids = (input["input_ids"] == tokenizer.mask_token_id).nonzero().cpu() | |
# predict probabilities | |
positive_logits = bert_mlm_positive(**input) | |
negative_logits = bert_mlm_negative(**input) | |
word_idx = [tokenizer.encode(word, add_special_tokens=False)[0] for word in mask_word] | |
positive_prob = softmax( | |
positive_logits.logits[mask_ids[:, 0], mask_ids[:, 1]], | |
dim=-1 | |
) | |
positive_prob = positive_prob[np.arange(len(word_idx)), word_idx] | |
negative_prob = softmax( | |
negative_logits.logits[mask_ids[:, 0], mask_ids[:, 1]], | |
dim=-1 | |
) | |
negative_prob = negative_prob[np.arange(len(word_idx)), word_idx] | |
ratio = (positive_prob + epsilon)/ (negative_prob + epsilon) | |
lowest_ratio = torch.topk(ratio, k=num_tokens, largest=False, dim=-1) | |
# pick top_k | |
logits_indices = mask_ids[lowest_ratio.indices] | |
top_k_probs = positive_logits.logits[logits_indices[:, 0], logits_indices[:, 1]] | |
top_k_probs = softmax(top_k_probs, dim=-1) | |
top_k_probs = torch.topk(top_k_probs, k=k_best, dim=-1) | |
# top get words for every small ratio | |
top_k_words = [] | |
for i in range(top_k_probs.indices.shape[0]): | |
top_words = tokenizer.convert_ids_to_tokens(top_k_probs.indices[i]) | |
top_k_words.append(top_words) | |
# construct replaced sentences | |
replaced_words = [] | |
for word_idx, top_words in zip(lowest_ratio.indices, top_k_words): | |
for word in top_words: | |
replaced_sentence = copy(words) | |
replaced_sentence[word_idx] = word | |
replaced_words.append(' '.join(replaced_sentence)) | |
return replaced_words | |
def evaluate_top(model, sentences): | |
predictions = [] | |
for sentence in sentences: | |
inputs = tokenizer(sentence, padding=True, return_tensors="pt", is_split_into_words=True) | |
prediction = model(**inputs) | |
predictions.append(prediction.logits) | |
predictions = torch.cat(predictions, dim=0) | |
return predictions | |
def get_replacements_with_classifier(model, sentence, num_tokens, k_best, m_best, epsilon=1e-3): | |
replacements = get_replacements(sentence, num_tokens, k_best, epsilon=epsilon) | |
top_m_replacements = [] | |
for i in range(num_tokens): | |
top_k = replacements[i*k_best: (i+1)*k_best] | |
top_k_predictions = evaluate_top(model, top_k)[:, 1].flatten() | |
top_m_prediction_idx = torch.topk(top_k_predictions, k=m_best) | |
for idx in top_m_prediction_idx.indices: | |
top_m_replacements.append(top_k[idx]) | |
return top_m_replacements | |
st.set_page_config(page_title="A + B calculator pro max", layout="centered") | |
st.markdown("## Dude, let's convert some negative vibes to positive") | |
negative = st.text_input("Gimme ya review", value='') | |
positive = get_replacements_with_classifier( | |
classification_model, | |
negative, | |
1, | |
20, | |
1 | |
)[0] | |
st.text(positive) |