Commit 01af21ba authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent def54548
...@@ -52,8 +52,9 @@ from .xsets import IdMapSetPerSpeaker ...@@ -52,8 +52,9 @@ from .xsets import IdMapSetPerSpeaker
from .xsets import SideSampler from .xsets import SideSampler
from .res_net import ResBlockWFMS from .res_net import ResBlockWFMS
from .res_net import ResBlock from .res_net import ResBlock
from .res_net import PreResNet34
from .res_net import PreFastResNet34 from .res_net import PreFastResNet34
from .res_net import PreHalfResNet34
from .res_net import PreResNet34
from ..bosaris import IdMap from ..bosaris import IdMap
from ..bosaris import Key from ..bosaris import Key
from ..bosaris import Ndx from ..bosaris import Ndx
...@@ -417,20 +418,20 @@ class Xtractor(torch.nn.Module): ...@@ -417,20 +418,20 @@ class Xtractor(torch.nn.Module):
("batch_norm5", torch.nn.BatchNorm1d(1536)) ("batch_norm5", torch.nn.BatchNorm1d(1536))
])) ]))
self.embedding_size = 512
self.stat_pooling = MeanStdPooling() self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0 self.stat_pooling_weight_decay = 0
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([ self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(3072, 512)) ("linear6", torch.nn.Linear(3072, self.embedding_size))
])) ]))
self.embedding_size = 512
if self.loss == "aam": if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(512, self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number), int(self.speaker_number),
s=64, s=64,
m=0.2, m=0.2,
easy_margin=True) easy_margin=False)
elif self.loss == "cce": elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([ self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)), ("activation6", torch.nn.LeakyReLU(0.2)),
...@@ -446,7 +447,6 @@ class Xtractor(torch.nn.Module): ...@@ -446,7 +447,6 @@ class Xtractor(torch.nn.Module):
self.sequence_network_weight_decay = 0.0002 self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002 self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002 self.after_speaker_embedding_weight_decay = 0.002
self.embedding_size = 512
elif model_archi == "resnet34": elif model_archi == "resnet34":
...@@ -503,6 +503,32 @@ class Xtractor(torch.nn.Module): ...@@ -503,6 +503,32 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00002 self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002 self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "rawnet2": elif model_archi == "rawnet2":
if loss not in ["cce", 'aam']: if loss not in ["cce", 'aam']:
...@@ -850,6 +876,7 @@ def update_training_dictionary(dataset_description, ...@@ -850,6 +876,7 @@ def update_training_dictionary(dataset_description,
# Initialize default dictionaries # Initialize default dictionaries
dataset_opts["data_path"] = None dataset_opts["data_path"] = None
dataset_opts["dataset_csv"] = None dataset_opts["dataset_csv"] = None
dataset_opts["stratify"] = False
dataset_opts["data_file_extension"] = ".wav" dataset_opts["data_file_extension"] = ".wav"
dataset_opts["sample_rate"] = 16000 dataset_opts["sample_rate"] = 16000
...@@ -1029,9 +1056,12 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1029,9 +1056,12 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
""" """
df = pandas.read_csv(dataset_opts["dataset_csv"]) df = pandas.read_csv(dataset_opts["dataset_csv"])
stratify = None
if dataset_opts["stratify"]:
stratify = df["speaker_idx"]
training_df, validation_df = train_test_split(df, training_df, validation_df = train_test_split(df,
test_size=dataset_opts["validation_ratio"], test_size=dataset_opts["validation_ratio"],
stratify=df["speaker_idx"]) stratify=stratify)
training_set = SideSet(dataset_opts, training_set = SideSet(dataset_opts,
set_type="train", set_type="train",
...@@ -1055,7 +1085,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1055,7 +1085,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if training_opts["multi_gpu"]: if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0 assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0 assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//torch.cuda.device_count() batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'], side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"], spk_count=model_opts["speaker_number"],
...@@ -1068,7 +1098,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1068,7 +1098,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"] num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
) )
else: else:
batch_size = dataset_opts["batch_size"] batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'], side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"], spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"], examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
...@@ -1115,12 +1145,11 @@ def get_optimizer(model, model_opts, train_opts, training_loader): ...@@ -1115,12 +1145,11 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
""" """
:param model: :param model:
:param model_yaml: :param model_opts:
:param train_opts:
:param training_loader:
:return: :return:
""" """
"""
Set the training options
"""
if train_opts["optimizer"]["type"] == 'adam': if train_opts["optimizer"]["type"] == 'adam':
_optimizer = torch.optim.Adam _optimizer = torch.optim.Adam
_options = {'lr': train_opts["lr"]} _options = {'lr': train_opts["lr"]}
...@@ -1161,11 +1190,15 @@ def get_optimizer(model, model_opts, train_opts, training_loader): ...@@ -1161,11 +1190,15 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
optimizer = _optimizer(param_list, **_options) optimizer = _optimizer(param_list, **_options)
if train_opts["scheduler"]["type"] == 'CyclicLR': if train_opts["scheduler"]["type"] == 'CyclicLR':
cycle_momentum = True
if train_opts["optimizer"]["type"] == "aam":
cycle_momentum = False
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer, scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8, base_lr=1e-8,
max_lr=train_opts["lr"], max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2, step_size_up=model_opts["speaker_number"] * 2,
step_size_down=None, step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2") mode="triangular2")
elif train_opts["scheduler"]["type"] == "MultiStepLR": elif train_opts["scheduler"]["type"] == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
...@@ -1541,9 +1574,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind ...@@ -1541,9 +1574,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
negatives = scores[non_indices] negatives = scores[non_indices]
positives = scores[tar_indices] positives = scores[tar_indices]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
#equal_error_rate = eer(negatives, positives)
pmiss, pfa = rocch(positives, negatives) pmiss, pfa = rocch(positives, negatives)
equal_error_rate = rocch2eer(pmiss, pfa) equal_error_rate = rocch2eer(pmiss, pfa)
...@@ -1595,15 +1625,8 @@ def extract_embeddings(idmap_name, ...@@ -1595,15 +1625,8 @@ def extract_embeddings(idmap_name,
else: else:
idmap = IdMap(idmap_name) idmap = IdMap(idmap_name)
#if type(model) is Xtractor:
# min_duration = (model.context_size() - 1) * win_shift + win_duration
# model_cs = model.context_size()
#else:
# min_duration = (model.module.context_size() - 1) * win_shift + win_duration
# model_cs = model.module.context_size()
# Create dataset to load the data # Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name, dataset = IdMapSet(idmap_name=idmap,
data_path=data_root_name, data_path=data_root_name,
file_extension=file_extension, file_extension=file_extension,
transform_pipeline=transform_pipeline, transform_pipeline=transform_pipeline,
...@@ -1615,7 +1638,6 @@ def extract_embeddings(idmap_name, ...@@ -1615,7 +1638,6 @@ def extract_embeddings(idmap_name,
min_duration=win_duration min_duration=win_duration
) )
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
...@@ -1624,33 +1646,9 @@ def extract_embeddings(idmap_name, ...@@ -1624,33 +1646,9 @@ def extract_embeddings(idmap_name,
num_workers=num_thread) num_workers=num_thread)
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
model.to(device) model.to(device)
# Get the size of embeddings to extract
if type(model) is Xtractor:
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
else:
name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
#embeddings = StatServer()
#embeddings.modelset = idmap.leftids
#embeddings.segset = idmap.rightids
#embeddings.start = idmap.start
#embeddings.stop = idmap.stop
#embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
#embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
embed = [] embed = []
modelset= [] modelset= []
segset = [] segset = []
...@@ -1662,12 +1660,9 @@ def extract_embeddings(idmap_name, ...@@ -1662,12 +1660,9 @@ def extract_embeddings(idmap_name,
desc='xvector extraction', desc='xvector extraction',
mininterval=1, mininterval=1,
disable=None)): disable=None)):
#if data.shape[1] > 20000000:
# data = data[...,:20000000]
print(f"data.shape = {data.shape}")
if data.dim() > 2: if data.dim() > 2:
data = data.squeeze() data = data.squeeze()
print(f"data.shape = {data.shape}")
with torch.cuda.amp.autocast(enabled=mixed_precision): with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100))) tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
...@@ -1675,9 +1670,6 @@ def extract_embeddings(idmap_name, ...@@ -1675,9 +1670,6 @@ def extract_embeddings(idmap_name,
_, vec = model(x=td.to(device), is_eval=True) _, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu()) embed.append(vec.detach().cpu())
#modelset.extend([mod,] * data.shape[0])
modelset.extend(mod * data.shape[0]) modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0]) segset.extend(seg * data.shape[0])
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift , win_shift)) starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift , win_shift))
...@@ -1717,8 +1709,6 @@ def extract_embeddings_per_speaker(idmap_name, ...@@ -1717,8 +1709,6 @@ def extract_embeddings_per_speaker(idmap_name,
model = model.to(memory_format=torch.channels_last) model = model.to(memory_format=torch.channels_last)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
# Create dataset to load the data # Create dataset to load the data
dataset = IdMapSetPerSpeaker(idmap_name=idmap_name, dataset = IdMapSetPerSpeaker(idmap_name=idmap_name,
data_root_path=data_root_name, data_root_path=data_root_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment