Commit fa238b80 authored by Le Lan Gaël's avatar Le Lan Gaël
Browse files

fixed correlation pooling

parent 9d1a30bb
*.pyc
*.DS_Store
docs
.vscode/settings.json
.gitignore
......@@ -87,16 +87,20 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self.mask = self.mask.to(x.device)
if self.training:
x *= torch.nn.functional.dropout(torch.ones((1, x.shape[1], 1, 1), device=x.device), p=self.channels_dropout)
#[B, C, Fr, T, m]
x = x.reshape(x.shape[0], x.shape[1], self.groups, x.shape[-2], self.merge_freqs_count)
#[B, Fr, C, T, m]
x = x.permute(0, 2, 1, 3, 4)
#[B, Fr*C, T, m]
x = x.flatten(start_dim=1, end_dim=2)
#[B, T, C, F]
x = x.permute(0, 2, 1, 3)
#[B, T, C, Fr, f]
x = x.reshape(x.shape[0], x.shape[1], x.shape[-2], self.groups, self.merge_freqs_count)
#[B, T, f, Fr, C]
x = x.permute(0, 1, 4, 3, 2)
#[B, T, f, Fr*C]
x = x.flatten(start_dim=3, end_dim=4)
#[B, Fr*C, T, f]
x = x.permute(0, 3, 1, 2)
#[B, Fr*C', T, f]
x = self.L_proj(x)
#[B, Fr*C', T, m]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
#[B, Fr, C', Tr]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
x -= torch.mean(x, axis=-1, keepdims=True)
out = x/(torch.std(x, axis=-1, keepdims=True) + 1e-5)
#[B, C', C']
......
......@@ -545,10 +545,10 @@ class Xtractor(torch.nn.Module):
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Linear(in_features = int(48*47*5/2),
self.before_speaker_embedding = torch.nn.Linear(in_features = int(64*63*5/2),
out_features = self.embedding_size)
self.stat_pooling = ChannelWiseCorrPooling(in_channels=256, out_channels=48)
self.stat_pooling = ChannelWiseCorrPooling(in_channels=256, out_channels=64)
self.loss = loss
if self.loss == "aam":
......@@ -1134,7 +1134,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
batch_size = dataset_opts["batch_size"]# // dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
......@@ -1594,7 +1594,7 @@ def train_epoch(model,
loss += criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
loss, output = output_tuple
loss, no_margin_output = output_tuple
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment