Commit 1b6ebce4 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

accelerate add_reverb

parent 56eef1b6
......@@ -455,10 +455,7 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
"""
# Select the data augmentation randomly
if len(transform_dict.keys()) >= transform_number:
aug_idx = numpy.arange(len(transform_dict.keys()))
else:
aug_idx = random.choice(numpy.arange(len(transform_dict.keys())), k=transform_number)
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "phone_filtering" in augmentations:
......@@ -477,12 +474,16 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
speech = strech(speech, rate)
if "add_reverb" in augmentations:
rir_nfo = random.randrange(len(rir_df))
rir_fn = transform_dict["add_noise"]["data_path"] + "/" + rir_nfo + ".wav"
rir_nfo = rir_df.iloc[random.randrange(rir_df.shape[0])].file_id
rir_fn = transform_dict["add_reverb"]["data_path"] + "/" + rir_nfo + ".wav"
rir, rir_fs = torchaudio.load(rir_fn)
rir = rir[rir_nfo[1], :] #keep selected channel
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
#rir = rir[rir_nfo[1], :] #keep selected channel
delta = speech.size(-1) - rir.size(-1)
kernel = torch.nn.functional.pad(rir, (0, delta))
#speech_ = torch.nn.functional.pad(speech, (kernel.shape[1], 0))
# Multiply in frequency domain to convolve in time domain
result = torch.fft.rfft(speech) * torch.fft.rfft(kernel)
speech = torch.fft.irfft(result, n=speech.size(-1))
if "add_noise" in augmentations:
# Pick a noise type
......@@ -495,7 +496,7 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
# TODO make SNRs configurable by noise type
snr_db = random.randint(13, 20)
pick_count = random.randint(3, 7)
index_list = random.choices(range(noise_df.loc['speech'].shape[0]), k=pick_count)
index_list = random.sample(range(noise_df.loc['speech'].shape[0]), k=pick_count)
for idx in index_list:
noise_row = noise_df.loc['speech'].iloc[idx]
noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"])
......
......@@ -308,10 +308,10 @@ class SoftmaxAngularProto(torch.nn.Module):
return cos_sim_matrix, cce_prediction
class SoftmaxMagnet(torch.nn.Module):
class AngularProximityMagnet(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, emb_dim=256, batch_size=512, init_w=10.0, init_b=-5.0, **kwargs):
super(SoftmaxMagnet, self).__init__()
super(AngularProximityMagnet, self).__init__()
self.test_normalize = True
......@@ -319,17 +319,17 @@ class SoftmaxMagnet(torch.nn.Module):
self.b1 = torch.nn.Parameter(torch.tensor(init_b))
self.b2 = torch.nn.Parameter(torch.tensor(+5.54))
last_linear = torch.nn.Linear(256, 1)
last_linear.bias.data += 1
#last_linear = torch.nn.Linear(512, 1)
#last_linear.bias.data += 1
self.magnitude = torch.nn.Sequential(OrderedDict([
("linear9", torch.nn.Linear(emb_dim, 1)),
("relu9", torch.nn.ReLU()),
("linear9", torch.nn.Linear(256, 256)),
("relu9", torch.nn.ReLU()),
("linear9", last_linear),
("relu9", torch.nn.ReLU())
]))
#self.magnitude = torch.nn.Sequential(OrderedDict([
# ("linear9", torch.nn.Linear(emb_dim, 512)),
# ("relu9", torch.nn.ReLU()),
# ("linear10", torch.nn.Linear(512, 512)),
# ("relu10", torch.nn.ReLU()),
# ("linear11", last_linear),
# ("relu11", torch.nn.ReLU())
# ]))
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(emb_dim, spk_count))
......@@ -342,11 +342,11 @@ class SoftmaxMagnet(torch.nn.Module):
assert x.size()[1] >= 2
cce_prediction = self.cce_backend(x)
#x = self.magnitude(x) * torch.nn.functional.normalize(x)
if target==None:
return cce_prediction
return x, cce_prediction
x = self.magnitude(x) * torch.nn.functional.normalize(x)
x = x.reshape(-1,2,x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
......
......@@ -242,8 +242,7 @@ class SideSet(Dataset):
self.rir_df = None
if "add_reverb" in self.transform:
# load the RIR database
tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
self.rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
def __getitem__(self, index):
"""
......
......@@ -64,7 +64,7 @@ from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from .sincnet import SincNet
from ..bosaris.detplot import rocch, rocch2eer
from .loss import SoftmaxAngularProto, ArcLinear, SoftmaxMagnet
from .loss import SoftmaxAngularProto, ArcLinear, AngularProximityMagnet
from .loss import l2_norm
from .loss import ArcMarginProduct
from torch.cuda.amp import autocast, GradScaler
......@@ -251,17 +251,14 @@ def test_metrics(model,
backward=False)
tar, non = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device
).get_tar_non(Key(key_test_filename))
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device
).get_tar_non(Key(key_test_filename))
pmiss, pfa = rocch(tar, non)
test_eer = rocch2eer(pmiss, pfa)
return test_eer
return rocch2eer(pmiss, pfa)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
......@@ -501,6 +498,18 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = self.embedding_size)
#last_linear = torch.nn.Linear(2560, 1)
#last_linear.bias.data += 1
#self.magnitude = torch.nn.Sequential(OrderedDict([
#("linear9", torch.nn.Linear(2560, 256)),
#("relu9", torch.nn.ReLU()),
#("linear10", torch.nn.Linear(256, 256)),
#("relu10", torch.nn.ReLU()),
# ("linear11", last_linear),
# ("relu11", torch.nn.ReLU())
# ]))
self.stat_pooling = AttentivePooling(128, 80, global_context=False)
self.stat_pooling_weight_decay = 0
......@@ -511,28 +520,25 @@ class Xtractor(torch.nn.Module):
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
elif self.loss == 'smn':
self.after_speaker_embedding = SoftmaxMagnet(int(self.speaker_number))
self.after_speaker_embedding = AngularProximityMagnet(int(self.speaker_number))
self.preprocessor_weight_decay = 0.0000
self.sequence_network_weight_decay = 0.0000
self.stat_pooling_weight_decay = 0.0000
#self.magnitude_weight_decay = 0.0000
self.before_speaker_embedding_weight_decay = 0.0000
self.after_speaker_embedding_weight_decay = 0.0000
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256, 80)
self.stat_pooling_weight_decay = 0
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
......@@ -540,15 +546,13 @@ class Xtractor(torch.nn.Module):
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.00002
elif model_archi == "rawnet2":
......@@ -828,12 +832,15 @@ class Xtractor(torch.nn.Module):
return x
else:
return self.after_speaker_embedding(x), x
elif self.loss in ['aam', 'aps', 'smn']:
elif self.loss in ['aam', 'aps']:
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(x, target=target), torch.nn.functional.normalize(x, dim=1)
elif self.loss == 'smn':
if not is_eval:
x = self.after_speaker_embedding(x, target=target), x
return x
......@@ -1045,7 +1052,7 @@ def xtrain(speaker_number,
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
2,
256,
128,
dataset_params["batch_size"])
training_loader = DataLoader(training_set,
......@@ -1083,6 +1090,7 @@ def xtrain(speaker_number,
param_list.append({'params': model.preprocessor.parameters(), 'weight_decay': model.preprocessor_weight_decay})
param_list.append({'params': model.sequence_network.parameters(), 'weight_decay': model.sequence_network_weight_decay})
param_list.append({'params': model.stat_pooling.parameters(), 'weight_decay': model.stat_pooling_weight_decay})
#param_list.append({'params': model.magnitude.parameters(), 'weight_decay': model.magnitude_weight_decay})
param_list.append({'params': model.before_speaker_embedding.parameters(), 'weight_decay': model.before_speaker_embedding_weight_decay})
param_list.append({'params': model.after_speaker_embedding.parameters(), 'weight_decay': model.after_speaker_embedding_weight_decay})
......@@ -1091,6 +1099,7 @@ def xtrain(speaker_number,
param_list.append({'params': model.module.preprocessor.parameters(), 'weight_decay': model.module.preprocessor_weight_decay})
param_list.append({'params': model.module.sequence_network.parameters(), 'weight_decay': model.module.sequence_network_weight_decay})
param_list.append({'params': model.module.stat_pooling.parameters(), 'weight_decay': model.module.stat_pooling_weight_decay})
#param_list.append({'params': model.module.magnitude.parameters(), 'weight_decay': model.module.magnitude_weight_decay})
param_list.append({'params': model.module.before_speaker_embedding.parameters(), 'weight_decay': model.module.before_speaker_embedding_weight_decay})
param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay})
......@@ -1146,6 +1155,7 @@ def xtrain(speaker_number,
test_eer = test_metrics(model, device, num_thread, mixed_precision)
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Reversed Test EER = {rev_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer_norm * 100} %")
# remember best accuracy and save checkpoint
if compute_test_eer:
......@@ -1154,6 +1164,7 @@ def xtrain(speaker_number,
contrastive_eer = test_metrics(model, device, num_thread, mixed_precision, corpus='VCTK5+')
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Reversed Test EER = {rev_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - VCTK EER = {contrastive_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - VCTK EER = {contrastive_eer_norm * 100} %")
best_eer = min(test_eer, best_eer)
else:
is_best = val_eer < best_eer
......@@ -1277,7 +1288,6 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
else:
save_checkpoint({
'epoch': epoch,
......@@ -1324,9 +1334,13 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_size = target.shape[0]
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
output, batch_embeddings = model(data, target=None, is_eval=False)
if loss_criteria == 'cce':
batch_embeddings = l2_norm(batch_embeddings)
if loss_criteria == 'smn':
batch_embeddings, batch_predictions = output
else:
batch_predictions = output
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment