from typing import Optional
import librosa
import torch
import torchaudio
import re
import os
from pyctcdecode import build_ctcdecoder
import numpy as np
from pathbench.evaluator import ReferenceFreeEvaluator
from pathbench.string_clean import clean_text, cached_phonemize
[docs]
class ArtPDoubleASREvaluator(ReferenceFreeEvaluator):
"""An evaluator that uses a wav2vec 2.0 model to compute articulatory precision."""
def __init__(self, language: str, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft"):
from pathbench.model_registry import get_ctc_model
self.phonetic_processor, self.phonetic_model, self.device = get_ctc_model(model_id)
self.language = language
model_ids = {
"en": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"en-us": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
"cmn": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
}
if language not in model_ids:
raise ValueError(f"Language '{language}' is not supported for ArtPDoubleASREvaluator.")
asr_model_id = model_ids[language]
self.processor, self.model, _ = get_ctc_model(asr_model_id)
# Assuming the 'lms' directory is at the project root.
lms_dir = 'lms'
lm_paths = {
"en": os.path.join(lms_dir, "wiki_en_token.arpa"),
"nl": os.path.join(lms_dir, "wiki_nl_token.arpa"),
"es": os.path.join(lms_dir, "wiki_es_token.arpa.bin"),
"it": os.path.join(lms_dir, "wiki_it_token.arpa.bin"),
"cmn": os.path.join(lms_dir, "wiki_zh_token.arpa"),
}
lm_lang = language.split('-')[0]
if lm_lang not in lm_paths:
lm_lang = 'en' # Default to 'en' if no specific LM
lm_path = lm_paths.get(lm_lang)
if lm_path and lm_path.endswith('.arpa'):
bin_path = lm_path + '.bin'
if os.path.exists(bin_path):
print(f"Found binary LM file: {bin_path}, using it.")
lm_path = bin_path
self.decoder = None
if lm_path and os.path.exists(lm_path):
vocab_dict = self.processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
labels = list(sorted_vocab_dict.keys())
self.decoder = build_ctcdecoder(
labels,
kenlm_model_path=lm_path,
)
print(f"CTC decoder for '{language}' with LM '{lm_path}' built.")
else:
print(f"Warning: Language model for '{language}' not found at '{lm_path}'. No decoder built.")
[docs]
def score(
self,
utterance_id: str,
audio_path: str,
start_time: float = 0.0,
end_time: float = -1.0,
) -> Optional[float]:
"""
Computes the articulatory precision score.
"""
if not self.decoder:
print(f"Error: No decoder available for language '{self.language}'.")
return None
try:
duration = end_time - start_time if end_time != -1.0 else None
speech, sample_rate = librosa.load(
audio_path, sr=16000, offset=start_time, duration=duration
)
except Exception as e:
print(f"Error reading audio file {audio_path}: {e}")
return None
if len(speech) < 400:
print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).")
return None
return self._score_audio(speech, sample_rate)
def _score_audio(self, audio: np.ndarray, fs: int) -> Optional[float]:
# Get transcription from language-specific ASR with LM
input_values_asr = self.processor(
audio, sampling_rate=fs, return_tensors="pt"
).input_values
input_values_asr = input_values_asr.to(self.device)
with torch.no_grad():
logits_asr = self.model(input_values_asr).logits
logits_numpy = logits_asr.cpu().numpy()[0]
lm_transcription = self.decoder.decode(logits_numpy)
# Process audio with phonetic model
input_values_phonetic = self.phonetic_processor(
audio, sampling_rate=fs, return_tensors="pt"
).input_values
input_values_phonetic = input_values_phonetic.to(self.device)
# Get model outputs
with torch.no_grad():
logits = self.phonetic_model(input_values_phonetic).logits
# 1. Phonemize the n-gram improved transcription.
phonemized_reference = cached_phonemize(clean_text(lm_transcription), self.language)
phonemized_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip()
print(f"Phonemized reference: {phonemized_reference}")
# 2. Get the mapping from phonemes to model vocab indices.
vocab = self.phonetic_processor.tokenizer.get_vocab()
target_phonemes = phonemized_reference.split()
# ʲ phonemes are not in so remove
target_phonemes = [p.replace("ʲ", "") for p in target_phonemes]
target_phonemes = [p.replace("dz", "z") for p in target_phonemes]
target_phonemes = [p for p in target_phonemes if p in vocab]
if not target_phonemes:
print(f"Warning: No recognisable phonemes from ASR transcription. Falling back to 'a'.")
target_phonemes = ["a"]
target_ids = [vocab[p] for p in target_phonemes]
# 3. Forced alignment
emissions = torch.log_softmax(logits, dim=-1)
emissions = emissions.cpu()
targets = torch.tensor(target_ids, dtype=torch.int32).unsqueeze(0)
try:
aligned_path, scores = torchaudio.functional.forced_align(
emissions, targets, blank=vocab.get(self.phonetic_processor.tokenizer.pad_token, 0)
)
except Exception as e:
print(f"Forced alignment failed: {e}")
return None
# 4. Calculate Articulatory Precision
best_path = aligned_path
total_prob = 0
num_phonemes = 0
# Convert alignment scores from log-probabilities to probabilities
# so the final score is an average probability, not log probability.
prob_scores = torch.exp(scores)
for i, (token, score) in enumerate(zip(best_path[0,:], prob_scores[0,:])):
if not token == vocab.get(self.phonetic_processor.tokenizer.pad_token, 0):
num_phonemes += 1
total_prob += score
if num_phonemes > 0:
artp_score = float(total_prob / num_phonemes)
else:
artp_score = 0.0
return artp_score