Source code for pathbench.vad

from collections import OrderedDict
from typing import Optional, Tuple

import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from phonemizer.phonemize import phonemize
from phonemizer.separator import Separator
import re
import numpy as np
import librosa

from pathbench.string_clean import clean_text


[docs] class FATrimmer: """A class to trim silence from audio using forced alignment.""" MAX_CACHE_SIZE = 10000 def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_exp: bool = False): self.processor = Wav2Vec2Processor.from_pretrained(model_id) self.model = Wav2Vec2ForCTC.from_pretrained(model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.use_exp = use_exp self.cache = OrderedDict() print(f"Phonetic model '{model_id}' loaded on {self.device}.") def _cache_put(self, key, value): self.cache[key] = value if len(self.cache) > self.MAX_CACHE_SIZE: self.cache.popitem(last=False)
[docs] def trim(self, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0) -> Optional[Tuple[np.ndarray, int]]: """ Trims silence from the beginning and end of an audio file using forced alignment. Returns a tuple of (trimmed_audio_array, sample_rate). """ cache_key = (audio_path, transcription, language, start_time, end_time) if cache_key in self.cache: self.cache.move_to_end(cache_key) return self.cache[cache_key] try: duration = end_time - start_time if end_time != -1 else None speech, sample_rate = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration, dtype=np.float64) except Exception as e: print(f"Error reading audio file {audio_path}: {e}") return None if len(speech) < 400: print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).") return None # Process audio input_values = self.processor( speech, sampling_rate=sample_rate, return_tensors="pt" ).input_values input_values = input_values.to(self.device) # Get model outputs with torch.no_grad(): logits = self.model(input_values).logits # 1. Phonemize the ground truth transcription. separator = Separator(phone=" ", word="|") phonemized_reference = phonemize( clean_text(transcription), language=language, backend="espeak", strip=True, preserve_punctuation=False, separator=separator ) phonemized_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip() if not phonemized_reference: print(f"Warning: Could not phonemize reference transcription for {audio_path}.") print(f"Unphonemized transcription: '{transcription}'") print(f"Updated transcription after cleaning: '{clean_text(transcription)}'") return None # 2. Get the mapping from phonemes to model vocab indices. vocab = self.processor.tokenizer.get_vocab() target_phonemes = phonemized_reference.split() # remove j target_phonemes = [p.replace("ʲ", "") for p in target_phonemes] # remove dz target_phonemes = [p.replace("dz", "z") for p in target_phonemes] original_phonemes = len(target_phonemes) target_phonemes = [p for p in target_phonemes if p in vocab] if len(target_phonemes) < original_phonemes: print(f"Warning: Some phonemes not in model vocabulary for {audio_path}.") if not target_phonemes: print(f"Warning: No phonemes in model vocabulary for {audio_path}.") return None try: target_ids = [vocab[p] for p in target_phonemes] except KeyError as e: print(f"Error: Phoneme {e} not in model vocabulary.") return None # 3. Forced alignment emissions = torch.log_softmax(logits, dim=-1) if self.use_exp: emissions = torch.exp(emissions) emissions = emissions.cpu() targets = torch.tensor(target_ids, dtype=torch.int32).unsqueeze(0) try: aligned_path, scores = torchaudio.functional.forced_align( emissions, targets, blank=vocab.get(self.processor.tokenizer.pad_token, 0) ) except Exception as e: print(f"Forced alignment failed for {audio_path}: {e}") return None # 4. Get start and end frames of speech try: start_idx = -1 for i, x in enumerate(aligned_path[0]): if x != 0: start_idx = i break if start_idx == -1: print(f"Warning: Could not find start of speech in {audio_path}. Returning full audio.") self._cache_put(cache_key, (speech, sample_rate)) return speech, sample_rate end_idx = -1 for i, x in enumerate(reversed(aligned_path[0])): if x != 0: end_idx = len(aligned_path[0]) - i break if end_idx == -1: print(f"Warning: Could not find end of speech in {audio_path}. Returning full audio.") self._cache_put(cache_key, (speech, sample_rate)) return speech, sample_rate ratio = speech.shape[0] / emissions.shape[1] start_frame = int(start_idx * ratio) end_frame = int((end_idx + 1) * ratio) trimmed_audio = speech[start_frame:end_frame] self._cache_put(cache_key, (trimmed_audio, sample_rate)) return trimmed_audio, sample_rate except Exception as e: print(f"Error during trimming of {audio_path}: {e}") return None