Source code for pathbench.nad_evaluator

from typing import List, Optional

import numpy as np
from dtw import dtw
import torch
import librosa

from pathbench.evaluator import ReferenceAudioEvaluator, ReferenceTxtAndAudioEvaluator
from pathbench.vad import FATrimmer
from pathbench.model_registry import get_featurizer as load_wav2vec2_featurizer


[docs] class NADEvaluator(ReferenceAudioEvaluator): """ An evaluator that computes the Normalized Alignment Distance (NAD) using DTW on wav2vec2 features. """ def __init__(self, model_id="facebook/wav2vec2-large", layer=10): self.featurizer = load_wav2vec2_featurizer(model_id, layer) self.min_feature_len = 2 # DTW requires at least 2 feature vectors self._feature_cache = {} # (audio_path, start_time, end_time) -> (features, err) def _get_features(self, audio_path, start_time, end_time): """Helper to load and featurize an audio file. Results are cached.""" cache_key = (audio_path, start_time, end_time) if cache_key in self._feature_cache: return self._feature_cache[cache_key] # 1. Load audio audio = None try: duration = end_time - start_time if end_time != -1.0 else None offset = start_time if start_time != 0.0 else 0 audio, _ = librosa.load(audio_path, sr=16000, offset=offset, duration=duration) if audio is None or len(audio) == 0: result = (None, f"Audio at {audio_path} could not be loaded or is empty.") self._feature_cache[cache_key] = result return result # 2. Featurize features = self.featurizer(audio) if features.shape[0] < self.min_feature_len: result = (None, f"Feature length for {audio_path} is {features.shape[0]}, which is less than minimum {self.min_feature_len}.") self._feature_cache[cache_key] = result return result result = (features, None) self._feature_cache[cache_key] = result return result except Exception as e: result = (None, f"Failed to process {audio_path}: {e}") self._feature_cache[cache_key] = result return result
[docs] def score( self, utterance_id: str, audio_path: str, reference_audios: List[tuple[str, float, float]], start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """Computes the average DTW distance between test and reference audio.""" if not reference_audios: return None test_feats, err = self._get_features(audio_path, start_time, end_time) if err: print(f"Error: Failed to get features for test audio {utterance_id}: {err}") return None ref_feats = [] for ref_path, ref_start, ref_end in reference_audios: r_feats, err = self._get_features(ref_path, ref_start, ref_end) if err: print(f"Warning: Failed to get features for ref {ref_path} in group {utterance_id}, skipping ref. Error: {err}") else: ref_feats.append(r_feats) # --- Calculate DTW --- if test_feats is None or not ref_feats: print(f"Error: Could not obtain valid features for DTW calculation for group {utterance_id}.") return None distances = [] for r_feats in ref_feats: try: distance = dtw(test_feats, r_feats, distance_only=True).normalizedDistance distances.append(distance) except Exception as e: # This can happen if, even after all checks, features are problematic (e.g., all zeros) print(f"Error during DTW calculation for {utterance_id}: {e}") distances.append(np.nan) return np.nanmean(distances) if distances else None
[docs] class TrimmedNADEvaluator(ReferenceTxtAndAudioEvaluator): """ An evaluator that computes the Normalized Alignment Distance (NAD) using DTW on wav2vec2 features. Falls back to untrimmed audio for the whole group if trimming or featurization fails for any member of the group. """ def __init__(self, model_id="facebook/wav2vec2-large", layer=10, trimmer: Optional[FATrimmer] = None): self.featurizer = load_wav2vec2_featurizer(model_id, layer) self.trimmer = trimmer self.min_feature_len = 2 # DTW requires at least 2 feature vectors self._feature_cache = {} # (audio_path, start_time, end_time, use_trimming) -> (features, err) def _get_features(self, audio_path, transcription, language, start_time, end_time, use_trimming): """Helper to load, optionally trim, and featurize an audio file. Results are cached.""" cache_key = (audio_path, start_time, end_time, use_trimming) if cache_key in self._feature_cache: return self._feature_cache[cache_key] use_segment = start_time != 0.0 or end_time != -1.0 # 1. Load audio (either trimmed or from segment/file) audio = None try: if use_trimming and self.trimmer and not use_segment: trimmed_data = self.trimmer.trim(audio_path, transcription, language, start_time, end_time) if trimmed_data and len(trimmed_data[0]) > 0: audio, _ = trimmed_data if audio is None: # Fallback for failed trim or if trimming is disabled duration = end_time - start_time if end_time != -1.0 else None offset = start_time if start_time != 0.0 else 0 audio, _ = librosa.load(audio_path, sr=16000, offset=offset, duration=duration) if audio is None or len(audio) == 0: result = (None, f"Audio at {audio_path} could not be loaded or is empty.") self._feature_cache[cache_key] = result return result # 2. Featurize features = self.featurizer(audio) if features.shape[0] < self.min_feature_len: result = (None, f"Feature length for {audio_path} is {features.shape[0]}, which is less than minimum {self.min_feature_len}.") self._feature_cache[cache_key] = result return result result = (features, None) self._feature_cache[cache_key] = result return result except Exception as e: result = (None, f"Failed to process {audio_path}: {e}") self._feature_cache[cache_key] = result return result
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, reference_audios: List[tuple[str, float, float]], start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """ Computes the average DTW distance. If trimming/featurizing fails for any audio in a group (test or any reference), it falls back to untrimmed for all. """ if not reference_audios: return None # --- Pass 1: Attempt to get features with trimming enabled --- test_feats = None ref_feats = [] errors = [] use_trimming = True # Check if trimming should be attempted at all use_test_segment = start_time != 0.0 or end_time != -1.0 use_ref_segments = any(ref_start != 0.0 or ref_end != -1.0 for _, ref_start, ref_end in reference_audios) if not self.trimmer or use_test_segment or use_ref_segments: use_trimming = False if use_trimming: test_feats, err = self._get_features(audio_path, transcription, language, start_time, end_time, use_trimming=True) if err: errors.append(err) for ref_path, ref_start, ref_end in reference_audios: r_feats, err = self._get_features(ref_path, transcription, language, ref_start, ref_end, use_trimming=True) if err: errors.append(err) ref_feats.append(r_feats) # Append even if None to keep list aligned # If any error occurred, discard all results from this pass if errors or test_feats is None or any(f is None for f in ref_feats): print(f"Warning: Failed to get trimmed features for group {utterance_id}. Falling back to untrimmed. Errors: {errors}") test_feats = None ref_feats = [] use_trimming = False # Force fallback else: ref_feats = [f for f in ref_feats if f is not None] # Clean up list # --- Pass 2: Get features with trimming disabled (if pass 1 failed or was skipped) --- if not use_trimming: test_feats, err = self._get_features(audio_path, transcription, language, start_time, end_time, use_trimming=False) if err: print(f"Error: Failed to get untrimmed features for test audio {utterance_id}: {err}") return None ref_feats = [] for ref_path, ref_start, ref_end in reference_audios: r_feats, err = self._get_features(ref_path, transcription, language, ref_start, ref_end, use_trimming=False) if err: print(f"Warning: Failed to get untrimmed features for ref {ref_path} in group {utterance_id}, skipping ref. Error: {err}") else: ref_feats.append(r_feats) # --- Pass 3: Calculate DTW --- if test_feats is None or not ref_feats: print(f"Error: Could not obtain valid features for DTW calculation for group {utterance_id}.") return None distances = [] for r_feats in ref_feats: try: distance = dtw(test_feats, r_feats, distance_only=True).normalizedDistance distances.append(distance) except Exception as e: # This can happen if, even after all checks, features are problematic (e.g., all zeros) print(f"Error during DTW calculation for {utterance_id}: {e}") distances.append(np.nan) return np.nanmean(distances) if distances else None