Commit 1c00f39a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

one batch size

parent 421001a1
......@@ -323,6 +323,8 @@ class Xtractor(torch.nn.Module):
if model_archi == "xvector":
self.input_nbdim = 2
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
else:
......@@ -382,6 +384,7 @@ class Xtractor(torch.nn.Module):
self.embedding_size = 512
elif model_archi == "resnet34":
self.input_nbdim = 3
self.preprocessor = None
self.sequence_network = PreResNet34()
......@@ -414,6 +417,8 @@ class Xtractor(torch.nn.Module):
else:
self.loss = loss
self.input_nbdim = 2
filts = [128, [128, 128], [128, 256], [256, 256]]
self.norm_embedding = True
......@@ -519,6 +524,11 @@ class Xtractor(torch.nn.Module):
else:
self.activation = torch.nn.ReLU()
if cfg["segmental"][list(cfg["segmental"].keys())[0]].startswith("conv2D"):
self.input_nbdim = 3
elif cfg["segmental"][list(cfg["segmental"].keys())[0]].startswith("conv"):
self.input_nbdim = 2
# Create sequential object for the first part of the network
segmental_layers = []
for k in cfg["segmental"].keys():
......@@ -1019,7 +1029,6 @@ def xtrain(speaker_number,
curr_patience = patience
else:
curr_patience -= 1
#writer.close()
for ii in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(ii))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment