Commit 6f10df46 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug spkset

parent b405f263
......@@ -414,4 +414,4 @@ wavfile.write("output_comp.wav", sr, x3)
"""
\ No newline at end of file
"""
......@@ -48,7 +48,7 @@ from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .xsets import SpkSet
from .res_net import RawPreprocessor, ResBlockWFMS, ResBlock, PreResNet34
from .res_net import RawPreprocessor, ResBlockWFMS, ResBlock, PreResNet34, PreFastResNet34
from ..bosaris import IdMap
from ..bosaris import Key
from ..bosaris import Ndx
......@@ -417,7 +417,7 @@ class Xtractor(torch.nn.Module):
self.preprocessor = None
self.sequence_network = PreFastResNet34()
self.before_speaker_embedding = torch.nn.Linear(in_features = 1280,
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = 256)
self.stat_pooling = MeanStdPooling()
......@@ -804,7 +804,7 @@ def xtrain(speaker_number,
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None and model_yaml in ["xvector", "rawnet2", "resnet34"]:
if model_name is None and model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
# Initialize a first model
if model_yaml == "xvector":
model = Xtractor(speaker_number, "xvector", loss=loss)
......@@ -812,6 +812,8 @@ def xtrain(speaker_number,
model = Xtractor(speaker_number, "rawnet2")
elif model_yaml == "resnet34":
model = Xtractor(speaker_number, "resnet34")
elif model_yaml == "fastresnet34":
model = Xtractor(speaker_number, "fastresnet34")
model_archi = model_yaml
else:
with open(model_yaml, 'r') as fh:
......@@ -906,12 +908,20 @@ def xtrain(speaker_number,
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SpkSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
overlap=dataset_params['train']['overlap'],
output_format="pytorch",
windowed=True)
#training_set = SpkSet(dataset_yaml,
# set_type="train",
# dataset_df=training_df,
# overlap=dataset_params['train']['overlap'],
# output_format="pytorch",
# windowed=True)
training_set = SideSet(dataset_yaml,
set_type="train",
overlap=dataset_params['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
)
validation_set = SideSet(dataset_yaml,
set_type="validation",
......@@ -939,7 +949,7 @@ def xtrain(speaker_number,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_thread,
num_workers=1,#num_thread,
persistent_workers=True)
validation_loader = DataLoader(validation_set,
......@@ -1024,6 +1034,8 @@ def xtrain(speaker_number,
scaler=scaler,
clipping=clipping)
print("end of train epoch")
# Add the cross validation here
if math.fmod(epoch, 5) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mixed_precision)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment