#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from abc import abstractmethod
from scipy.stats import pearsonr
from scipy.signal import windows, convolve
from scipy.ndimage import uniform_filter1d
import numpy as np
import soundfile as sf
import os.path
from numpy.linalg import norm # distance calc in dtw
from dtw import dtw #
from pathbench.utils import normalise_signal, moving_average_filtering
from typing import List, Optional
import librosa
from pathbench.evaluator import ReferenceAudioEvaluator
eps = np.finfo(float).eps
def _dgtreal(signal, wi, hop_size, nfft):
"""Discrete Gabor Transform for real signals (replaces ltfatpy.dgtreal).
Computes the DGT using a canonical tight Hamming window, matching the
output of ``ltfatpy.dgtreal(signal, wi, hop_size, nfft)`` for the
window spec ``{'name': ('tight', 'hamming'), 'M': <window_length>}``.
Parameters
----------
signal : array_like
Real-valued input signal.
wi : dict
Window spec, e.g. ``{'name': ('tight', 'hamming'), 'M': 512}``.
hop_size : int
Time shift in samples (LTFAT parameter *a*).
nfft : int
Number of frequency channels / FFT length (LTFAT parameter *M*).
Returns
-------
coeffs : ndarray, shape (nfft//2+1, num_frames)
Complex Gabor coefficients (positive frequencies only).
Ls : int
Original length of *signal*.
g_tight : ndarray
Canonical tight analysis window that was used.
"""
M = nfft
a = hop_size
Ls = len(signal)
window_length = wi['M']
# --- Build the canonical tight window (LTFAT convention) ----------------
# 1. LTFAT Hamming: centered at index 0, circular/DFT-even sampling.
x = np.r_[0:0.5:1.0/window_length, -0.5:0:1.0/window_length]
g = 0.54 + 0.46 * np.cos(2 * np.pi * x)
g *= (np.abs(x) < 0.5).astype(float) # zero the boundary sample
# 2. Energy-normalize (L2 norm = 1), matching LTFAT's firwin norm='energy'.
g /= np.linalg.norm(g)
# 3. Pad to signal length L using LTFAT's fir2long (HP centering):
# first half → start, second half → end, zeros in the middle.
lcm_val = int(np.lcm(a, M))
L = max(int(np.ceil(Ls / lcm_val)) * lcm_val, lcm_val)
g_long = np.zeros(L)
half = window_length // 2
g_long[:half] = g[:half]
g_long[L - half:] = g[half:]
# 4. Frame-operator diagonal: d = M * Σ_k |g_long[n + k·a]|²
N = L // a
glong2 = g_long ** 2
d_period = np.sum(glong2.reshape(N, a), axis=0) * M
d = np.tile(d_period, N)
# 5. Extract FIR-length diagonal (long2fir, HP centering for even length).
d_fir = np.concatenate([d[:half], d[L - half:]])
# 6. Canonical tight window.
g_tight = g / np.sqrt(d_fir)
# --- Compute the DGT ---------------------------------------------------
# Convert window from LTFAT circular convention (peak at 0) to standard
# symmetric form (peak at center) so we can use a centered STFT.
g_std = np.fft.fftshift(g_tight)
f = np.zeros(L)
f[:Ls] = signal
num_frames = L // a
num_bins = M // 2 + 1
coeffs = np.zeros((num_bins, num_frames), dtype=complex)
for n in range(num_frames):
centre = n * a
indices = (centre - half + np.arange(M)) % L
coeffs[:, n] = np.fft.rfft(f[indices] * g_std, M)
return coeffs, Ls, g_tight
[docs]
class STOI():
def __init__(self, reference_words: np.ndarray,
test_words: np.ndarray,
normalization_method: str,
centroid_ind: int,
frame_deletion:bool = True,
fs: int = 16000):
'''
Short Term Objective Intelligibility (STOI) measure
:params
reference_words: list of reference words
test_words: list of test words
normalization_method: 'RMS' or 'zero_mean'
frame_deletion: True or False
'''
self.reference_words = [w.copy() for w in reference_words]
self.train_target = np.ones(len(reference_words))
self.test_words = [w.copy() for w in test_words]
print("self.test_words", self.test_words)
self.normalization_method = normalization_method
self.Tw = 32 # 32
self.Ts = 16 # 16
self.J = 15 # Number of 1/3 octave bands
self.mn = 150 # Center frequency of first 1/3 octave band in Hz.
self.centroid_ind = centroid_ind
self.consecN = 15
Beta = -15 # lower SDR-bound
self.c = 10**(-Beta/20) # constant for clipping procedure
self.frame_deletion = frame_deletion
self.considered_first_bin = 0
self.fs = fs
self.Nw = int(round(1E-3 * self.Tw * self.fs))
self.Ns = int(round(1E-3 * self.Ts * self.fs))
self.nfft = int(2**np.ceil(np.log2(self.Nw)))
self.stoi_val = 0
self.estoi_val = 0
self.ref_create()
self.STOI_value()
#return self.stoi_val, self.estoi_val
[docs]
@staticmethod
def thirdoct(fs, N_fft, number_of_bands, mn):
"""
Extracts a one-thirdthird octave band representation
:param fs: sampling frequency (Hz)
:param N_fft: number of bins for the FFT
:param number_of_bands: number of one-third octave bands, marked as J in the paper
:param mn:
:return:
"""
f = np.linspace(0, fs, N_fft + 1)
f = f[0:int(N_fft/2) + 1]
k = np.arange(0, (number_of_bands))
cf = 2**(k/3)*mn
fl = np.sqrt((2**(k/3)*mn)*2**((k - 1)/3)*mn)
fr = np.sqrt((2**(k/3)*mn)*2**((k + 1)/3)*mn)
A = np.zeros((number_of_bands, len(f)))
for i in np.arange(0, (len(cf))):
b = np.argmin((f - fl[i])**2)
fl[i] = f[b]
fl_ii = b
b = np.argmin((f - fr[i])**2)
fr[i] = f[b]
fr_ii = b
A[i, fl_ii:(fr_ii)] = 1
A = A[0:number_of_bands, :]
cf = cf[0:number_of_bands]
return A, cf
[docs]
@staticmethod
def difference_oct(X, Y):
return np.abs(np.log10((X)) - np.log10((Y)))
[docs]
def align_dtw(self, control, test, frame_deletion: bool, test_time: bool):
"""
Aligns two TF representations together using dynamic time warping (DTW)
:param control: control signal to align with (np.ndarray)
:param test: test (pathological) signal to align with (np.ndarray)
:param frame_deletion: whether to delete repeated frames. My intuition is it is useful because you align two identical
length samples, and you don't need to decide which to align to? (TODO: check)
:param test_time: i have no idea (TODO: check)
:return: dtw frame paths
"""
# Calculate the path using two-norm distance based DTW
#alignment = dtw(control, test, dist_method='euclidean', keep_internals=True)
#dist = alignment.distance
#path = np.array([alignment.index1, alignment.index2])
# NOTE: Seems to have varied a lot by switching DTW implementations
_, path = librosa.sequence.dtw(X=control.T, Y=test.T, metric='euclidean')
path = np.transpose(path)
#librosa
if not test_time:
if frame_deletion:
new_path_control = np.delete(np.array(path)[0, :], 1 + np.where(np.diff(np.array(path)[0, :]) == 0)[0])
new_path_test = np.delete(np.array(path)[1, :], 1 + np.where(np.diff(np.array(path)[0, :]) == 0)[0])
else:
new_path_control = np.array(path)[0, :]
new_path_test = np.array(path)[1, :]
else:
if self.frame_deletion:
# Paper: Repeated frames affects intelligibility measures
new_path_control = np.delete(np.array(path)[1, :], 1 + np.where(np.diff(np.array(path)[1, :]) == 0)[0])
new_path_test = np.delete(np.array(path)[0, :], 1 + np.where(np.diff(np.array(path)[1, :]) == 0)[0])
new_path_control = np.delete(new_path_control, 1 + np.where(np.diff(new_path_test) == 0)[0])
new_path_test = np.delete(new_path_test, 1 + np.where(np.diff(new_path_test) == 0)[0])
else:
new_path_control = np.array(path)[1, :]
new_path_test = np.array(path)[0, :]
return new_path_control, new_path_test
[docs]
def ref_create(self):
"""
Creates the global reference signal for the comparison based on the reference signal which should contain common word/utterance
NOTE: global reference is not exactly the same as centroid. Centroid is the one that's used for creating the global reference.
:return:
"""
# Creates the reference I guess ?
self.reference_log_octave_transforms = [None] * len(self.reference_words) # Storage?
self.test_log_octave_transforms = [None] * len(self.test_words) # Storage?
# TODO: Is there any purpose to mean sum calculations? (removed for now)
self.reference_log_octave_transforms = self.log_octave_transform_extractor(self.reference_words)
self.test_log_octave_transforms = self.log_octave_transform_extractor(self.test_words)
# =============================================================================
# #XXXXX The next part create reference from many octave band representations
# =============================================================================
subjects = [self.reference_words[i] for i in np.where(self.train_target == 1)[0]]
number_of_subjects = len(subjects)
centroid = self.reference_log_octave_transforms[np.where(self.train_target == 1)[0][self.centroid_ind]] # initial representation [centroid]
sum_f = np.zeros_like(centroid) # ?
sum_f_num = np.zeros((np.size(centroid, 0), 1)) # ?
# subject_range is all other representations except centroid
subject_range = [i for i in range(number_of_subjects) if i != self.centroid_ind]
# Here all the other representations are aligned to the centroid representation. Then the energies of the
# octave-band rperesentations are summed
for num in subject_range:
aln2 = self.reference_log_octave_transforms[np.where(self.train_target == 1)[0][num]]
new_path_cont, new_path_test = self.align_dtw(centroid,aln2, frame_deletion=True,test_time=False)
# My understanding that this sums the energies in the frames like Eq1, but not entirely sure
for frame_ind in range(np.size(centroid, 0)):
sum_f[frame_ind, :] += np.sum(10**aln2[new_path_test[new_path_cont == frame_ind], :], axis=0)
# This array holds is filled with frame counts
sum_f_num[frame_ind, 0] += len(new_path_test[new_path_cont == frame_ind])
# Final reference representations
ref_for_tr = np.log10(sum_f/sum_f_num)
# Repeat the reference for all test words
tr = [ref_for_tr for _ in range(len(self.test_words))]
self.ref_test = tr
@staticmethod
def _safe_pearsonr(x, y):
"""
Helper to calculate Pearson correlation safely.
Returns (0.0, 1.0) if input variance is effectively zero (constant input),
otherwise calls scipy.stats.pearsonr.
"""
# Check for constant inputs (near-zero standard deviation)
if np.std(x) < 1e-12 or np.std(y) < 1e-12:
return 0.0, 1.0 # Correlation 0, p-value 1 (or return NaN based on preference)
return pearsonr(x, y)
[docs]
def stoi_calculation(self, N, X, Y, frame_shift, subject_id):
d_interm = np.zeros((np.size(X, axis=0), len(np.arange(N, np.size(X, axis=1) + 1, frame_shift))))
for i, m in enumerate(range(N, X.shape[1] + 1, frame_shift)):
x_segment = X[:, (m - N):m] # region with length N of clean TF-Units for all j
y_segment = Y[:, (m - N):m] # region with length N of processed TF-units for all j
alpha = np.sqrt(np.sum(x_segment ** 2, axis=1) / np.sum(y_segment ** 2, axis=1))
aY_seg = y_segment * alpha[:, np.newaxis]
for j in range(self.J):
d1 = (self.c+1) * x_segment[j, :]
d2 = aY_seg[j, :]
y_prime = np.min(np.array([d1, d2]), axis=0)
d_interm[j, i], _ = self._safe_pearsonr(x_segment[j, :], y_prime) # Eq 2 from Parvaneh's paper
# NaN columns are removed from the calculation
tmp = np.isnan(d_interm)
tmp = np.sum(tmp, axis=0)
self.stoi_val[subject_id] = np.mean(d_interm[self.considered_first_bin:, tmp == 0])
[docs]
def estoi_calculation(self, N, X, Y, frame_shift, subject_id):
d_interm_e = np.zeros((N, len(np.arange(N, np.size(X, axis=1) + 1, frame_shift))))
for ind, m in enumerate(range(N, X.shape[1] + 1, frame_shift)):
y_segment = (Y[:, (m - N):m] - np.mean(Y[:, (m - N):m], axis=1, keepdims=True)) / \
(np.std(Y[:, (m - N):m], axis=1, keepdims=True) + eps)
x_segment = (X[:, (m - N):m] - np.mean(X[:, (m - N):m], axis=1, keepdims=True)) / \
(np.std(X[:, (m - N):m], axis=1, keepdims=True) + eps)
for j in range(N):
d_interm_e[j, ind], _ = self._safe_pearsonr(x_segment[:, j], y_segment[:, j]) # Eq 4 from Parvaneh's paper
tmp = (np.isnan(d_interm_e))
tmp = np.sum(tmp, axis=0)
estoi_val = np.mean((d_interm_e[self.considered_first_bin:, (tmp == 0)]))
self.estoi_val[subject_id] = estoi_val
[docs]
def STOI_value(self):
self.stoi_val = np.zeros(len(self.test_words))
self.estoi_val = np.zeros(len(self.test_words))
number_of_subjects = len(self.test_words)
self.aligned_ref = [None]
self.aligned_test = [None] * number_of_subjects
self.aligned_ref = None
self.aligned_test = None
for subject_id in range(number_of_subjects):
aln1 = self.test_log_octave_transforms[subject_id]
new_path_cont, new_path_test = self.align_dtw(aln1, self.ref_test[subject_id],
frame_deletion=self.frame_deletion,
test_time=True)
aln1 = 10 ** aln1[new_path_test, :]
cont = 10 ** self.ref_test[subject_id][new_path_cont, :]
self.aligned_ref = cont
self.aligned_test = aln1
X = np.transpose(cont)
Y = np.transpose(aln1)
frame_shift = 1
N = np.min([self.consecN, np.size(X, axis= 1)])
try:
# STOI
self.stoi_calculation(N, X, Y, frame_shift, subject_id)
# ESTOI
self.estoi_calculation(N, X, Y, frame_shift, subject_id)
except ValueError as err:
self.stoi_val = [np.nan]
self.estoi_val = [np.nan]
#print(err)
#print('error in:', self.test_words[subject_id])
#pass
[docs]
class ReferenceEvaluator:
"""Deprecated. Kept for backward compatibility. Use ReferenceAudioEvaluator instead."""
def __init__(self, **kwargs):
self.stoi_kwargs = kwargs
[docs]
class PSTOIEvaluator(ReferenceAudioEvaluator):
"""An evaluator that uses PSTOI to compute a score."""
def __init__(self, **kwargs):
self.stoi_kwargs = kwargs
[docs]
def score(
self,
utterance_id: str,
audio_path: str,
reference_audios: List[tuple[str, float, float]],
start_time: float = 0.0,
end_time: float = -1.0,
) -> Optional[float]:
"""
Computes the PSTOI score.
"""
duration = end_time - start_time if end_time != -1 else None
test_audio, sr = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration, dtype=np.float64)
reference_audios_data = []
for ref_path, ref_start, ref_end in reference_audios:
ref_duration = ref_end - ref_start if ref_end != -1 else None
ref_audio, _ = librosa.load(ref_path, sr=16000, offset=ref_start, duration=ref_duration, dtype=np.float64)
reference_audios_data.append(ref_audio)
stoi_object = STOI(
reference_words=reference_audios_data,
test_words=[test_audio],
**self.stoi_kwargs
)
return stoi_object.stoi_val[0]
[docs]
class ESTOIEvaluator(ReferenceAudioEvaluator):
"""An evaluator that uses P-ESTOI to compute a score."""
def __init__(self, **kwargs):
self.stoi_kwargs = kwargs
[docs]
def score(
self,
utterance_id: str,
audio_path: str,
reference_audios: List[tuple[str, float, float]],
start_time: float = 0.0,
end_time: float = -1.0,
) -> Optional[float]:
"""
Computes the P-ESTOI score.
"""
duration = end_time - start_time if end_time != -1 else None
test_audio, sr = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration, dtype=np.float64)
# Check if test_audio is full silence
if np.all(test_audio == 0):
print(f"Warning: Test audio {audio_path} is silent. Returning P-ESTOI score of 0.0.")
return 0.0
reference_audios_data = []
for ref_path, ref_start, ref_end in reference_audios:
ref_duration = ref_end - ref_start if ref_end != -1 else None
ref_audio, _ = librosa.load(ref_path, sr=16000, offset=ref_start, duration=ref_duration, dtype=np.float64)
reference_audios_data.append(ref_audio)
stoi_object = STOI(
reference_words=reference_audios_data,
test_words=[test_audio],
**self.stoi_kwargs
)
return stoi_object.estoi_val[0]