File size: 7,673 Bytes
7b6f763
 
 
0fc771a
6e4fc64
 
 
7b6f763
 
7df388d
 
7b6f763
 
 
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eafc169
6e4fc64
 
 
 
 
 
 
7b6f763
6e4fc64
 
 
7b6f763
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6f763
6e4fc64
 
 
 
 
 
 
 
7b6f763
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
7b6f763
 
6e4fc64
 
c7135b5
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6f763
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch

# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)


CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive

LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."

CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store

BEST_N = 3 # The number of best decodes user would like to seee


def input_error_message(error_type):
    # type: (str) -> str
    """Generate an input error message from error type."""
    return "[Error]: Invalid Input. " + error_type

def update_cache(cache, new_record):
    # type: List[List[str, str, str]] -> List[List[str, str, str]]
    """Update the cache to store the most recent five reframing histories."""
    cache.append(new_record)
    if len(cache) > MAX_STORE:
      cache = cache[1:]
    return cache

def reframe(input_text, strategy):
    # type: (str, str) -> str
    """Reframe the input text with a specified strategy.

    The strategy will be concetenated to the input text and passed to a finetuned BART model.

    The reframed positive text will be returned.
    """
    text_with_strategy = input_text +  "Strategy: ['" + strategy + "']"

    # Input Control
    # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
    if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
      return input_error_message(LENGTH_ERROR)
    # The input text cannot be too positive to ensure the text can be positively reframed.
    if TextBlob(input_text).sentiment.polarity > 0.2:
      return input_error_message(SENTIMENT_ERROR)
    # The input text cannot be offensive.
    sonar = Sonar()
    # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
    if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
      return input_error_message(OFFENSIVE_ERROR)
    
    # Reframing
    # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
    reframed_text = reframer(text_with_strategy)[0]['summary_text']

    # Update cache
    global CACHE
    CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])

    return reframed_text


def show_reframe_change(input_text, strategy):
    # type: (str, str) -> List[Tuple[str, str]]
    """Compare the addition and deletion of characters in input_text to form reframed_text.

    The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
    """  
    reframed_text = reframe(input_text, strategy)
    from difflib import Differ
    d = Differ()
    return [
        (token[2:], token[0] if token[0] != " " else None)
        for token in d.compare(input_text, reframed_text)
    ]

def show_n_best_decodes(input_text, strategy):
    # type: (str, str) -> str
    prompt = [input_text +  "Strategy: ['" + strategy + "']"]
    n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), 
                                    do_sample=True,
                                    num_return_sequences=BEST_N
                                    )
    best_n_result = ""
    for i in range(len(n_best_decodes)):
        best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
        if i < BEST_N - 1:
            best_n_result += "\n"
    return best_n_result

def show_history(cache):
    # type: List[List[str, str, str]] -> str
    history = ""
    for i in cache:
      input_text, strategy, reframed_text = i
      history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
    return gr.Textbox.update(value=history, visible=True)


# Build Gradio interface
with gr.Blocks() as demo:
    # Instruction
    gr.Markdown(
    '''
    # Positive Reframing
    Start inputing negative texts to see how you can see the same event from a positive angle.
    ''')

    # Input text to be reframed
    text = gr.Textbox(label="Original Text")

    # Input strategy for the reframing
    gr.Markdown(
    '''
    Choose one of the six strategies to carry out reframing: \n
    **Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n
    **Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n 
    **Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n 
    **Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n 
    **Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n 
    **Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc.
    ''')
    strategy = gr.Radio(
        ["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?"
    )

    # Trigger button for reframing
    greet_btn = gr.Button("Reframe")
    best_output = gr.HighlightedText(
        label="Diff",
        combine_adjacent=True,
        ).style(color_map={"+": "green", "-": "red"})
    greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output)

    # Trigger button for showing n best reframings
    greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N))
    n_best_output = gr.Textbox(interactive=False)
    greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output)

    # Default examples of text and strategy pairs for user to have a quick start
    gr.Markdown("## Examples")
    gr.Examples(
        [["I have a lot of homework to do today.", "self_affirmation"], ["This has been the longest and most stressful week of my life!", "optimism"], ["So stressed about the midterms next week.", "thankfulness"]], 
        [text, strategy], output, show_reframe_change, cache_examples=False, run_on_click=False
    )

    # Link to paper and Github repo
    gr.Markdown(
    '''
    For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames).
    ''')

demo.launch()