Source code for pathbench.articulatory_precision_evaluator

from typing import Optional

import librosa
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from phonemizer.phonemize import phonemize
from phonemizer.separator import Separator
import re

import numpy as np
from pathbench.evaluator import Evaluator, ReferenceFreeEvaluator, ReferenceTxtEvaluator
from pathbench.string_clean import clean_text


[docs] class PhoneticConfidenceEvaluator(ReferenceFreeEvaluator): """An evaluator that scores based on the model's average confidence in its own greedy-decoded phoneme sequence (no reference text used).""" def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_exp: bool = False): self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.use_exp = use_exp print(f"Phonetic model '{model_id}' loaded on {self.device}.")
[docs] def score( self, utterance_id: str, audio_path: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """ Computes the phonetic confidence score. """ try: duration = end_time - start_time if end_time != -1.0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).") return None return self._score_audio(speech, sample_rate)
def _score_audio(self, audio: np.ndarray, fs: int) -> Optional[float]: # Process audio input_values = self.processor( audio, sampling_rate=fs, return_tensors="pt" ).input_values input_values = input_values.to(self.device) # Get model outputs with torch.no_grad(): logits = self.model(input_values).logits # 2. Get the mapping from phonemes to model vocab indices. vocab = self.processor.tokenizer.get_vocab() # reverse mapping of this vocab_reverse = {v: k for k, v in vocab.items()} #print("vocab", vocab) #print("vocab_reverse", vocab_reverse) emissions = torch.log_softmax(logits, dim=-1) emissions = torch.exp(emissions) # torchaudio.functional.forced_align requires CPU tensors emissions = emissions.cpu() # 4. Calculate Articulatory Precision # The following section first shows the segmentation of the raw model output # (including <pad> tokens), and then calculates the articulatory precision score # based on the forced alignment with the reference transcription. # The score is calculated using average probabilities, not log probabilities. # Get the raw segmentation from argmax, which includes <pad> tokens print("\n--- Raw Model Output Segmentation (including <pad>) ---") best_path = torch.argmax(emissions, dim=-1)[0] change_points_raw = (best_path.diff() != 0).nonzero(as_tuple=True)[0] segments_raw = torch.cat([ torch.tensor([0], device=best_path.device), change_points_raw + 1, torch.tensor([best_path.shape[0]], device=best_path.device) ]) probabilities = emissions[0] total_prob = 0 num_phonemes = 0 for i, (start, end) in enumerate(zip(segments_raw[:-1], segments_raw[1:])): token_id = best_path[start].item() avg_prob = probabilities[start:end, token_id].mean().item() token_str = vocab_reverse.get(token_id, "UNK") # If <pad>, skip from calculation if not token_str == self.processor.tokenizer.pad_token: num_phonemes += 1 total_prob += avg_prob #print(f" Segment {i} ({token_str}): frames {start}-{end-1}, avg_prob={avg_prob:.4f}") if num_phonemes > 0: artp_score = total_prob / num_phonemes else: artp_score = 0.0 # Or handle as an error #print(f"\n--- Final Score ---") #print(f"Utterance ID: {utterance_id}") #print(f"Transcription: {transcription}") #print(f"Phonemized Reference: {phonemized_reference}") #print(f"Articulatory Precision Score: {artp_score}") #print("(Note: Score is based on forced alignment of reference text, not raw model output)") return artp_score
[docs] class ArticulatoryPrecisionEvaluator(ReferenceTxtEvaluator): """An evaluator that uses a wav2vec 2.0 model to compute articulatory precision.""" def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft"): self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) print(f"Phonetic model '{model_id}' loaded on {self.device}.")
[docs] def score( self, utterance_id: str, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0, ) -> Optional[float]: """ Computes the articulatory precision score. """ try: duration = end_time - start_time if end_time != -1.0 else None speech, sample_rate = librosa.load( audio_path, sr=16000, offset=start_time, duration=duration ) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).") return None # Process audio input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt" ).input_values input_values = input_values.to(self.device) # Get model outputs with torch.no_grad(): logits = self.model(input_values).logits # 1. Phonemize the ground truth transcription. separator = Separator(phone=" ", word="|") phonemized_reference = phonemize( clean_text(transcription), language=language, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) phonemized_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip() print(f"Phonemized reference for {utterance_id}: {phonemized_reference}") if not phonemized_reference: print(f"Warning: Could not phonemize reference transcription for {utterance_id}.") return None # 2. Get the mapping from phonemes to model vocab indices. vocab = self.processor.tokenizer.get_vocab() # reverse mapping of this vocab_reverse = {v: k for k, v in vocab.items()} #print("vocab", vocab) #print("vocab_reverse", vocab_reverse) # The model seems to have different symbols than the phonemizer. # For example, phonemizer might produce 'ə' but the model has 'ə'. # Let's assume for now they are compatible. target_phonemes = phonemized_reference.split() print(target_phonemes) # ʲ phonemes are not in so remove target_phonemes = [p.replace("ʲ", "") for p in target_phonemes] target_phonemes = [p.replace("dz", "z") for p in target_phonemes] original_count = len(target_phonemes) target_phonemes = [p for p in target_phonemes if p in vocab] if len(target_phonemes) < original_count: print(f"Warning: Some phonemes not in model vocabulary for {audio_path}.") if not target_phonemes: print(f"Warning: No phonemes in model vocabulary for {audio_path}.") return None target_ids = [vocab[p] for p in target_phonemes] # 3. Forced alignment # Based on https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html emissions = torch.log_softmax(logits, dim=-1) # torchaudio.functional.forced_align requires CPU tensors emissions = emissions.cpu() targets = torch.tensor(target_ids, dtype=torch.int32).unsqueeze(0) print(emissions.shape) print(targets.shape) print(targets) try: aligned_path, scores = torchaudio.functional.forced_align( emissions, targets, blank=vocab.get(self.processor.tokenizer.pad_token, 0) ) except Exception as e: print(f"Forced alignment failed for {utterance_id}: {e}") return None # 4. Calculate Articulatory Precision # The following section first shows the segmentation of the raw model output # (including <pad> tokens), and then calculates the articulatory precision score # based on the forced alignment with the reference transcription. # The score is calculated using average probabilities, not log probabilities. # Get the raw segmentation from argmax, which includes <pad> tokens print("\n--- Raw Model Output Segmentation (including <pad>) ---") best_path = aligned_path #print("best_path.shape", best_path.shape) #print("emissions.shape", emissions.shape) #print("scores.shape", scores.shape) #print("best path:", best_path) #print("emisisons", emissions) #change_points_raw = (best_path.diff() != 0).nonzero(as_tuple=True)[0] #segments_raw = torch.cat([ # torch.tensor([0], device=best_path.device), # change_points_raw + 1, # torch.tensor([best_path.shape[0]], device=best_path.device) #]) total_prob = 0 num_phonemes = 0 # Convert alignment scores from log-probabilities to probabilities # so the final score is an average probability, not log probability. prob_scores = torch.exp(scores) for i, (token, score) in enumerate(zip(best_path[0,:], prob_scores[0,:])): #print(f" Frame {i}: token_id={token}, score={score}") # If <pad>, skip from calculation if not token == vocab.get(self.processor.tokenizer.pad_token, 0): num_phonemes += 1 total_prob += score if num_phonemes > 0: artp_score = float(total_prob / num_phonemes) else: artp_score = 0.0 # Or handle as an error print(artp_score) return artp_score