from pathlib import Path
from typing import Dict, Iterator, Any, Optional, List, Tuple
import numpy as np
from pathbench.string_clean import clean_text
def _load_kaldi_style_file(file_path: Path, num_parts: int) -> Dict[str, list[str]]:
"""Loads a Kaldi-style file (e.g., wav.scp, text, utt2spk) into a dictionary."""
data = {}
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split(maxsplit=num_parts - 1)
if len(parts) == num_parts:
key, *values = parts
data[key] = values
return data
PHONEMISER_LANG_MAPPING = {
"en": "en-us",
"nl": "nl",
"it": "it",
"es": "es",
"zh": "cmn",
}
[docs]
class Dataset:
"""Handles a speech dataset in a Kaldi-style format."""
def __init__(self, dataset_path: str, use_reference: bool = False,
reference_path: str = None,
reference_type: str = 'control',
reference_mapping: dict = None):
self.path = Path(dataset_path)
if not self.path.is_dir():
raise FileNotFoundError(f"Dataset directory not found: {self.path}")
# Load language from file or default to 'en'
lang_file = self.path / "language"
if lang_file.exists():
lang = lang_file.read_text().strip()
else:
lang = "en"
print(
f"Warning: 'language' file not found in {self.path}. Defaulting to 'en'."
)
self.language = PHONEMISER_LANG_MAPPING.get(lang, lang)
self.segments = self._load_if_exists("segments", 4)
self.wav_scp = {key: value[0] for key, value in self._load_if_exists("wav.scp", 2).items()}
self.text = {key: value[0] for key, value in self._load_if_exists("text", 2).items()}
self.utt2spk = {key: value[0] for key, value in self._load_if_exists("utt2spk", 2).items()}
self.spk2score = self._load_scores_if_exists("spk2score")
self.utt2score = self._load_scores_if_exists("utt2score")
self.spk2gender = {key: value[0] for key, value in self._load_if_exists("spk2gender", 2).items()}
self.spk2age = self._load_scores_if_exists("spk2age")
self.use_reference = use_reference
self.reference_path = reference_path
self.reference_type = reference_type
self.reference_mapping = reference_mapping
self.reference_dataset = None
if self.use_reference:
if self.reference_type == "none":
pass
elif self.reference_type in ['control', 'all']:
if not self.reference_path:
raise ValueError('reference_path is required for control/all reference types')
self.reference_dataset = Dataset(self.reference_path)
elif self.reference_type == 'custom' and not self.reference_mapping:
raise ValueError('reference_mapping is required for custom reference type')
def _load_if_exists(self, filename: str, num_parts: int) -> Dict[str, list[str]]:
file_path = self.path / filename
if file_path.exists():
return _load_kaldi_style_file(file_path, num_parts)
return {}
def _load_scores_if_exists(self, filename: str) -> Dict[str, float]:
file_path = self.path / filename
scores = {}
if file_path.exists():
with open(file_path, 'r') as f:
for line in f:
key, score = line.strip().split()
if score == 'N/A':
print("Warning: Found 'N/A' score for key:", key)
scores[key] = np.nan
else:
scores[key] = float(score)
return scores
def __iter__(self) -> Iterator[tuple[str, str, str, Optional[List[Tuple[str, float, float]]], float, float]]:
"""
Iterates over utterances, yielding utterance ID, audio path, transcription,
a list of reference audio paths, start time, and end time.
"""
if self.segments:
utt_ids = list(self.segments.keys())
else:
utt_ids = list(self.wav_scp.keys())
for utt_id in utt_ids:
start_time, end_time = 0.0, -1.0
if self.segments:
if utt_id in self.segments:
rec_id, start_str, end_str = self.segments[utt_id]
audio_path = self.wav_scp.get(rec_id)
start_time = float(start_str)
end_time = float(end_str)
else:
continue
else:
audio_path = self.wav_scp.get(utt_id)
if not audio_path:
continue
transcription = self.text.get(utt_id, "")
reference_audio_paths = None
if self.use_reference:
reference_audio_paths = self._get_reference_audios(utt_id, transcription)
yield utt_id, audio_path, transcription, reference_audio_paths, start_time, end_time
def _get_reference_audios(self, utt_id: str, transcription: str) -> List[tuple[str, float, float]]:
if self.reference_type == "control":
return self._load_same_text_references(utt_id, transcription)
elif self.reference_type == "all":
current_speaker = self.utt2spk.get(utt_id)
return self._load_all_same_text_references(transcription, current_speaker)
elif self.reference_type == "custom":
return self._load_custom_references(utt_id)
else:
raise ValueError(f"Unsupported reference_type: {self.reference_type}")
def _load_same_text_references(self, utt_id: str, transcription: str) -> List[tuple[str, float, float]]:
"""
Loads reference audios from control speakers with the same transcription and gender.
"""
if not self.reference_dataset:
return []
current_speaker = self.utt2spk.get(utt_id)
if not current_speaker:
return []
current_gender = self.spk2gender.get(current_speaker)
if not current_gender:
return []
ref_paths = []
cleaned_transcription = clean_text(transcription)
for ref_utt_id, ref_trans in self.reference_dataset.text.items():
#print("check sameness:", clean_text(ref_trans), cleaned_transcription)
if clean_text(ref_trans) == cleaned_transcription:
#print("same")
ref_speaker = self.reference_dataset.utt2spk.get(ref_utt_id)
if not ref_speaker:
continue
ref_gender = self.reference_dataset.spk2gender.get(ref_speaker)
if ref_gender != current_gender:
continue
start_time, end_time = 0.0, -1.0
if self.reference_dataset.segments:
if ref_utt_id in self.reference_dataset.segments:
rec_id, start_str, end_str = self.reference_dataset.segments[ref_utt_id]
audio_path = self.reference_dataset.wav_scp.get(rec_id)
start_time = float(start_str)
end_time = float(end_str)
else:
continue
else:
audio_path = self.reference_dataset.wav_scp.get(ref_utt_id)
if audio_path:
ref_paths.append((audio_path, start_time, end_time))
return ref_paths
def _load_all_same_text_references(self, transcription: str, current_speaker: str) -> List[tuple[str, float, float]]:
"""
Loads all reference audios with the same transcription from different speakers,
from both the main dataset and the reference dataset.
"""
ref_paths = []
# Search in the main dataset (e.g., pathological)
ref_paths.extend(Dataset._find_matching_references_in_dataset(self, transcription, current_speaker))
# Search in the reference dataset (e.g., control)
if self.reference_dataset:
ref_paths.extend(Dataset._find_matching_references_in_dataset(self.reference_dataset, transcription, current_speaker))
return ref_paths
@staticmethod
def _find_matching_references_in_dataset(dataset, transcription: str, current_speaker: str) -> List[tuple[str, float, float]]:
paths = []
cleaned_transcription = clean_text(transcription)
for utt_id, trans in dataset.text.items():
if clean_text(trans) == cleaned_transcription:
speaker = dataset.utt2spk.get(utt_id)
# The speaker check should be against the original utterance speaker
if speaker != current_speaker:
start_time, end_time = 0.0, -1.0
if dataset.segments:
if utt_id in dataset.segments:
rec_id, start_str, end_str = dataset.segments[utt_id]
path = dataset.wav_scp.get(rec_id)
start_time = float(start_str)
end_time = float(end_str)
else:
continue
else:
path = dataset.wav_scp.get(utt_id)
if path:
paths.append((path, start_time, end_time))
return paths
def _load_custom_references(self, utt_id: str) -> List[tuple[str, float, float]]:
"""
Loads reference audios based on a custom mapping.
Assumes the mapping is from utterance ID to a list of audio file paths.
Segments are not supported for custom references.
"""
if not self.reference_mapping:
return []
paths = self.reference_mapping.get(utt_id, [])
return [(path, 0.0, -1.0) for path in paths]
[docs]
def get_utterances(self):
"""Returns a list of utterance IDs."""
if self.segments:
return list(self.segments.keys())
return list(self.wav_scp.keys())