Commit 5da3e20f authored by Le Lan Gaël's avatar Le Lan Gaël
Browse files

Merge branch 'dev-gl3lan' into corr_pooling

parents fa238b80 be3c8478
*.pyc
*.DS_Store
docs
.vscode/settings.json
.vscode
.gitignore
.vscode
.history
......@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "none" in augmentations:
pass
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2)
......@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape = speech.shape[1]
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": 'ALAW', "bits_per_sample": 8}, "8 bit a-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis")
......
......@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self):
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch)
......@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap=0.,
dataset_df=None,
min_duration=0.165,
output_format="pytorch",
output_format="pytorch"
):
"""
......@@ -269,6 +269,8 @@ class SideSet(Dataset):
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
if "stretch" in transforms:
self.transform["stretch"] = []
self.noise_df = None
if "add_noise" in self.transform:
......@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
if self.idmap.stop[index] is None:
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
duration = int(speech.shape[1] - start)
else:
duration = int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
assert nfo.sample_rate == self.sample_rate
conversion_rate = nfo.sample_rate // self.sample_rate
duration = (int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start)
# add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2
start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
duration = int(self.min_duration * self.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
frame_offset=start * conversion_rate,
num_frames=duration * conversion_rate)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
speech += 10e-6 * torch.randn(speech.shape)
......
......@@ -530,7 +530,8 @@ class Xtractor(torch.nn.Module):
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number),
emb_dim=self.embedding_size)
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
......@@ -1095,7 +1096,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df, validation_df = train_test_split(df,
test_size=dataset_opts["validation_ratio"],
stratify=stratify)
# TODO
torch.manual_seed(training_opts['torch_seed'] + local_rank)
torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank)
......@@ -1105,7 +1107,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number=dataset_opts['train']['transform_number'],
overlap=dataset_opts['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
output_format="pytorch"
)
validation_set = SideSet(dataset_opts,
......@@ -1628,7 +1630,7 @@ def train_epoch(model,
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
training_acc=100.0 * accuracy / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment