Commit 47fa1ecc authored by Théo Mariotte's avatar Théo Mariotte
Browse files

add mvdr front end

parent db1e778d
import torchaudio.transforms as T
import torch
import gpuRIR
import numpy
HAS_FFT_CONV=True
HAS_GPU_RIR=True
try:
from fft_conv_pytorch import fft_conv
except:
HAS_FFT_CONV=False
try:
import gpuRIR
except:
HAS_GPU_RIR=False
class ReverbAugment():
"""
Augment multichannel audio signal with reverberation using room acoustic simulations
......
This diff is collapsed.
......@@ -166,12 +166,13 @@ class SACC(torch.nn.Module):
def __init__(self,
in_channels,
feat_typ='stft',
orig_self_att=False,
win_length=400,
hop_length=160,
n_mels=64,
samplerate=16000,
freq_aware=False,
mhsa_kwargs={"att_dim":256,
"orig_self_att":False,
"num_heads":1,
"bias":True,
"activation":None,
......@@ -180,8 +181,6 @@ class SACC(torch.nn.Module):
super(SACC,self).__init__()
# TODO : remove "old_attention", only here to load models previously trained with original self attention
old_attention=False
# If output features need to be upsampled (e.g. in case of WavLM features extractor)
self.upsampling=None
......@@ -197,14 +196,15 @@ class SACC(torch.nn.Module):
self.upsampling=torch.nn.Linear(wavlm_seqlen,target_seqlen)
mhsa_kwargs["input_dim"]=input_dim
mhsa_kwargs["value_dim"]=mhsa_kwargs["num_heads"]
self.self_att = MultiHeadSelfAttention(**mhsa_kwargs)
if old_attention:
if not orig_self_att:
mhsa_kwargs["value_dim"]=mhsa_kwargs["num_heads"]
self.self_att = MultiHeadSelfAttention(**mhsa_kwargs)
else:
mhsa_kwargs["num_heads"]=1
self.self_att = SelfAttention(**mhsa_kwargs)
self.in_channels=in_channels
self.mel_scale = torchaudio.transforms.MelScale(n_mels=n_mels,sample_rate=samplerate,)
self.freq_aware=freq_aware
self.feat_typ=feat_typ
self.softmax = torch.nn.Softmax(dim=-2)
......@@ -252,6 +252,40 @@ class SACC(torch.nn.Module):
return feat, w_comb, w_att
else:
return feat
def get_features(self,x):
"""
Same as forward method but returns more features for model analysis
"""
stft = self.feat_extractor(x)
#permute to fit self attention dimension requirements
if self.feat_typ == "stft":
X = torch.abs(stft).permute(0,-1,-3,-2) #(B,T,C,F)
X_log = torch.log(X+EPS)
X_norm = self.mvn(X_log)
w_comb,w_att = self.self_att(X_norm)
else:
X = stft.permute(0,-1,-3,-2)
X_norm = self.mvn(X) #(B,T,C,F)
w_comb,w_att = self.self_att(X_norm)
# apply softmax on weights as on the original paper
w_comb=self.softmax(w_comb)
# self-attentive channel combination
weighted_stft = (w_comb * X).sum(-2).permute(0,2,1)
#feat = S.permute(0,2,1)
# Compute log-mel spectrogram only in case of STFT features as input
if self.feat_typ == "stft":
logmel_spec = torch.log(self.mel_scale(weighted_stft)+EPS)
feat = self.mvn(logmel_spec)
if self.feat_typ=="wavlm" and self.upsampling is not None:
feat = self.upsampling(weighted_stft)
return stft, weighted_stft, feat, w_att, w_comb
@staticmethod
def mvn(x) -> torch.Tensor:
......@@ -290,7 +324,6 @@ class FSACC(torch.nn.Module):
hop_length=160,
n_mels=64,
samplerate=16000,
freq_aware=False,
channel_att_dim=256,
freq_att_dim=8,
mhsa_kwargs={"att_dim":256,
......@@ -515,29 +548,69 @@ class ComplexSACC(torch.nn.Module):
X_complex = S_real * torch.exp(1j*S_imag)
S = X_complex.abs().permute(0,2,1)
else:
S = (S_real**2 + S_imag**2).permute(0,2,1)
S = torch.sqrt((S_real**2 + S_imag**2)).permute(0,2,1)
#Sc = torch.empty(S_real.shape,dtype=torch.cfloat)
#Sc.real=S_real
#Sc.imag=S_imag
#S = Sc.abs().permute(0,2,1)
# extract Log-Mel features
S_logmel = torch.log(self.mel_scale(S)+EPS)
S_logmel = self.mvn(S_logmel)
if attention_weights_fl:
#w_comb = torch.empty(w_comb_real.shape,dtype=torch.cfloat)
#w_att = torch.empty(w_att_real.shape,dtype=torch.cfloat)
w_comb=w_comb_real+1j*w_comb_imag
w_att=w_att_real+1J*w_att_imag
#w_comb.imag=w_comb_imag
#w_att.real=w_att_real
#w_att.imag=w_att_imag
w_att=w_att_real+1j*w_att_imag
return S_logmel, w_comb, w_att
else:
return S_logmel
def get_features(self,x):
"""
Same as forward method but returns more features for model analysis
"""
# STFT domain
X = self.multichannel_stft(x,
in_channels=self.in_channels,
n_fft=self.win_length,
win_length=self.win_length,
hop_length=self.hop_length,
center=False,
pad=self.win_length//2-1,)
if self.use_mag_phase:
# process magnitude and phase separately
X_real = X.abs().permute(0,-1,-3,-2)
X_imag = X.angle().permute(0,-1,-3,-2)
else:
# process real and imaginary parts separately
X_real = X.real.permute(0,-1,-3,-2)
X_imag = X.imag.permute(0,-1,-3,-2)
# mean and variance normalization
X_real_n = self.mvn(X_real)
X_imag_n = self.mvn(X_imag)
# compute real and imaginary attention weights
w_comb_real, w_att_real = self.self_att_real(X_real_n)
w_comb_imag, w_att_imag = self.self_att_imag(X_imag_n)
w_comb_real=self.softmax(w_comb_real)
w_comb_imag=self.softmax(w_comb_imag)
# Apply attention to real and imaginary parts and combine them
S_real = (w_comb_real * X_real).sum(-2)
S_imag = (w_comb_imag * X_imag).sum(-2)
# Get module of weighted stft
if self.use_mag_phase:
X_complex = S_real * torch.exp(1j*S_imag)
S = X_complex.abs().permute(0,2,1)
else:
S = (S_real**2 + S_imag**2).permute(0,2,1)
weighted_stft=S_real + 1j*S_imag
# extract Log-Mel features
S_logmel = torch.log(self.mel_scale(S)+EPS)
S_logmel = self.mvn(S_logmel)
return X, weighted_stft, S_logmel, w_att_real+1j*w_att_imag, w_comb_real+1j*w_comb_imag
......@@ -244,6 +244,7 @@ class SeqToSeq(torch.nn.Module):
bias_fl=cfg["pre_processing"][k]["bias"]
att_dim=cfg["pre_processing"][k]["att_weight_dim"]
n_mels=cfg["pre_processing"][k]["n_mels"]
orig_self_att=cfg["pre_processing"][k]["orig_self_att"]
num_heads = cfg["pre_processing"][k].get("num_heads",1)
ffn_out=cfg["pre_processing"][k].get("ffn_out",True)
if feat_typ == 'stft':
......@@ -255,6 +256,7 @@ class SeqToSeq(torch.nn.Module):
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
orig_self_att=orig_self_att,
mhsa_kwargs={"att_dim":att_dim,
"num_heads":num_heads,
"bias":bias_fl,
......@@ -300,6 +302,7 @@ class SeqToSeq(torch.nn.Module):
att_weight_dim=cfg["pre_processing"][k]["att_weight_dim"]
n_mels=cfg["pre_processing"][k]["n_mels"]
activation = cfg["pre_processing"][k].get("activation",None)
use_mag_phase=cfg["pre_processing"][k].get("use_mag_phase",False)
self.feature_size=n_mels
pre_processing_layers.append(
......@@ -310,7 +313,8 @@ class SeqToSeq(torch.nn.Module):
n_mels=n_mels,
samplerate=self.samplerate,
att_bias=bias_fl,
activation=activation,)))
activation=activation,
use_mag_phase=use_mag_phase,)))
input_size = self.feature_size
self.pre_processing = torch.nn.Sequential(OrderedDict(pre_processing_layers))
......@@ -418,15 +422,25 @@ class SeqToSeq(torch.nn.Module):
return x
else:
if len(self.pre_processing)==1 and isinstance(self.pre_processing[0],SACC):
x,w_comb,w_att = self.pre_processing[0](inputs,attention_weights_fl=True)
return x, w_comb, w_att
#x,w_comb,w_att = self.pre_processing[0](inputs,attention_weights_fl=True)
stft, weighted_stft, feat, w_att, w_comb = self.pre_processing[0].get_features(inputs)
return stft, weighted_stft, feat, w_att, w_comb
if len(self.pre_processing)==1 and isinstance(self.pre_processing[0],ComplexSACC):
x,w_comb,w_att = self.pre_processing[0](inputs,attention_weights_fl=True)
return x, w_comb, w_att
stft, weighted_stft, feat, w_att, w_comb = self.pre_processing[0].get_features(inputs)
return stft, weighted_stft, feat, w_att, w_comb
else:
x = self.pre_processing(inputs)
return x
def load(self,model_path,):
try:
checkpoint = torch.load(model_path)
print('\nloading pre-trained model from {}'.format(model_path))
self.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.train()
except: #file doesn't exist yet
pass
def get_output_size(self):
"""
......
import torch
import torchaudio
import numpy
import scipy.special as bessel
EPS=1e-6
def cross_power_rec(x,y,alpha=0.68,**stft_kw):
"""
Compute short time auto and cross power spectrum estimates using recursive averaging
"""
stft = torchaudio.transforms.Spectrogram(**stft_kw)
X=stft(x)
Y=stft(y)
phi_xx = X * X.conj()
phi_yy = Y * Y.conj()
phi_xy = X * Y.conj()
# PSD matrix estimation with recursive averaging
phi_xx[...,1:]=alpha*phi_xx[...,:-1] + (1-alpha)*phi_xx[...,1:]
phi_yy[...,1:]=alpha*phi_yy[...,:-1] + (1-alpha)*phi_yy[...,1:]
phi_xy[...,1:]=alpha*phi_xy[...,:-1] + (1-alpha)*phi_xy[...,1:]
return phi_xx, phi_yy, phi_xy
def spat_coherence_sig(phi_xx,phi_yy,phi_xy):
"""
Spatial coherence from auto/cross-power spectra
"""
den=torch.sqrt(phi_xx*phi_yy)
return phi_xy/den
def noise_coherence_model(n_freq,typ="3d_iso",fs=16000,c=340.0,intermic_dist=0.1,):
"""
Computes the noise spatial coherence model for a set of frequencies. Two models are implemented: fully siotropic and 2d isotropic (more accurate in room acoustics since floor and weiling are more absorbant)
:param n_freq: number of frequency bin in the spatial coherence model
:param typ: model of spatial coherence ("3d_iso": 3d isotropic; "2d_iso": 2d isotropic)
:param fs: sample rate
:param c: speed of the sound
:param intermic dist: distance between the two microphones in the spatial coherence model
"""
f_vec = torch.linspace(0,fs/2,n_freq)
wt=2*numpy.pi*f_vec*intermic_dist/c
if typ=="3d_iso":
sigma_n = torch.sin(wt)/wt
sigma_n[0]=1.0
elif typ=="2d_iso":
sigma_n = torch.from_numpy(bessel.jv(0,wt.detach().cpu().numpy()))
else:
raise Exception("Ahhhhhhh!")
return sigma_n
def cdr_estimate(sigma_x,sigma_n):
"""
Coherent-to-Diffuse Ratio (CDR) estimate from [1] eq. 25
[1] A. Schwarz et W. Kellermann, « Coherent-to-Diffuse Power Ratio Estimation for Dereverberation »
TASLP, vol. 23, nᵒ 6, p. 1006‑1018 (2015), doi: 10.1109/TASLP.2015.2418571
"""
sigma_n=sigma_n.unsqueeze(0).unsqueeze(0).unsqueeze(-1).repeat(sigma_x.shape[0],1,1,sigma_x.shape[-1])
num_fact = sigma_n * sigma_x.real - sigma_x.abs()**2
sqrt_fact = torch.sqrt(sigma_n**2 * sigma_x.real**2 - sigma_n**2 * sigma_x.abs()**2 + sigma_n**2 - 2*sigma_n*sigma_x.real + sigma_x.abs()**2)
num_fact-=sqrt_fact
den_fact = sigma_x.abs()**2-1
return num_fact/(den_fact + EPS)
def tf_mask_estimate(cdr_est,g_min=1e-3,mu=0.6):
g_min=torch.zeros_like(cdr_est)
mask_=torch.sqrt(1-mu/(cdr_est+1))
sig_mask = torch.max(g_min,mask_)
noise_mask = 1-sig_mask
return sig_mask, noise_mask
def get_masks(x,y,alpha=0.68,**kwargs):
"""
Estimate signal/masks time-frequency masks based on the method proposed in [1]. Estimation is made under isotropic diffuse noise assumption
(fully iosotropic of 2d isotropic). The model is computed using equations (14) and (25) from the paper.
[1] A. Schwarz et W. Kellermann, « Coherent-to-Diffuse Power Ratio Estimation for Dereverberation »
TASLP, vol. 23, nᵒ 6, p. 1006‑1018 (2015), doi: 10.1109/TASLP.2015.2418571
"""
phi_xx, phi_yy, phi_xy = cross_power_rec(x,y,alpha=0.68,**kwargs["stft_kw"])
sigma_x = spat_coherence_sig(phi_xx,phi_yy,phi_xy)
sigma_n = noise_coherence_model(n_freq=kwargs["stft_kw"]["n_fft"]//2+1,
typ=kwargs["model_typ"],
fs=kwargs["fs"],
c=kwargs["c"],
intermic_dist=kwargs["intermic_dist"],)
#ipdb.set_trace()
cdr_est = cdr_estimate(sigma_x,sigma_n)
mask_s, mask_n = tf_mask_estimate(cdr_est=cdr_est,g_min=1e-3,mu=0.6,)
mask_s[mask_s.isnan()==1] = EPS
mask_n[mask_n.isnan()==1] = EPS
return mask_s, mask_n
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment