coyotte508 HF staff commited on
Commit
c28c458
·
1 Parent(s): 7ba8d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -164
app.py CHANGED
@@ -1,167 +1,7 @@
1
- from transformers import pipeline
2
- from transformers import AutoModelForSeq2SeqLM
3
- from transformers import AutoTokenizer
4
- from textblob import TextBlob
5
- from hatesonar import Sonar
6
  import gradio as gr
7
- import torch
8
 
9
- # Load trained model
10
- model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
11
- tokenizer = AutoTokenizer.from_pretrained("output/reframer")
12
- reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
13
 
14
-
15
- CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
16
- CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
17
- SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
18
- OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
19
-
20
- LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
21
- SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
22
- OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
23
-
24
- CACHE = [] # A list storing the most recent 5 reframing history
25
- MAX_STORE = 5 # The maximum number of history user would like to store
26
-
27
- BEST_N = 3 # The number of best decodes user would like to seee
28
-
29
-
30
- def input_error_message(error_type):
31
- # type: (str) -> str
32
- """Generate an input error message from error type."""
33
- return "[Error]: Invalid Input. " + error_type
34
-
35
- def update_cache(cache, new_record):
36
- # type: List[List[str, str, str]] -> List[List[str, str, str]]
37
- """Update the cache to store the most recent five reframing histories."""
38
- cache.append(new_record)
39
- if len(cache) > MAX_STORE:
40
- cache = cache[1:]
41
- return cache
42
-
43
- def reframe(input_text, strategy):
44
- # type: (str, str) -> str
45
- """Reframe the input text with a specified strategy.
46
-
47
- The strategy will be concetenated to the input text and passed to a finetuned BART model.
48
-
49
- The reframed positive text will be returned.
50
- """
51
- text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
52
-
53
- # Input Control
54
- # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
55
- if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
56
- return input_text + input_error_message(LENGTH_ERROR)
57
- # The input text cannot be too positive to ensure the text can be positively reframed.
58
- if TextBlob(input_text).sentiment.polarity > 0.2:
59
- return input_text + input_error_message(SENTIMENT_ERROR)
60
- # The input text cannot be offensive.
61
- sonar = Sonar()
62
- # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
63
- if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
64
- return input_text + input_error_message(OFFENSIVE_ERROR)
65
-
66
- # Reframing
67
- # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
68
- reframed_text = reframer(text_with_strategy)[0]['summary_text']
69
-
70
- # Update cache
71
- global CACHE
72
- CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
73
-
74
- return reframed_text
75
-
76
-
77
- def show_reframe_change(input_text, strategy):
78
- # type: (str, str) -> List[Tuple[str, str]]
79
- """Compare the addition and deletion of characters in input_text to form reframed_text.
80
-
81
- The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
82
- """
83
- reframed_text = reframe(input_text, strategy)
84
- from difflib import Differ
85
- d = Differ()
86
- return [
87
- (token[2:], token[0] if token[0] != " " else None)
88
- for token in d.compare(input_text, reframed_text)
89
- ]
90
-
91
- def show_n_best_decodes(input_text, strategy):
92
- # type: (str, str) -> str
93
- prompt = [input_text + "Strategy: ['" + strategy + "']"]
94
- n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
95
- do_sample=True,
96
- num_return_sequences=BEST_N
97
- )
98
- best_n_result = ""
99
- for i in range(len(n_best_decodes)):
100
- best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
101
- if i < BEST_N - 1:
102
- best_n_result += "\n"
103
- return best_n_result
104
-
105
- def show_history(cache):
106
- # type: List[List[str, str, str]] -> str
107
- history = ""
108
- for i in cache:
109
- input_text, strategy, reframed_text = i
110
- history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
111
- return gr.Textbox.update(value=history, visible=True)
112
-
113
-
114
- # Build Gradio interface
115
- with gr.Blocks() as demo:
116
- # Instruction
117
- gr.Markdown(
118
- '''
119
- # Positive Reframing
120
- **Start inputing negative texts to see how you can see the same event from a positive angle.**
121
- ''')
122
-
123
- # Input text to be reframed
124
- text = gr.Textbox(label="Original Text")
125
-
126
- # Input strategy for the reframing
127
- gr.Markdown(
128
- '''
129
- **Choose one of the six strategies to carry out reframing:** \n
130
- **Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n
131
- **Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n
132
- **Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n
133
- **Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n
134
- **Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n
135
- **Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc.
136
- ''')
137
- strategy = gr.Radio(
138
- ["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?"
139
- )
140
-
141
- # Trigger button for reframing
142
- greet_btn = gr.Button("Reframe")
143
- best_output = gr.HighlightedText(
144
- label="Diff",
145
- combine_adjacent=True,
146
- ).style(color_map={"+": "green", "-": "red"})
147
- greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output)
148
-
149
- # Trigger button for showing n best reframings
150
- greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N))
151
- n_best_output = gr.Textbox(interactive=False)
152
- greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output)
153
-
154
- # Default examples of text and strategy pairs for user to have a quick start
155
- gr.Markdown("## Examples")
156
- gr.Examples(
157
- [["I have a lot of homework to do today.", "self_affirmation"], ["So stressed about the midterm next week.", "optimism"], ["I failed my math quiz I am such a loser.", "growth"]],
158
- [text, strategy], best_output, show_reframe_change, cache_examples=False, run_on_click=False
159
- )
160
-
161
- # Link to paper and Github repo
162
- gr.Markdown(
163
- '''
164
- For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames).
165
- ''')
166
-
167
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
 
 
5
 
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()