Commit 19953b86 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

minor

parent cd7ea34d
......@@ -448,7 +448,7 @@ class Xtractor(torch.nn.Module):
elif k.startswith("ctrans"):
segmental_layers.append((k, torch.nn.ConvTranspose1d(input_size,
cfg["segmental"][k][":"],
cfg["segmental"][k]["output_channels"],
kernel_size=cfg["segmental"][k]["kernel_size"],
dilation=cfg["segmental"][k]["dilation"])))
elif k.startswith("activation"):
......@@ -649,11 +649,8 @@ def xtrain(speaker_number,
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None and model_yaml in ["xvector", "rawnet2"]:
# Initialize a first model
......@@ -740,7 +737,7 @@ def xtrain(speaker_number,
Then we provide those two
"""
if write_batches_to_disk:
if write_batches_to_disk or dataset_params["batch_size"] > 1:
output_format = "numpy"
else:
output_format = "pytorch"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment