Source code for pathbench.asr_evaluators

from typing import Optional
import re
import os

import jiwer
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from phonemizer.phonemize import phonemize
from phonemizer.separator import Separator
from pyctcdecode import build_ctcdecoder

import numpy as np
from pathbench.evaluator import ReferenceTxtEvaluator, ReferenceFreeEvaluator
from pathbench.string_clean import clean_text


[docs] class ASREvaluator(ReferenceTxtEvaluator): """Computes WER using an ASR model.""" def __init__(self, model_id: str): self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) print(f"ASR model '{model_id}' loaded on {self.device}.")
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") print(f"Predicted: {predicted_transcription}") cleaned_reference = clean_text(transcription) cleaned_prediction = clean_text(predicted_transcription) print("Cleaned Reference:", cleaned_reference) print("Cleaned Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class PEREvaluator(ReferenceTxtEvaluator): """Computes PER using a language-specific ASR model.""" def __init__(self, language: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" model_ids = { "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "en-us": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", } if language not in model_ids: raise ValueError(f"Language '{language}' is not supported for PEREvaluator.") model_id = model_ids[language] self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.model.to(self.device) print(f"ASR model '{model_id}' for language '{language}' loaded on {self.device}.") self.language = language
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: if language != self.language: print( f"Warning: PEREvaluator initialized for '{self.language}' " f"but received '{language}'. Skipping." ) return None try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt", padding="longest" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") espeak_language_map = { "en": "en-us", "en-us": "en-us", "es": "es", "nl": "nl", "it": "it" } espeak_lang = espeak_language_map.get(language, language) separator = Separator(phone=" ", word="|") phonemized_reference = phonemize( clean_text(transcription), language=espeak_lang, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) phonemized_prediction = phonemize( clean_text(predicted_transcription), language=espeak_lang, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) print(f"Phonemized Reference: {phonemized_reference}") print(f"Phonemized Predicted: {phonemized_prediction}") cleaned_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip() cleaned_prediction = re.sub(r"\s+", " ", phonemized_prediction.replace("|", " ")).strip() print("Cleaned Phonemized Reference:", cleaned_reference) print("Cleaned Phonemized Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class DirectPEREvaluator(ReferenceTxtEvaluator): """Computes PER using the espeak-cv-ft model directly.""" def __init__(self): self.processor = Wav2Vec2Processor.from_pretrained( "facebook/wav2vec2-xlsr-53-espeak-cv-ft" ) self.model = Wav2Vec2ForCTC.from_pretrained( "facebook/wav2vec2-xlsr-53-espeak-cv-ft" ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device)
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt", padding="longest" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_transcription = self.processor.batch_decode(predicted_ids)[0] print(f"Reference: {transcription}") separator = Separator(phone=" ", word="|") phonemized_reference = phonemize( clean_text(transcription), language=language, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) print(f"Phonemized Reference: {phonemized_reference}") print(f"Phonemized Predicted: {predicted_transcription}") cleaned_reference = re.sub( r"\s+", " ", phonemized_reference.replace("|", " ") ).strip() cleaned_prediction = re.sub( r"\s+", " ", predicted_transcription.replace("|", " ") ).strip() print("Cleaned Phonemized Reference:", cleaned_reference) print("Cleaned Phonemized Predicted:", cleaned_prediction) return jiwer.wer(cleaned_reference, cleaned_prediction)
[docs] class DoubleASREvaluator(ReferenceFreeEvaluator): """Computes PER between greedy and LM-based CTC decoding.""" def __init__(self, language: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" model_ids = { "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "en-us": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", } if language not in model_ids: raise ValueError(f"Language '{language}' is not supported for DoubleASREvaluator.") model_id = model_ids[language] self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.model.to(self.device) print(f"ASR model '{model_id}' for language '{language}' loaded on {self.device}.") self.language = language lms_dir = 'lms' lm_paths = { "en": os.path.join(lms_dir, "wiki_en_token.arpa"), "nl": os.path.join(lms_dir, "wiki_nl_token.arpa"), "es": os.path.join(lms_dir, "wiki_es_token.arpa.bin"), "it": os.path.join(lms_dir, "wiki_it_token.arpa.bin"), } lm_lang = language.split('-')[0] if lm_lang not in lm_paths: lm_lang = 'en' lm_path = lm_paths.get(lm_lang) if lm_path and lm_path.endswith('.arpa'): bin_path = lm_path + '.bin' if os.path.exists(bin_path): lm_path = bin_path self.decoder = None if lm_path and os.path.exists(lm_path): vocab_dict = self.processor.tokenizer.get_vocab() sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} labels = list(sorted_vocab_dict.keys()) self.decoder = build_ctcdecoder(labels, kenlm_model_path=lm_path) print(f"CTC decoder for '{language}' with LM '{lm_path}' built.") else: print(f"Warning: Language model for '{language}' not found at '{lm_path}'.")
[docs] def score( self, utterance_id: str, audio_path: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: if not self.decoder: print(f"Error: No decoder available for language '{self.language}'.") return None try: duration = end_time - start_time if end_time >= 0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, mono=True, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping {audio_path} — too short ({len(speech)} samples).") return None return self._score_audio(speech, sample_rate)
def _score_audio(self, audio: np.ndarray, fs: int) -> Optional[float]: input_values = self.processor( audio, sampling_rate=fs, return_tensors="pt" ).input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) greedy_transcription = self.processor.batch_decode(predicted_ids)[0] lm_transcription = self.decoder.decode(logits.cpu().numpy()[0]) print(f"Greedy: {greedy_transcription}") print(f"With LM: {lm_transcription}") espeak_language_map = { "en": "en-us", "en-us": "en-us", "es": "es", "nl": "nl", "it": "it" } espeak_lang = espeak_language_map.get(self.language, self.language) separator = Separator(phone=" ", word="|") phonemized_greedy = phonemize( clean_text(greedy_transcription), language=espeak_lang, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) phonemized_lm = phonemize( clean_text(lm_transcription), language=espeak_lang, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) print(f"Phonemized Greedy: {phonemized_greedy}") print(f"Phonemized With LM: {phonemized_lm}") cleaned_greedy = re.sub(r"\s+", " ", phonemized_greedy.replace("|", " ")).strip() cleaned_lm = re.sub(r"\s+", " ", phonemized_lm.replace("|", " ")).strip() print("Cleaned Phonemized Greedy:", cleaned_greedy) print("Cleaned Phonemized With LM:", cleaned_lm) return jiwer.wer(cleaned_greedy, cleaned_lm)