import gradio as gr import jiwer import pandas as pd import logging from typing import List, Optional, Tuple, Dict from llama_cpp import Llama import os # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True, handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Initialize LLM MODEL_PATH = "./DeepSeek-R1-Distill-Qwen-1.5B-Q3_K_M.gguf" try: llm = Llama( model_path=MODEL_PATH, n_ctx=2048, # Context window n_threads=4, # CPU threads n_batch=512, # Batch size verbose=False # Disable verbose output ) logger.info("LLM initialized successfully") except Exception as e: logger.error(f"Failed to initialize LLM: {str(e)}") llm = None def calculate_wer_metrics( hypothesis: str, reference: str, normalize: bool = True, words_to_filter: Optional[List[str]] = None ) -> Dict: """ Calculate WER metrics between hypothesis and reference texts. Args: hypothesis (str): The hypothesis text reference (str): The reference text normalize (bool): Whether to normalize texts before comparison words_to_filter (List[str], optional): Words to filter out before comparison Returns: dict: Dictionary containing WER metrics Raises: ValueError: If inputs are invalid or result in empty text after processing """ logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}") # Validate inputs if not hypothesis.strip() or not reference.strip(): raise ValueError("Both hypothesis and reference texts must contain non-empty strings") if normalize: # Define basic transformations basic_transform = jiwer.Compose([ jiwer.ExpandCommonEnglishContractions(), jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.RemovePunctuation(), jiwer.Strip(), jiwer.ReduceToListOfListOfWords() ]) if words_to_filter and any(words_to_filter): def filter_words_transform(words: List[str]) -> List[str]: filtered = [word for word in words if word.lower() not in [w.lower() for w in words_to_filter]] if not filtered: raise ValueError("Text is empty after filtering words") return filtered transformation = jiwer.Compose([ basic_transform, filter_words_transform ]) else: transformation = basic_transform # Pre-check the transformed text try: transformed_ref = transformation(reference) transformed_hyp = transformation(hypothesis) if not transformed_ref or not transformed_hyp: raise ValueError("Text is empty after normalization") logger.debug(f"Transformed reference: {transformed_ref}") logger.debug(f"Transformed hypothesis: {transformed_hyp}") except Exception as e: logger.error(f"Transformation error: {str(e)}") raise ValueError(f"Error during text transformation: {str(e)}") measures = jiwer.compute_measures( truth=reference, hypothesis=hypothesis, truth_transform=transformation, hypothesis_transform=transformation ) else: measures = jiwer.compute_measures( truth=reference, hypothesis=hypothesis ) return measures def extract_medical_terms(text: str) -> List[str]: """Extract medical terms from text using Qwen model.""" if llm is None: logger.error("LLM not initialized") return [] prompt = f"""Extract all medical terms from the following text. Return only the medical terms as a comma-separated list. Text: {text}""" try: response = llm( prompt, max_tokens=256, temperature=0.1, stop=["Text:", "\n\n"], echo=False ) response_text = response['choices'][0]['text'].strip() # Remove thinking process if present if '' in response_text and '' in response_text: medical_terms_text = response_text.split('')[-1].strip() else: medical_terms_text = response_text medical_terms = [term.strip() for term in medical_terms_text.split(',')] return [term for term in medical_terms if term and not term.startswith('<') and not term.endswith('>')] except Exception as e: logger.error(f"Error in medical term extraction: {str(e)}") return [] def calculate_medical_recall( hypothesis_terms: List[str], reference_terms: List[str] ) -> float: """ Calculate medical term recall rate. Args: hypothesis_terms (List[str]): Medical terms from hypothesis reference_terms (List[str]): Medical terms from reference Returns: float: Recall rate """ if not reference_terms: return 1.0 if not hypothesis_terms else 0.0 correct_terms = set(hypothesis_terms) & set(reference_terms) return len(correct_terms) / len(set(reference_terms)) def process_inputs( reference: str, hypothesis: str, normalize: bool, words_to_filter: str ) -> Tuple[str, str, str, str]: """ Process inputs and calculate both WER and medical term recall metrics. Args: reference (str): Reference text hypothesis (str): Hypothesis text normalize (bool): Whether to normalize text words_to_filter (str): Comma-separated words to filter Returns: Tuple[str, str, str, str]: HTML formatted main metrics, error analysis, and explanations """ if not reference or not hypothesis: return "Please provide both reference and hypothesis texts.", "", "", "" try: # Extract medical terms logger.info("Extracting medical terms from reference text...") reference_terms = extract_medical_terms(reference) logger.info(f"Reference terms extracted: {reference_terms}") logger.info("Extracting medical terms from hypothesis text...") hypothesis_terms = extract_medical_terms(hypothesis) logger.info(f"Hypothesis terms extracted: {hypothesis_terms}") # Calculate medical recall med_recall = calculate_medical_recall(hypothesis_terms, reference_terms) # Calculate WER metrics filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None measures = calculate_wer_metrics( hypothesis=hypothesis, reference=reference, normalize=normalize, words_to_filter=filter_words ) # Format metrics metrics_df = pd.DataFrame({ 'Metric': ['WER', 'MER', 'WIL', 'WIP', 'Medical Term Recall'], 'Value': [ f"{measures['wer']:.3f}", f"{measures['mer']:.3f}", f"{measures['wil']:.3f}", f"{measures['wip']:.3f}", f"{med_recall:.3f}" ] }) # Format error analysis error_df = pd.DataFrame({ 'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'], 'Count': [ measures['substitutions'], measures['deletions'], measures['insertions'], measures['hits'] ] }) # Format medical terms comparison med_terms_df = pd.DataFrame({ 'Source': ['Reference', 'Hypothesis'], 'Medical Terms': [ ', '.join(reference_terms), ', '.join(hypothesis_terms) ] }) metrics_html = metrics_df.to_html(index=False) error_html = error_df.to_html(index=False) med_terms_html = med_terms_df.to_html(index=False) explanation = f"""

Metrics Explanation:

Extracted Medical Terms:

{med_terms_html} """ return metrics_html, error_html, explanation, "" except Exception as e: error_msg = f"Error in processing: {str(e)}" logger.error(error_msg) return "", "", "", error_msg def load_example() -> Tuple[str, str]: """Load example texts for demonstration.""" return ( "The patient shows signs of heart attack and hypertension.", "The patient shows signs of heart attack and high blood pressure." ) def create_interface() -> gr.Blocks: """Create the Gradio interface.""" with gr.Blocks(title="WER Evaluation Tool") as interface: gr.Markdown("# Word Error Rate (WER) Evaluation Tool") gr.Markdown( "This tool helps you evaluate the Word Error Rate (WER) between a reference " "text and a hypothesis text. WER is commonly used in speech recognition and " "machine translation evaluation." ) with gr.Row(): with gr.Column(): reference = gr.Textbox( label="Reference Text", placeholder="Enter the reference text here...", lines=5 ) with gr.Column(): hypothesis = gr.Textbox( label="Hypothesis Text", placeholder="Enter the hypothesis text here...", lines=5 ) with gr.Row(): normalize = gr.Checkbox( label="Normalize text (lowercase, remove punctuation)", value=True ) words_to_filter = gr.Textbox( label="Words to filter (comma-separated)", placeholder="e.g., um, uh, ah" ) with gr.Row(): example_btn = gr.Button("Load Example") calculate_btn = gr.Button("Calculate WER", variant="primary") with gr.Row(): metrics_output = gr.HTML(label="Main Metrics") error_output = gr.HTML(label="Error Analysis") explanation_output = gr.HTML() error_msg_output = gr.HTML() # Event handlers example_btn.click( load_example, outputs=[reference, hypothesis] ) calculate_btn.click( process_inputs, inputs=[reference, hypothesis, normalize, words_to_filter], outputs=[metrics_output, error_output, explanation_output, error_msg_output] ) return interface if __name__ == "__main__": logger.info("Application started") try: app = create_interface() app.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True ) except Exception as e: logger.error(f"Failed to launch application: {str(e)}") raise