Spaces:
Runtime error
Runtime error
# ########################################################################### | |
# | |
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) | |
# (C) Cloudera, Inc. 2022 | |
# All rights reserved. | |
# | |
# Applicable Open Source License: Apache 2.0 | |
# | |
# NOTE: Cloudera open source products are modular software products | |
# made up of hundreds of individual components, each of which was | |
# individually copyrighted. Each Cloudera open source product is a | |
# collective work under U.S. Copyright Law. Your license to use the | |
# collective work is as provided in your written agreement with | |
# Cloudera. Used apart from the collective work, this file is | |
# licensed for your use pursuant to the open source license | |
# identified above. | |
# | |
# This code is provided to you pursuant a written agreement with | |
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute | |
# this code. If you do not have a written agreement with Cloudera nor | |
# with an authorized and properly licensed third party, you do not | |
# have any rights to access nor to use this code. | |
# | |
# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the | |
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY | |
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED | |
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO | |
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND | |
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, | |
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS | |
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE | |
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY | |
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR | |
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES | |
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF | |
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | |
# DATA. | |
# | |
# ########################################################################### | |
from typing import Iterable | |
import altair as alt | |
from captum.attr._utils.visualization import ( | |
VisualizationDataRecord, | |
format_word_importances, | |
_get_color, | |
) | |
try: | |
from IPython.display import display, HTML | |
HAS_IPYTHON = True | |
except ImportError: | |
HAS_IPYTHON = False | |
def format_classname(classname): | |
return f'<td>{classname}</td>' | |
def visualize_text( | |
datarecords: Iterable[VisualizationDataRecord], legend: bool = True | |
) -> "HTML": # In quotes because this type doesn't exist in standalone mode | |
assert HAS_IPYTHON, ( | |
"IPython must be available to visualize text. " | |
"Please run 'pip install ipython'." | |
) | |
dom = [] | |
dom.append( | |
'<head><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"></head>' | |
) | |
dom.append("""<table width:100; class="table">""") | |
rows = [ | |
"<thead>" | |
"<tr>" | |
"<th scope='col'><span class='text-nowrap'>Predicted Label</span></th>" | |
"<th scope='col'><span class='text-nowrap'>Attribution Score</span></th>" | |
"<th scope='col'><span class='text-nowrap'>Feature Importance</span></th>" | |
"</tr>" | |
"</thead>" | |
] | |
for datarecord in datarecords: | |
rows.append( | |
"".join( | |
[ | |
"<tbody>", | |
"<tr>", | |
format_classname( | |
f"{datarecord.pred_class.capitalize()}" | |
), | |
format_classname(f"{round(datarecord.attr_score.item(), 2)}"), | |
format_word_importances( | |
datarecord.raw_input_ids, datarecord.word_attributions | |
), | |
"<tr>", | |
"</tbody>", | |
] | |
) | |
) | |
dom.append("".join(rows)) | |
dom.append("</table>") | |
if legend: | |
dom.append("<div class='row'>") | |
dom.append("<div class='col-6'>") | |
dom.append("<b>Legend: </b>") | |
for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): | |
dom.append( | |
'<span style="display: inline-block; width: 10px; height: 10px; \ | |
border: 1px solid; background-color: \ | |
{value}"></span> {label} '.format( | |
value=_get_color(value), label=label | |
) | |
) | |
dom.append("</div>") | |
dom.append("<div class='col-6'></div>") | |
dom.append("</div>") | |
html = HTML("".join(dom)) | |
display(html) | |
return html | |
def build_altair_classification_plot(format_cls_result): | |
""" | |
Builds Altair bar chart for classification results. | |
Args: | |
format_cls_result (List): Output from `format_classification_results()` | |
""" | |
source = alt.pd.DataFrame(format_cls_result) | |
color_scale = alt.Scale( | |
domain=[record["type"] for record in format_cls_result], | |
range=["#00A3AF", "#F96702"], | |
) | |
c = ( | |
alt.Chart(source) | |
.mark_bar(size=50) | |
.encode( | |
x=alt.X( | |
"percentage_start:Q", axis=alt.Axis(title="Style Distribution (%)") | |
), | |
x2=alt.X2("percentage_end:Q"), | |
color=alt.Color( | |
"type:N", | |
legend=alt.Legend(title="Attribute"), | |
scale=color_scale, | |
), | |
) | |
.properties(height=150) | |
) | |
return c | |