Commit a3c81d7d authored by Le Lan Gaël's avatar Le Lan Gaël
Browse files

speedup

parent 5da3e20f
......@@ -55,11 +55,18 @@ class MeanStdPooling(torch.nn.Module):
def forward(self, x):
"""
:param x:
:param x: [B, C*F, T]
:return:
"""
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
x = x.flatten(start_dim=1, end_dim=2)
# [B, C*F]
mean = torch.mean(x, dim=2)
# [B, C*F]
std = torch.std(x, dim=2)
# [B, 2*C*F]
return torch.cat([mean, std], dim=1)
......@@ -141,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
def forward(self, x):
"""
:param x: [B, C*F, T] Tensor
:param x: [B, C*F, T]
:return:
"""
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
# [B, C*F, T]
x = x.flatten(start_dim=1, end_dim=2)
if self.global_context:
w = self.attention(torch.cat([x, self.gc(x).unsqueeze(2).repeat(1, 1, x.shape[-1])], dim=1))
else:
......
......@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.layer6(out)
out = self.layer7(out)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
return x
class PreHalfResNet34(torch.nn.Module):
......@@ -535,6 +536,9 @@ class PreHalfResNet34(torch.nn.Module):
:param x:
:return:
"""
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
......@@ -553,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7,
stride=(2, 1), padding=3, bias=False)
stride=(1, 2), padding=3, bias=False)
self.bn1 = torch.nn.BatchNorm2d(16)
# With block = [3, 4, 6, 3]
......@@ -585,12 +589,15 @@ class PreFastResNet34(torch.nn.Module):
:param x:
:return:
"""
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def ResNet34():
......
......@@ -564,7 +564,7 @@ class Xtractor(torch.nn.Module):
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
self.after_speaker_embedding_weight_decay = 0.000
elif model_archi == "rawnet2":
......@@ -813,15 +813,8 @@ class Xtractor(torch.nn.Module):
if self.preprocessor is not None:
x = self.preprocessor(x, is_eval)
x = x.unsqueeze(1)
# Does not work for FastResNet34 !
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = self.sequence_network(x)
#x = x.to(memory_format=torch.contiguous_format)
#print(x.shape)
#x = torch.flatten(x, start_dim=1, end_dim=2)
#print(x.shape)
# Mean and Standard deviation pooling
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
......@@ -1123,7 +1116,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
......@@ -1136,7 +1129,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"]# // dataset_opts["train"]["sampler"]["examples_per_speaker"]
batch_size = dataset_opts["batch_size"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment