import gradio as gr
import jiwer
import pandas as pd
import logging
from typing import List, Optional, Tuple, Dict
from llama_cpp import Llama
import os
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
force=True,
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Initialize LLM
MODEL_PATH = "./DeepSeek-R1-Distill-Qwen-1.5B-Q3_K_M.gguf"
try:
llm = Llama(
model_path=MODEL_PATH,
n_ctx=2048, # Context window
n_threads=4, # CPU threads
n_batch=512, # Batch size
verbose=False # Disable verbose output
)
logger.info("LLM initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize LLM: {str(e)}")
llm = None
def calculate_wer_metrics(
hypothesis: str,
reference: str,
normalize: bool = True,
words_to_filter: Optional[List[str]] = None
) -> Dict:
"""
Calculate WER metrics between hypothesis and reference texts.
Args:
hypothesis (str): The hypothesis text
reference (str): The reference text
normalize (bool): Whether to normalize texts before comparison
words_to_filter (List[str], optional): Words to filter out before comparison
Returns:
dict: Dictionary containing WER metrics
Raises:
ValueError: If inputs are invalid or result in empty text after processing
"""
logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}")
# Validate inputs
if not hypothesis.strip() or not reference.strip():
raise ValueError("Both hypothesis and reference texts must contain non-empty strings")
if normalize:
# Define basic transformations
basic_transform = jiwer.Compose([
jiwer.ExpandCommonEnglishContractions(),
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.RemovePunctuation(),
jiwer.Strip(),
jiwer.ReduceToListOfListOfWords()
])
if words_to_filter and any(words_to_filter):
def filter_words_transform(words: List[str]) -> List[str]:
filtered = [word for word in words
if word.lower() not in [w.lower() for w in words_to_filter]]
if not filtered:
raise ValueError("Text is empty after filtering words")
return filtered
transformation = jiwer.Compose([
basic_transform,
filter_words_transform
])
else:
transformation = basic_transform
# Pre-check the transformed text
try:
transformed_ref = transformation(reference)
transformed_hyp = transformation(hypothesis)
if not transformed_ref or not transformed_hyp:
raise ValueError("Text is empty after normalization")
logger.debug(f"Transformed reference: {transformed_ref}")
logger.debug(f"Transformed hypothesis: {transformed_hyp}")
except Exception as e:
logger.error(f"Transformation error: {str(e)}")
raise ValueError(f"Error during text transformation: {str(e)}")
measures = jiwer.compute_measures(
truth=reference,
hypothesis=hypothesis,
truth_transform=transformation,
hypothesis_transform=transformation
)
else:
measures = jiwer.compute_measures(
truth=reference,
hypothesis=hypothesis
)
return measures
def extract_medical_terms(text: str) -> List[str]:
"""Extract medical terms from text using Qwen model."""
if llm is None:
logger.error("LLM not initialized")
return []
prompt = f"""Extract all medical terms from the following text.
Return only the medical terms as a comma-separated list.
Text: {text}"""
try:
response = llm(
prompt,
max_tokens=256,
temperature=0.1,
stop=["Text:", "\n\n"],
echo=False
)
response_text = response['choices'][0]['text'].strip()
# Remove thinking process if present
if '' in response_text and '' in response_text:
medical_terms_text = response_text.split('')[-1].strip()
else:
medical_terms_text = response_text
medical_terms = [term.strip() for term in medical_terms_text.split(',')]
return [term for term in medical_terms if term and not term.startswith('<') and not term.endswith('>')]
except Exception as e:
logger.error(f"Error in medical term extraction: {str(e)}")
return []
def calculate_medical_recall(
hypothesis_terms: List[str],
reference_terms: List[str]
) -> float:
"""
Calculate medical term recall rate.
Args:
hypothesis_terms (List[str]): Medical terms from hypothesis
reference_terms (List[str]): Medical terms from reference
Returns:
float: Recall rate
"""
if not reference_terms:
return 1.0 if not hypothesis_terms else 0.0
correct_terms = set(hypothesis_terms) & set(reference_terms)
return len(correct_terms) / len(set(reference_terms))
def process_inputs(
reference: str,
hypothesis: str,
normalize: bool,
words_to_filter: str
) -> Tuple[str, str, str, str]:
"""
Process inputs and calculate both WER and medical term recall metrics.
Args:
reference (str): Reference text
hypothesis (str): Hypothesis text
normalize (bool): Whether to normalize text
words_to_filter (str): Comma-separated words to filter
Returns:
Tuple[str, str, str, str]: HTML formatted main metrics, error analysis,
and explanations
"""
if not reference or not hypothesis:
return "Please provide both reference and hypothesis texts.", "", "", ""
try:
# Extract medical terms
logger.info("Extracting medical terms from reference text...")
reference_terms = extract_medical_terms(reference)
logger.info(f"Reference terms extracted: {reference_terms}")
logger.info("Extracting medical terms from hypothesis text...")
hypothesis_terms = extract_medical_terms(hypothesis)
logger.info(f"Hypothesis terms extracted: {hypothesis_terms}")
# Calculate medical recall
med_recall = calculate_medical_recall(hypothesis_terms, reference_terms)
# Calculate WER metrics
filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None
measures = calculate_wer_metrics(
hypothesis=hypothesis,
reference=reference,
normalize=normalize,
words_to_filter=filter_words
)
# Format metrics
metrics_df = pd.DataFrame({
'Metric': ['WER', 'MER', 'WIL', 'WIP', 'Medical Term Recall'],
'Value': [
f"{measures['wer']:.3f}",
f"{measures['mer']:.3f}",
f"{measures['wil']:.3f}",
f"{measures['wip']:.3f}",
f"{med_recall:.3f}"
]
})
# Format error analysis
error_df = pd.DataFrame({
'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'],
'Count': [
measures['substitutions'],
measures['deletions'],
measures['insertions'],
measures['hits']
]
})
# Format medical terms comparison
med_terms_df = pd.DataFrame({
'Source': ['Reference', 'Hypothesis'],
'Medical Terms': [
', '.join(reference_terms),
', '.join(hypothesis_terms)
]
})
metrics_html = metrics_df.to_html(index=False)
error_html = error_df.to_html(index=False)
med_terms_html = med_terms_df.to_html(index=False)
explanation = f"""
Metrics Explanation:
- WER (Word Error Rate): The percentage of words that were incorrectly predicted
- MER (Match Error Rate): The percentage of words that were incorrectly matched
- WIL (Word Information Lost): The percentage of word information that was lost
- WIP (Word Information Preserved): The percentage of word information that was preserved
- Medical Term Recall: The proportion of reference medical terms that were correctly identified in the hypothesis
Extracted Medical Terms:
{med_terms_html}
"""
return metrics_html, error_html, explanation, ""
except Exception as e:
error_msg = f"Error in processing: {str(e)}"
logger.error(error_msg)
return "", "", "", error_msg
def load_example() -> Tuple[str, str]:
"""Load example texts for demonstration."""
return (
"The patient shows signs of heart attack and hypertension.",
"The patient shows signs of heart attack and high blood pressure."
)
def create_interface() -> gr.Blocks:
"""Create the Gradio interface."""
with gr.Blocks(title="WER Evaluation Tool") as interface:
gr.Markdown("# Word Error Rate (WER) Evaluation Tool")
gr.Markdown(
"This tool helps you evaluate the Word Error Rate (WER) between a reference "
"text and a hypothesis text. WER is commonly used in speech recognition and "
"machine translation evaluation."
)
with gr.Row():
with gr.Column():
reference = gr.Textbox(
label="Reference Text",
placeholder="Enter the reference text here...",
lines=5
)
with gr.Column():
hypothesis = gr.Textbox(
label="Hypothesis Text",
placeholder="Enter the hypothesis text here...",
lines=5
)
with gr.Row():
normalize = gr.Checkbox(
label="Normalize text (lowercase, remove punctuation)",
value=True
)
words_to_filter = gr.Textbox(
label="Words to filter (comma-separated)",
placeholder="e.g., um, uh, ah"
)
with gr.Row():
example_btn = gr.Button("Load Example")
calculate_btn = gr.Button("Calculate WER", variant="primary")
with gr.Row():
metrics_output = gr.HTML(label="Main Metrics")
error_output = gr.HTML(label="Error Analysis")
explanation_output = gr.HTML()
error_msg_output = gr.HTML()
# Event handlers
example_btn.click(
load_example,
outputs=[reference, hypothesis]
)
calculate_btn.click(
process_inputs,
inputs=[reference, hypothesis, normalize, words_to_filter],
outputs=[metrics_output, error_output, explanation_output, error_msg_output]
)
return interface
if __name__ == "__main__":
logger.info("Application started")
try:
app = create_interface()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)
except Exception as e:
logger.error(f"Failed to launch application: {str(e)}")
raise