Source code for pathbench.articulatory_precision_evaluator

from typing import Optional

import librosa
import torch
import torchaudio
import re

import numpy as np
from pathbench.evaluator import Evaluator, ReferenceFreeEvaluator, ReferenceTxtEvaluator
from pathbench.string_clean import clean_text, cached_phonemize


[docs] class PhoneticConfidenceEvaluator(ReferenceFreeEvaluator): """An evaluator that scores based on the model's average confidence in its own greedy-decoded phoneme sequence (no reference text used).""" def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_exp: bool = False): from pathbench.model_registry import get_ctc_model self.processor, self.model, self.device = get_ctc_model(model_id) self.use_exp = use_exp
[docs] def score( self, utterance_id: str, audio_path: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """ Computes the phonetic confidence score. """ try: duration = end_time - start_time if end_time != -1.0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).") return None return self._score_audio(speech, sample_rate)
def _score_audio(self, audio: np.ndarray, fs: int) -> Optional[float]: # Process audio input_values = self.processor( audio, sampling_rate=fs, return_tensors="pt" ).input_values input_values = input_values.to(self.device) # Get model outputs with torch.no_grad(): logits = self.model(input_values).logits # 2. Get the mapping from phonemes to model vocab indices. vocab = self.processor.tokenizer.get_vocab() # reverse mapping of this vocab_reverse = {v: k for k, v in vocab.items()} #print("vocab", vocab) #print("vocab_reverse", vocab_reverse) emissions = torch.log_softmax(logits, dim=-1) emissions = torch.exp(emissions) # torchaudio.functional.forced_align requires CPU tensors emissions = emissions.cpu() # 4. Calculate Articulatory Precision # The following section first shows the segmentation of the raw model output # (including <pad> tokens), and then calculates the articulatory precision score # based on the forced alignment with the reference transcription. # The score is calculated using average probabilities, not log probabilities. # Get the raw segmentation from argmax, which includes <pad> tokens print("\n--- Raw Model Output Segmentation (including <pad>) ---") best_path = torch.argmax(emissions, dim=-1)[0] change_points_raw = (best_path.diff() != 0).nonzero(as_tuple=True)[0] segments_raw = torch.cat([ torch.tensor([0], device=best_path.device), change_points_raw + 1, torch.tensor([best_path.shape[0]], device=best_path.device) ]) probabilities = emissions[0] total_prob = 0 num_phonemes = 0 for i, (start, end) in enumerate(zip(segments_raw[:-1], segments_raw[1:])): token_id = best_path[start].item() avg_prob = probabilities[start:end, token_id].mean().item() token_str = vocab_reverse.get(token_id, "UNK") # If <pad>, skip from calculation if not token_str == self.processor.tokenizer.pad_token: num_phonemes += 1 total_prob += avg_prob #print(f" Segment {i} ({token_str}): frames {start}-{end-1}, avg_prob={avg_prob:.4f}") if num_phonemes > 0: artp_score = total_prob / num_phonemes else: artp_score = 0.0 # Or handle as an error #print(f"\n--- Final Score ---") #print(f"Utterance ID: {utterance_id}") #print(f"Transcription: {transcription}") #print(f"Phonemized Reference: {phonemized_reference}") #print(f"Articulatory Precision Score: {artp_score}") #print("(Note: Score is based on forced alignment of reference text, not raw model output)") return artp_score
[docs] class ArticulatoryPrecisionEvaluator(ReferenceTxtEvaluator): """An evaluator that uses a wav2vec 2.0 model to compute articulatory precision.""" def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft"): from pathbench.model_registry import get_ctc_model self.processor, self.model, self.device = get_ctc_model(model_id)
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """ Computes the articulatory precision score. """ try: duration = end_time - start_time if end_time != -1.0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).") return None # Process audio input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt" ).input_values input_values = input_values.to(self.device) # Get model outputs with torch.no_grad(): logits = self.model(input_values).logits # 1. Phonemize the ground truth transcription. phonemized_reference = cached_phonemize(clean_text(transcription), language) phonemized_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip() print(f"Phonemized reference for {utterance_id}: {phonemized_reference}") if not phonemized_reference: print(f"Warning: Could not phonemize reference transcription for {utterance_id}.") return None # 2. Get the mapping from phonemes to model vocab indices. vocab = self.processor.tokenizer.get_vocab() # reverse mapping of this vocab_reverse = {v: k for k, v in vocab.items()} #print("vocab", vocab) #print("vocab_reverse", vocab_reverse) # The model seems to have different symbols than the phonemizer. # For example, phonemizer might produce 'ə' but the model has 'ə'. # Let's assume for now they are compatible. target_phonemes = phonemized_reference.split() print(target_phonemes) # ʲ phonemes are not in so remove target_phonemes = [p.replace("ʲ", "") for p in target_phonemes] target_phonemes = [p.replace("dz", "z") for p in target_phonemes] original_count = len(target_phonemes) target_phonemes = [p for p in target_phonemes if p in vocab] if len(target_phonemes) < original_count: print(f"Warning: Some phonemes not in model vocabulary for {audio_path}.") if not target_phonemes: print(f"Warning: No phonemes in model vocabulary for {audio_path}.") return None target_ids = [vocab[p] for p in target_phonemes] # 3. Forced alignment # Based on https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html emissions = torch.log_softmax(logits, dim=-1) # torchaudio.functional.forced_align requires CPU tensors emissions = emissions.cpu() targets = torch.tensor(target_ids, dtype=torch.int32).unsqueeze(0) print(emissions.shape) print(targets.shape) print(targets) try: aligned_path, scores = torchaudio.functional.forced_align( emissions, targets, blank=vocab.get(self.processor.tokenizer.pad_token, 0) ) except Exception as e: print(f"Forced alignment failed for {utterance_id}: {e}") return None # 4. Calculate Articulatory Precision # The following section first shows the segmentation of the raw model output # (including <pad> tokens), and then calculates the articulatory precision score # based on the forced alignment with the reference transcription. # The score is calculated using average probabilities, not log probabilities. # Get the raw segmentation from argmax, which includes <pad> tokens print("\n--- Raw Model Output Segmentation (including <pad>) ---") best_path = aligned_path #print("best_path.shape", best_path.shape) #print("emissions.shape", emissions.shape) #print("scores.shape", scores.shape) #print("best path:", best_path) #print("emisisons", emissions) #change_points_raw = (best_path.diff() != 0).nonzero(as_tuple=True)[0] #segments_raw = torch.cat([ # torch.tensor([0], device=best_path.device), # change_points_raw + 1, # torch.tensor([best_path.shape[0]], device=best_path.device) #]) total_prob = 0 num_phonemes = 0 # Convert alignment scores from log-probabilities to probabilities # so the final score is an average probability, not log probability. prob_scores = torch.exp(scores) for i, (token, score) in enumerate(zip(best_path[0,:], prob_scores[0,:])): #print(f" Frame {i}: token_id={token}, score={score}") # If <pad>, skip from calculation if not token == vocab.get(self.processor.tokenizer.pad_token, 0): num_phonemes += 1 total_prob += score if num_phonemes > 0: artp_score = float(total_prob / num_phonemes) else: artp_score = 0.0 # Or handle as an error print(artp_score) return artp_score