from collections import OrderedDict
from typing import Optional, Tuple
import torch
import torchaudio
import re
import numpy as np
import librosa
from pathbench.string_clean import clean_text, cached_phonemize
[docs]
class FATrimmer:
"""A class to trim silence from audio using forced alignment."""
MAX_CACHE_SIZE = 10000
def __init__(self, model_id: str = "facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_exp: bool = False):
from pathbench.model_registry import get_ctc_model
self.processor, self.model, self.device = get_ctc_model(model_id)
self.use_exp = use_exp
self.cache = OrderedDict()
def _cache_put(self, key, value):
self.cache[key] = value
if len(self.cache) > self.MAX_CACHE_SIZE:
self.cache.popitem(last=False)
[docs]
def trim(self, audio_path: str, transcription: str, language: str, start_time: float = 0.0, end_time: float = -1.0) -> Optional[Tuple[np.ndarray, int]]:
"""
Trims silence from the beginning and end of an audio file using forced alignment.
Returns a tuple of (trimmed_audio_array, sample_rate).
"""
cache_key = (audio_path, transcription, language, start_time, end_time)
if cache_key in self.cache:
self.cache.move_to_end(cache_key)
return self.cache[cache_key]
try:
duration = end_time - start_time if end_time != -1 else None
speech, sample_rate = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration, dtype=np.float64)
except Exception as e:
print(f"Error reading audio file {audio_path}: {e}")
return None
if len(speech) < 400:
print(f"Warning: Skipping audio file {audio_path} because it is too short ({len(speech)} samples).")
return None
# Process audio
input_values = self.processor(
speech, sampling_rate=sample_rate, return_tensors="pt"
).input_values
input_values = input_values.to(self.device)
# Get model outputs
with torch.no_grad():
logits = self.model(input_values).logits
# 1. Phonemize the ground truth transcription.
phonemized_reference = cached_phonemize(clean_text(transcription), language)
phonemized_reference = re.sub(r"\s+", " ", phonemized_reference.replace("|", " ")).strip()
if not phonemized_reference:
print(f"Warning: Could not phonemize reference transcription for {audio_path}.")
print(f"Unphonemized transcription: '{transcription}'")
print(f"Updated transcription after cleaning: '{clean_text(transcription)}'")
return None
# 2. Get the mapping from phonemes to model vocab indices.
vocab = self.processor.tokenizer.get_vocab()
target_phonemes = phonemized_reference.split()
# remove j
target_phonemes = [p.replace("ʲ", "") for p in target_phonemes]
# remove dz
target_phonemes = [p.replace("dz", "z") for p in target_phonemes]
original_phonemes = len(target_phonemes)
target_phonemes = [p for p in target_phonemes if p in vocab]
if len(target_phonemes) < original_phonemes:
print(f"Warning: Some phonemes not in model vocabulary for {audio_path}.")
if not target_phonemes:
print(f"Warning: No phonemes in model vocabulary for {audio_path}.")
return None
try:
target_ids = [vocab[p] for p in target_phonemes]
except KeyError as e:
print(f"Error: Phoneme {e} not in model vocabulary.")
return None
# 3. Forced alignment
emissions = torch.log_softmax(logits, dim=-1)
if self.use_exp:
emissions = torch.exp(emissions)
emissions = emissions.cpu()
targets = torch.tensor(target_ids, dtype=torch.int32).unsqueeze(0)
try:
aligned_path, scores = torchaudio.functional.forced_align(
emissions, targets, blank=vocab.get(self.processor.tokenizer.pad_token, 0)
)
except Exception as e:
print(f"Forced alignment failed for {audio_path}: {e}")
return None
# 4. Get start and end frames of speech
try:
start_idx = -1
for i, x in enumerate(aligned_path[0]):
if x != 0:
start_idx = i
break
if start_idx == -1:
print(f"Warning: Could not find start of speech in {audio_path}. Returning full audio.")
self._cache_put(cache_key, (speech, sample_rate))
return speech, sample_rate
end_idx = -1
for i, x in enumerate(reversed(aligned_path[0])):
if x != 0:
end_idx = len(aligned_path[0]) - i
break
if end_idx == -1:
print(f"Warning: Could not find end of speech in {audio_path}. Returning full audio.")
self._cache_put(cache_key, (speech, sample_rate))
return speech, sample_rate
ratio = speech.shape[0] / emissions.shape[1]
start_frame = int(start_idx * ratio)
end_frame = int((end_idx + 1) * ratio)
trimmed_audio = speech[start_frame:end_frame]
self._cache_put(cache_key, (trimmed_audio, sample_rate))
return trimmed_audio, sample_rate
except Exception as e:
print(f"Error during trimming of {audio_path}: {e}")
return None