Spaces:
Build error
Build error
File size: 7,673 Bytes
7b6f763 0fc771a 6e4fc64 7b6f763 7df388d 7b6f763 6e4fc64 eafc169 6e4fc64 7b6f763 6e4fc64 7b6f763 6e4fc64 7b6f763 6e4fc64 7b6f763 6e4fc64 7b6f763 6e4fc64 c7135b5 6e4fc64 7b6f763 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store
BEST_N = 3 # The number of best decodes user would like to seee
def input_error_message(error_type):
# type: (str) -> str
"""Generate an input error message from error type."""
return "[Error]: Invalid Input. " + error_type
def update_cache(cache, new_record):
# type: List[List[str, str, str]] -> List[List[str, str, str]]
"""Update the cache to store the most recent five reframing histories."""
cache.append(new_record)
if len(cache) > MAX_STORE:
cache = cache[1:]
return cache
def reframe(input_text, strategy):
# type: (str, str) -> str
"""Reframe the input text with a specified strategy.
The strategy will be concetenated to the input text and passed to a finetuned BART model.
The reframed positive text will be returned.
"""
text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
# Input Control
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
return input_error_message(LENGTH_ERROR)
# The input text cannot be too positive to ensure the text can be positively reframed.
if TextBlob(input_text).sentiment.polarity > 0.2:
return input_error_message(SENTIMENT_ERROR)
# The input text cannot be offensive.
sonar = Sonar()
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
return input_error_message(OFFENSIVE_ERROR)
# Reframing
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
reframed_text = reframer(text_with_strategy)[0]['summary_text']
# Update cache
global CACHE
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
return reframed_text
def show_reframe_change(input_text, strategy):
# type: (str, str) -> List[Tuple[str, str]]
"""Compare the addition and deletion of characters in input_text to form reframed_text.
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
"""
reframed_text = reframe(input_text, strategy)
from difflib import Differ
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(input_text, reframed_text)
]
def show_n_best_decodes(input_text, strategy):
# type: (str, str) -> str
prompt = [input_text + "Strategy: ['" + strategy + "']"]
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
do_sample=True,
num_return_sequences=BEST_N
)
best_n_result = ""
for i in range(len(n_best_decodes)):
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
if i < BEST_N - 1:
best_n_result += "\n"
return best_n_result
def show_history(cache):
# type: List[List[str, str, str]] -> str
history = ""
for i in cache:
input_text, strategy, reframed_text = i
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
return gr.Textbox.update(value=history, visible=True)
# Build Gradio interface
with gr.Blocks() as demo:
# Instruction
gr.Markdown(
'''
# Positive Reframing
Start inputing negative texts to see how you can see the same event from a positive angle.
''')
# Input text to be reframed
text = gr.Textbox(label="Original Text")
# Input strategy for the reframing
gr.Markdown(
'''
Choose one of the six strategies to carry out reframing: \n
**Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n
**Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n
**Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n
**Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n
**Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n
**Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc.
''')
strategy = gr.Radio(
["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?"
)
# Trigger button for reframing
greet_btn = gr.Button("Reframe")
best_output = gr.HighlightedText(
label="Diff",
combine_adjacent=True,
).style(color_map={"+": "green", "-": "red"})
greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output)
# Trigger button for showing n best reframings
greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N))
n_best_output = gr.Textbox(interactive=False)
greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output)
# Default examples of text and strategy pairs for user to have a quick start
gr.Markdown("## Examples")
gr.Examples(
[["I have a lot of homework to do today.", "self_affirmation"], ["This has been the longest and most stressful week of my life!", "optimism"], ["So stressed about the midterms next week.", "thankfulness"]],
[text, strategy], output, show_reframe_change, cache_examples=False, run_on_click=False
)
# Link to paper and Github repo
gr.Markdown(
'''
For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames).
''')
demo.launch() |