Commit 85663495 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

add batch_norm

parent 6c2584f8
......@@ -165,9 +165,9 @@ def data_augmentation(speech,
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2)
speech = strech(speech, rate)
stretched_length = int(speech.shape[1] * random.uniform(0.95,1.05))
speech = torch.zeros_like(speech)
speech[:, :min(speech.shape[1], stretched_length)] = torch.tensor(signal.resample(speech, stretched_length))[:, :min(speech.shape[1], stretched_length)]
if "add_reverb" in augmentations:
rir_nfo = rir_df.iloc[random.randrange(rir_df.shape[0])].file_id
......@@ -236,6 +236,7 @@ def data_augmentation(speech,
final_shape = speech.shape[1]
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": 'ALAW', "bits_per_sample": 8}, "8 bit a-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis")
......
......@@ -252,7 +252,8 @@ class ArcMarginProduct(torch.nn.Module):
def forward(self, input, target=None):
# cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input),
torch.nn.functional.normalize(self.weight))
if target == None:
return cosine * self.s
# cos(theta + m)
......
......@@ -271,6 +271,8 @@ class SideSet(Dataset):
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
if "stretch" in transforms:
self.transform["stretch"] = []
self.noise_df = None
if "add_noise" in self.transform:
......
......@@ -499,8 +499,8 @@ class Xtractor(torch.nn.Module):
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
s = 30.0,
m = 0.20,
easy_margin = False)
elif self.loss == 'aps':
......@@ -516,21 +516,24 @@ class Xtractor(torch.nn.Module):
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024,
hop_length=256,
win_length=400,
hop_length=160,
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("emb1", torch.nn.Linear(in_features = 5120, out_features = self.embedding_size)),
("bn1", torch.nn.BatchNorm1d(self.embedding_size))
]))
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
s = 30.0,
m = 0.20,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
......@@ -1219,7 +1222,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=train_opts["lr"],
step_size_up=training_loader.__len__() * 8,
step_size_up=training_loader.__len__() * 12,
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2")
......@@ -1602,7 +1605,7 @@ def train_epoch(model,
running_loss = 0.0
accuracy = 0.0
batch_count = 0
running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(training_monitor.best_eer)
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment