Commit 0a2924b9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add fileset

parent a31a26d6
......@@ -727,24 +727,34 @@ def xtrain(speaker_number,
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
if write_batches_to_disk:
output_format = "numpy"
else:
output_format = "pytorch"
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SideSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['train']['chunk_per_segment'],
overlap=dataset_params['train']['overlap'])
overlap=dataset_params['train']['overlap'],
output_format=output_format)
validation_set = SideSet(dataset_yaml, set_type="validation", dataset_df=validation_df)
validation_set = SideSet(dataset_yaml,
set_type="validation",
dataset_df=validation_df,
output_format=output_format)
if write_batches_to_disk:
training_set.write_to_disk(dataset_params["batch_size"], train_batch_fn_format, num_thread)
validation_set.write_to_disk(dataset_params["batch_size"], val_batch_fn_format, num_thread)
else:
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment