Spaces:
Build error
Build error
from transformers import pipeline | |
from transformers import AutoModelForSeq2SeqLM | |
from transformers import AutoTokenizer | |
from textblob import TextBlob | |
from hatesonar import Sonar | |
import gradio as gr | |
import torch | |
# Load trained model | |
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer") | |
tokenizer = AutoTokenizer.from_pretrained("output/reframer") | |
reframer = pipeline('summarization', model=model, tokenizer=tokenizer) | |
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text | |
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text | |
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text | |
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive | |
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length." | |
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment." | |
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text." | |
CACHE = [] # A list storing the most recent 5 reframing history | |
MAX_STORE = 5 # The maximum number of history user would like to store | |
BEST_N = 3 # The number of best decodes user would like to seee | |
def input_error_message(error_type): | |
# type: (str) -> str | |
"""Generate an input error message from error type.""" | |
return "[Error]: Invalid Input. " + error_type | |
def update_cache(cache, new_record): | |
# type: List[List[str, str, str]] -> List[List[str, str, str]] | |
"""Update the cache to store the most recent five reframing histories.""" | |
cache.append(new_record) | |
if len(cache) > MAX_STORE: | |
cache = cache[1:] | |
return cache | |
def reframe(input_text, strategy): | |
# type: (str, str) -> str | |
"""Reframe the input text with a specified strategy. | |
The strategy will be concetenated to the input text and passed to a finetuned BART model. | |
The reframed positive text will be returned. | |
""" | |
text_with_strategy = input_text + "Strategy: ['" + strategy + "']" | |
# Input Control | |
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea. | |
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND: | |
return input_text + input_error_message(LENGTH_ERROR) | |
# The input text cannot be too positive to ensure the text can be positively reframed. | |
if TextBlob(input_text).sentiment.polarity > 0.2: | |
return input_text + input_error_message(SENTIMENT_ERROR) | |
# The input text cannot be offensive. | |
sonar = Sonar() | |
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language | |
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD: | |
return input_text + input_error_message(OFFENSIVE_ERROR) | |
# Reframing | |
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output | |
reframed_text = reframer(text_with_strategy)[0]['summary_text'] | |
# Update cache | |
global CACHE | |
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text]) | |
return reframed_text | |
def show_reframe_change(input_text, strategy): | |
# type: (str, str) -> List[Tuple[str, str]] | |
"""Compare the addition and deletion of characters in input_text to form reframed_text. | |
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text. | |
""" | |
reframed_text = reframe(input_text, strategy) | |
from difflib import Differ | |
d = Differ() | |
return [ | |
(token[2:], token[0] if token[0] != " " else None) | |
for token in d.compare(input_text, reframed_text) | |
] | |
def show_n_best_decodes(input_text, strategy): | |
# type: (str, str) -> str | |
prompt = [input_text + "Strategy: ['" + strategy + "']"] | |
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), | |
do_sample=True, | |
num_return_sequences=BEST_N | |
) | |
best_n_result = "" | |
for i in range(len(n_best_decodes)): | |
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True) | |
if i < BEST_N - 1: | |
best_n_result += "\n" | |
return best_n_result | |
def show_history(cache): | |
# type: List[List[str, str, str]] -> str | |
history = "" | |
for i in cache: | |
input_text, strategy, reframed_text = i | |
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n" | |
return gr.Textbox.update(value=history, visible=True) | |
# Build Gradio interface | |
with gr.Blocks() as demo: | |
# Instruction | |
gr.Markdown( | |
''' | |
# Positive Reframing | |
**Start inputing negative texts to see how you can see the same event from a positive angle.** | |
''') | |
# Input text to be reframed | |
text = gr.Textbox(label="Original Text") | |
# Input strategy for the reframing | |
gr.Markdown( | |
''' | |
**Choose one of the six strategies to carry out reframing:** \n | |
**Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n | |
**Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n | |
**Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n | |
**Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n | |
**Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n | |
**Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc. | |
''') | |
strategy = gr.Radio( | |
["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?" | |
) | |
# Trigger button for reframing | |
greet_btn = gr.Button("Reframe") | |
best_output = gr.HighlightedText( | |
label="Diff", | |
combine_adjacent=True, | |
).style(color_map={"+": "green", "-": "red"}) | |
greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output) | |
# Trigger button for showing n best reframings | |
greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N)) | |
n_best_output = gr.Textbox(interactive=False) | |
greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output) | |
# Default examples of text and strategy pairs for user to have a quick start | |
gr.Markdown("## Examples") | |
gr.Examples( | |
[["I have a lot of homework to do today.", "self_affirmation"], ["So stressed about the midterm next week.", "optimism"], ["I failed my math quiz I am such a loser.", "growth"]], | |
[text, strategy], best_output, show_reframe_change, cache_examples=False, run_on_click=False | |
) | |
# Link to paper and Github repo | |
gr.Markdown( | |
''' | |
For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames). | |
''') | |
demo.launch() |