Spaces:
Runtime error
Runtime error
File size: 16,413 Bytes
2e4274a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
# ###########################################################################
#
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
# (C) Cloudera, Inc. 2022
# All rights reserved.
#
# Applicable Open Source License: Apache 2.0
#
# NOTE: Cloudera open source products are modular software products
# made up of hundreds of individual components, each of which was
# individually copyrighted. Each Cloudera open source product is a
# collective work under U.S. Copyright Law. Your license to use the
# collective work is as provided in your written agreement with
# Cloudera. Used apart from the collective work, this file is
# licensed for your use pursuant to the open source license
# identified above.
#
# This code is provided to you pursuant a written agreement with
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
# this code. If you do not have a written agreement with Cloudera nor
# with an authorized and properly licensed third party, you do not
# have any rights to access nor to use this code.
#
# Absent a written agreement with Cloudera, Inc. (βClouderaβ) to the
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
# ###########################################################################
import pandas as pd
from PIL import Image
import streamlit as st
import streamlit.components.v1 as components
from apps.data_utils import (
DATA_PACKET,
format_classification_results,
)
from apps.app_utils import (
DisableableButton,
reset_page_progress_state,
get_cached_style_intensity_classifier,
get_cached_word_attributions,
get_sti_metric,
get_cps_metric,
generate_style_transfer,
)
from apps.visualization_utils import build_altair_classification_plot
# SESSION STATE UTILS
if "page_progress" not in st.session_state:
st.session_state.page_progress = 1
if "st_result" not in st.session_state:
st.session_state.st_result = False
# PAGE CONFIG
ffl_favicon = Image.open("static/images/cldr-favicon.ico")
st.set_page_config(
page_title="CFFL: Text Style Transfer",
page_icon=ffl_favicon,
layout="centered",
initial_sidebar_state="expanded",
)
# SIDEBAR
ffl_logo = Image.open("static/images/ffllogo2@1x.png")
st.sidebar.image(ffl_logo)
st.sidebar.write(
"This prototype accompanies our [Text Style Transfer](https://blog.fastforwardlabs.com/2022/03/22/an-introduction-to-text-style-transfer.html)\
blog series in which we explore the task of automatically neutralizing subjectivity bias in free text."
)
st.sidebar.markdown("## Select a style attribute")
style_attribute = st.sidebar.selectbox(
"What style would you like to transfer between?", options=DATA_PACKET.keys()
)
STYLE_ATTRIBUTE_DATA = DATA_PACKET[style_attribute]
st.sidebar.markdown("## Start over")
st.sidebar.caption(
"This application is intended to be run sequentially from top to bottom. If you wish to alter selections after \
advancing through the app, push the button below to start fresh from the beginning."
)
st.sidebar.button("Restart from beginning", on_click=reset_page_progress_state)
# MAIN CONTENT
st.markdown("# Exploring Intelligent Writing Assistance")
st.write(
"""
The goal of this application is to demonstrate how the NLP task of _text style transfer_ can be applied to enhance the human writing experience.
In this sense, we intend to peel back the curtains on how an intelligent writing assistant might function β walking through the logical steps needed to
automatically re-style a piece of text while building up confidence in the model output.
We emphasize the imperative for a human-in-the-loop user experience when designing natural language generation systems. We believe text style
transfer has the potential to empower writers to better express themselves, but not by blindly generating text. Rather, generative models, in conjunction with
interpretability methods, should be combined to help writers understand the nuances of linguistic style and suggest stylistic edits that _may_ improve their writing.
Select a style attribute from the sidebar and enter some text below to get started!
"""
)
## 1. INPUT EXAMPLE
if st.session_state.page_progress > 0:
st.write("### 1. Input some text")
with st.container():
col1_1, col1_2 = st.columns([1, 3])
with col1_1:
input_type = st.radio(
"Type your own or choose from a preset example",
("Choose preset", "Enter text"),
horizontal=False,
)
with col1_2:
if input_type == "Enter text":
text_sample = st.text_input(
f"Enter some text to modify style from {style_attribute}",
help="You can also select one of our preset examples by toggling the radio button to the left.",
)
else:
option = st.selectbox(
f"Select a preset example to modify style from {style_attribute}",
[
f"Example {i+1}"
for i in range(len(STYLE_ATTRIBUTE_DATA.examples))
],
help="You can also enter your own text by toggling the radio button to the left.",
)
idx_key = int(option.split(" ")[-1]) - 1
text_sample = STYLE_ATTRIBUTE_DATA.examples[idx_key]
st.text_area(
"Preview Text",
value=text_sample,
placeholder="Enter some text above or toggle to choose a preset!",
disabled=True,
)
if text_sample != "":
db1 = DisableableButton(1, "Let's go!")
db1.create_enabled_button()
## 2. CLASSIFY INPUT
if st.session_state.page_progress > 1:
db1.disable()
st.write("### 2. Detect style")
st.write(
f"""
Before we can transfer style, we need to ensure the input text isn't already of the target style! To do so,
we classify the sample text with a model that has been fine-tuned to differentiate between
{STYLE_ATTRIBUTE_DATA.attribute_AND_string} tones.
In a product setting, you could imagine this style detection process running continuously inside your favorite word processor as you write,
prompting you for action when it detects language that is at odds with your desired tone of voice.
"""
)
with st.spinner("Detecting style, hang tight!"):
sic = get_cached_style_intensity_classifier(style_data=STYLE_ATTRIBUTE_DATA)
cls_result = sic.score(text_sample)
cls_result_df = pd.DataFrame(
cls_result[0]["distribution"],
columns=["Score"],
index=[v for k, v in sic.pipeline.model.config.id2label.items()],
)
with st.container():
format_cls_result = format_classification_results(
id2label=sic.pipeline.model.config.id2label, cls_result=cls_result
)
st.markdown("##### Distribution Between Style Classes")
chart = build_altair_classification_plot(format_cls_result)
st.altair_chart(chart.interactive(), use_container_width=True)
st.markdown(
f"""
- **:hugging_face: Model:** [{STYLE_ATTRIBUTE_DATA.cls_model_path}]({STYLE_ATTRIBUTE_DATA.build_model_url("cls")})
- **Input Text:** *"{text_sample}"*
- **Classification Result:** \t {cls_result[0]["label"].capitalize()} ({round(cls_result[0]["score"]*100, 2)}%)
"""
)
st.write(" ")
if cls_result[0]["label"].lower() != STYLE_ATTRIBUTE_DATA.target_attribute:
st.info(
f"""
**Looks like we have room for improvement!**
The distribution of classification scores suggests that the input text is more {STYLE_ATTRIBUTE_DATA.attribute_THAN_string}. Therefore,
an automated style transfer may improve our intended tone of voice."""
)
db2 = DisableableButton(2, "Let's see why")
db2.create_enabled_button()
else:
st.success(
f"""**No work to be done!** \n\n\n The distribution of classification scores suggests that the input text is less \
{STYLE_ATTRIBUTE_DATA.attribute_THAN_string}. Therefore, there's no need to transfer style. \
Enter a different text prompt or select one of the preset examples to re-run the analysis with."""
)
## 3. Here's why
if st.session_state.page_progress > 2:
db2.disable()
st.write("### 3. Interpret the classification result")
st.write(
f"""
Interpreting our model's output is a crucial practice that helps build trust and justify taking real-world action from the
model predictions. In this case, we apply a popular model interpretability technique called [Integrated Gradients](https://arxiv.org/pdf/1703.01365.pdf)
to the Transformer-based classifier to explain the model's prediction in terms of its features."""
)
with st.spinner("Interpreting the prediction, hang tight!"):
word_attributions_visual = get_cached_word_attributions(
text_sample=text_sample, style_data=STYLE_ATTRIBUTE_DATA
)
components.html(html=word_attributions_visual, height=200, scrolling=True)
st.write(
f"""
The visual above displays word attributions using the [Transformers Interpret](https://github.com/cdpierse/transformers-interpret) library.
Positive attribution values (green) indicate tokens that contribute positively towards the
predicted class ({STYLE_ATTRIBUTE_DATA.source_attribute}), while negative values (red) indicate tokens that contribute negatively towards the predicted class.
Visualizing word attributions is a helpful way to build intuition about what makes the input text _{STYLE_ATTRIBUTE_DATA.source_attribute}_!"""
)
db3 = DisableableButton(3, "Next")
db3.create_enabled_button()
## 4. SUGGEST GENERATED EDIT
if st.session_state.page_progress > 3:
db3.disable()
st.write("### 4. Generate a suggestion")
st.write(
f"Now that we've verified the input text is in fact *{STYLE_ATTRIBUTE_DATA.source_attribute}* and understand why that's the case, we can utilize a \
text style transfer model to generate a suggested replacement that retains the same semantic meaning, but achieves the *{STYLE_ATTRIBUTE_DATA.target_attribute}* target style.\
\n\n Expand the accordian below to toggle generation parameters, then click the button to transfer style!"
)
with st.expander("Toggle generation parameters"):
# st.markdown("##### Text generation parameters")
st.write("**max_gen_length**")
max_gen_length = st.slider(
"Whats the maximum generation length desired?", 1, 250, 200, 10
)
st.write("**num_beams**")
num_beams = st.slider(
"How many beams to use for beam-search decoding?", 1, 8, 4
)
st.write("**temperature**")
temperature = st.slider(
"What sensitivity value to model next token probabilities?",
0.0,
1.0,
1.0,
)
st.markdown(
f"""
**:hugging_face: Model:** [{STYLE_ATTRIBUTE_DATA.seq2seq_model_path}]({STYLE_ATTRIBUTE_DATA.build_model_url("seq2seq")})
"""
)
col4_1, col4_2, col4_3 = st.columns([1, 5, 4])
with col4_2:
st.markdown(
f"""
- **Max Generation Length:** {max_gen_length}
- **Number of Beams:** {num_beams}
- **Temperature:** {temperature}
"""
)
with col4_3:
with st.container():
st.write("")
st.button(
"Generate style transfer",
key="generate_text",
on_click=generate_style_transfer,
kwargs={
"text_sample": text_sample,
"style_data": STYLE_ATTRIBUTE_DATA,
"max_gen_length": max_gen_length,
"num_beams": num_beams,
"temperature": temperature,
},
)
if st.session_state.st_result:
st.warning(
f"""**{STYLE_ATTRIBUTE_DATA.source_attribute.capitalize()} Input:** "{text_sample}" """
)
st.info(
f"""
**{STYLE_ATTRIBUTE_DATA.target_attribute.capitalize()} Suggestion:** "{st.session_state.st_result[0]}" """
)
db4 = DisableableButton(4, "Next")
db4.create_enabled_button()
## 5. EVALUATE THE SUGGESTION
if st.session_state.page_progress > 4:
db4.disable()
st.write("### 5. Evaluate the suggestion")
st.markdown(
"""
Blindly prompting a writer with style suggestions without first checking quality would make for a noisy, error-prone product
with a poor user experience. Ultimately, we only want to suggest high quality edits. But what makes for a suggestion-worthy edit?
A comprehensive quality evaluation for text style transfer output should consider three criteria:
1. **Style strength** - To what degree does the generated text achieve the target style?
2. **Content preservation** - To what degree does the generated text retain the semantic meaning of the source text?
3. **Fluency** - To what degree does the generated text appear as if it were produced naturally by a human?
Below, we apply automated evaluation metrics - _Style Transfer Intensity (STI)_ & _Content Preservation Score (CPS)_ - to
measure the first two of these criteria by comparing the input text to the generated suggestion.
"""
)
with st.spinner("Evaluating text style transfer, hang tight!"):
sti = get_sti_metric(
input_text=text_sample,
output_text=st.session_state.st_result[0],
style_data=STYLE_ATTRIBUTE_DATA,
)
cps = get_cps_metric(
input_text=text_sample,
output_text=st.session_state.st_result[0],
style_data=STYLE_ATTRIBUTE_DATA,
)
st.markdown(
"""<hr style="height:2px;border:none;color:#333;background-color:#333;" /> """,
unsafe_allow_html=True,
)
col5_1, col5_2, col5_3 = st.columns([3, 1, 3])
with col5_1:
st.metric(
label="Style Transfer Intensity (STI)",
value=f"{round(sti[0]*100,2)}%",
)
st.caption(
f"""
STI compares the style class distributions (determined by the [style classifier]({STYLE_ATTRIBUTE_DATA.build_model_url("cls")}))
between the input text and generated suggestion using Earth Mover's Distance. Therefore, STI can be thought of as the percentage of the possible target style distribution
that was captured during the transfer.
"""
)
with col5_3:
st.metric(
label="Content Preservation Score (CPS)",
value=f"{round(cps[0]*100,2)}%",
)
st.caption(
f"""
CPS compares the latent space embeddings (determined by [SentenceBERT]({STYLE_ATTRIBUTE_DATA.build_model_url("sbert")}))
between the input text and generated suggestion using cosine similarity. Therefore, CPS can be thought of as the percentage of content that was perserved
during the style transfer.
"""
)
|