Source code for pathbench.asr_evaluators

from typing import Optional
import re
import os

import jiwer
import librosa
import torch
from pyctcdecode import build_ctcdecoder

import numpy as np
from pathbench.evaluator import ReferenceTxtEvaluator, ReferenceFreeEvaluator
from pathbench.string_clean import clean_text, cached_phonemize


[docs] class ASREvaluator(ReferenceTxtEvaluator): """Computes WER using an ASR model.""" def __init__(self, model_id: str): from pathbench.model_registry import get_ctc_model self.processor, self.model, self.device = get_ctc_model(model_id)
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") print(f"Predicted: {predicted_transcription}") cleaned_reference = clean_text(transcription) cleaned_prediction = clean_text(predicted_transcription) print("Cleaned Reference:", cleaned_reference) print("Cleaned Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class PEREvaluator(ReferenceTxtEvaluator): """Computes PER using a language-specific ASR model.""" def __init__(self, language: str): from pathbench.model_registry import get_ctc_model model_ids = { "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "en-us": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", "cmn": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", } if language not in model_ids: raise ValueError(f"Language '{language}' is not supported for PEREvaluator.") model_id = model_ids[language] self.processor, self.model, self.device = get_ctc_model(model_id) self.language = language
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: if language != self.language: print( f"Warning: PEREvaluator initialized for '{self.language}' " f"but received '{language}'. Skipping." ) return None try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt", padding="longest" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") espeak_language_map = { "en": "en-us", "en-us": "en-us", "es": "es", "nl": "nl", "it": "it", "cmn": "cmn", } espeak_lang = espeak_language_map.get(language, language) phonemized_reference = cached_phonemize(clean_text(transcription), espeak_lang) phonemized_prediction = cached_phonemize(clean_text(predicted_transcription), espeak_lang) print(f"Phonemized Reference: {phonemized_reference}") print(f"Phonemized Predicted: {phonemized_prediction}") cleaned_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip() cleaned_prediction = re.sub(r"\s+", " ", phonemized_prediction.replace("|", " ")).strip() print("Cleaned Phonemized Reference:", cleaned_reference) print("Cleaned Phonemized Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class DirectPEREvaluator(ReferenceTxtEvaluator): """Computes PER using the espeak-cv-ft model directly.""" def __init__(self): from pathbench.model_registry import get_ctc_model self.processor, self.model, self.device = get_ctc_model( "facebook/wav2vec2-xlsr-53-espeak-cv-ft" )
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt", padding="longest" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") phonemized_reference = cached_phonemize(clean_text(transcription), language) print(f"Phonemized Reference: {phonemized_reference}") print(f"Phonemized Predicted: {predicted_transcription}") cleaned_reference = re.sub( r"\s+", " ", phonemized_reference.replace("|", " ") ).strip() cleaned_prediction = re.sub( r"\s+", " ", predicted_transcription.replace("|", " ") ).strip() print("Cleaned Phonemized Reference:", cleaned_reference) print("Cleaned Phonemized Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class DoubleASREvaluator(ReferenceFreeEvaluator): """Computes PER between greedy and LM-based CTC decoding.""" def __init__(self, language: str): from pathbench.model_registry import get_ctc_model model_ids = { "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "en-us": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", "cmn": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", } if language not in model_ids: raise ValueError(f"Language '{language}' is not supported for DoubleASREvaluator.") model_id = model_ids[language] self.processor, self.model, self.device = get_ctc_model(model_id) self.language = language lms_dir = 'lms' lm_paths = { "en": os.path.join(lms_dir, "wiki_en_token.arpa"), "nl": os.path.join(lms_dir, "wiki_nl_token.arpa"), "es": os.path.join(lms_dir, "wiki_es_token.arpa.bin"), "it": os.path.join(lms_dir, "wiki_it_token.arpa.bin"), "cmn": os.path.join(lms_dir, "wiki_zh_token.arpa"), } lm_lang = language.split('-')[0] if lm_lang not in lm_paths: lm_lang = 'en' lm_path = lm_paths.get(lm_lang) if lm_path and lm_path.endswith('.arpa'): bin_path = lm_path + '.bin' if os.path.exists(bin_path): lm_path = bin_path self.decoder = None if lm_path and os.path.exists(lm_path): vocab_dict = self.processor.tokenizer.get_vocab() sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} labels = list(sorted_vocab_dict.keys()) self.decoder = build_ctcdecoder(labels, kenlm_model_path=lm_path) print(f"CTC decoder for '{language}' with LM '{lm_path}' built.") else: print(f"Warning: Language model for '{language}' not found at '{lm_path}'.")
[docs] def score( self, utterance_id: str, audio_path: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: if not self.decoder: print(f"Error: No decoder available for language '{self.language}'.") return None try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None return self._score_audio(speech, sample_rate)
def _score_audio(self, audio: np.ndarray, fs: int) -> Optional[float]: input_values = self.processor( audio, sampling_rate=fs, return_tensors="pt" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) greedy_transcription = self.processor.batch_decode(predicted_ids)[0] lm_transcription = self.decoder.decode(logits.cpu().numpy()[0]) print(f"Greedy: {greedy_transcription}") print(f"With LM: {lm_transcription}") espeak_language_map = { "en": "en-us", "en-us": "en-us", "es": "es", "nl": "nl", "it": "it", "cmn": "cmn", } espeak_lang = espeak_language_map.get(self.language, self.language) phonemized_greedy = cached_phonemize(clean_text(greedy_transcription), espeak_lang) phonemized_lm = cached_phonemize(clean_text(lm_transcription), espeak_lang) print(f"Phonemized Greedy: {phonemized_greedy}") print(f"Phonemized With LM: {phonemized_lm}") cleaned_greedy = re.sub(r"\s+", " ", phonemized_greedy.replace("|", " ")).strip() cleaned_lm = re.sub(r"\s+", " ", phonemized_lm.replace("|", " ")).strip() print("Cleaned Phonemized Greedy:", cleaned_greedy) print("Cleaned Phonemized With LM:", cleaned_lm) return jiwer.wer(cleaned_greedy, cleaned_lm)