wer-analysis / app.py
divi212's picture
Add in option to ignore punctuation and case
d4e7648 verified
raw
history blame
3.96 kB
import typing as T
import gradio as gr
from jiwer import process_words, RemovePunctuation, ToLowerCase, Compose
def make_string(words: T.List[str]) -> str:
"""Converts list of strings to a string"""
return " ".join(words)
def highlight_errors(
ground_truth: str,
hypothesis: str,
remove_punctuation: bool,
to_lower_case: bool,
) -> T.Tuple[str, float, int, int, int]:
"""
Takes in a ground truth and hypothesis string, applies transformations as specified by
remove_punctuation and to_lower_case, and returns data to visualize word error rate.
Specifically, this returns an HTML string with insertions, deletions, and substitutions
highlighted as well as the computed WER, and # of subsititutions, insertions, and deletions.
"""
highlighted_text = []
transforms = [
RemovePunctuation() if remove_punctuation else None,
ToLowerCase() if to_lower_case else None,
]
transform = Compose([t for t in transforms if t is not None])
processed = process_words(
reference=transform(ground_truth), hypothesis=transform(hypothesis)
)
# Process each alignment operation in measures
for alignment, ref, hyp in zip(
processed.alignments, processed.references, processed.hypotheses
):
for chunk in alignment:
if chunk.type == "equal":
# Add equal words without highlighting
highlighted_text.extend(ref[chunk.ref_start_idx : chunk.ref_end_idx])
elif chunk.type == "insert":
# Highlight inserted words in green
highlighted_text.append(
f'<span style="color:green;">'
f"{make_string(hyp[chunk.hyp_start_idx:chunk.hyp_end_idx])}</span>"
)
elif chunk.type == "substitute":
# Highlight substitutions in purple: ground truth is striked through
highlighted_text.append(
f'<span style="color:purple;">'
f"{make_string(hyp[chunk.hyp_start_idx:chunk.hyp_end_idx])}</span>"
) # Hypothesis word
highlighted_text.append(
f'<span style="color:purple; text-decoration:line-through;">'
f"{make_string(ref[chunk.ref_start_idx:chunk.ref_end_idx])}</span>"
) # Ground truth word
elif chunk.type == "delete":
# Highlight deleted words in red with strikethrough
highlighted_text.append(
f'<span style="color:red; text-decoration:line-through;">'
f"{make_string(ref[chunk.ref_start_idx:chunk.ref_end_idx])}</span>"
)
highlighted_text_str = make_string(highlighted_text)
# Color Legend HTML
legend_html = """
<div style="margin-top: 10px;">
<strong>Legend</strong><br>
<span style="color:green;">Insertion</span>: Green<br>
<span style="color:purple;">Substitution</span>: Purple<br>
<span style="color:red; text-decoration:line-through;">Deletion</span>: Red<br>
</div>
"""
# Combine highlighted output and legend
combined_output = f"{legend_html}<br>{highlighted_text_str}"
return (
combined_output,
processed.wer,
processed.substitutions,
processed.insertions,
processed.deletions,
)
# Gradio Interface
interface = gr.Interface(
fn=highlight_errors,
inputs=[
gr.Textbox(label="Ground Truth"),
gr.Textbox(label="Hypothesis"),
gr.Checkbox(label="Ignore Punctuation"),
gr.Checkbox(label="Ignore Case"),
],
outputs=[
gr.HTML(label="Highlighted Transcript"),
gr.Number(label="Word Error Rate"),
gr.Number(label="Substitutions"),
gr.Number(label="Insertions"),
gr.Number(label="Deletions"),
],
title="WER Analysis",
)
interface.launch()