Commit 01af21ba authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent def54548
......@@ -52,8 +52,9 @@ from .xsets import IdMapSetPerSpeaker
from .xsets import SideSampler
from .res_net import ResBlockWFMS
from .res_net import ResBlock
from .res_net import PreResNet34
from .res_net import PreFastResNet34
from .res_net import PreHalfResNet34
from .res_net import PreResNet34
from ..bosaris import IdMap
from ..bosaris import Key
from ..bosaris import Ndx
......@@ -417,20 +418,20 @@ class Xtractor(torch.nn.Module):
("batch_norm5", torch.nn.BatchNorm1d(1536))
]))
self.embedding_size = 512
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(3072, 512))
("linear6", torch.nn.Linear(3072, self.embedding_size))
]))
self.embedding_size = 512
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(512,
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s=64,
m=0.2,
easy_margin=True)
easy_margin=False)
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
......@@ -446,7 +447,6 @@ class Xtractor(torch.nn.Module):
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
self.embedding_size = 512
elif model_archi == "resnet34":
......@@ -503,6 +503,32 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "rawnet2":
if loss not in ["cce", 'aam']:
......@@ -850,6 +876,7 @@ def update_training_dictionary(dataset_description,
# Initialize default dictionaries
dataset_opts["data_path"] = None
dataset_opts["dataset_csv"] = None
dataset_opts["stratify"] = False
dataset_opts["data_file_extension"] = ".wav"
dataset_opts["sample_rate"] = 16000
......@@ -1029,9 +1056,12 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
"""
df = pandas.read_csv(dataset_opts["dataset_csv"])
stratify = None
if dataset_opts["stratify"]:
stratify = df["speaker_idx"]
training_df, validation_df = train_test_split(df,
test_size=dataset_opts["validation_ratio"],
stratify=df["speaker_idx"])
stratify=stratify)
training_set = SideSet(dataset_opts,
set_type="train",
......@@ -1055,7 +1085,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()
batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
......@@ -1068,7 +1098,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"]
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
......@@ -1115,12 +1145,11 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
"""
:param model:
:param model_yaml:
:param model_opts:
:param train_opts:
:param training_loader:
:return:
"""
"""
Set the training options
"""
if train_opts["optimizer"]["type"] == 'adam':
_optimizer = torch.optim.Adam
_options = {'lr': train_opts["lr"]}
......@@ -1161,11 +1190,15 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
optimizer = _optimizer(param_list, **_options)
if train_opts["scheduler"]["type"] == 'CyclicLR':
cycle_momentum = True
if train_opts["optimizer"]["type"] == "aam":
cycle_momentum = False
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2,
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2")
elif train_opts["scheduler"]["type"] == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
......@@ -1541,9 +1574,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
negatives = scores[non_indices]
positives = scores[tar_indices]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
#equal_error_rate = eer(negatives, positives)
pmiss, pfa = rocch(positives, negatives)
equal_error_rate = rocch2eer(pmiss, pfa)
......@@ -1595,15 +1625,8 @@ def extract_embeddings(idmap_name,
else:
idmap = IdMap(idmap_name)
#if type(model) is Xtractor:
# min_duration = (model.context_size() - 1) * win_shift + win_duration
# model_cs = model.context_size()
#else:
# min_duration = (model.module.context_size() - 1) * win_shift + win_duration
# model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
dataset = IdMapSet(idmap_name=idmap,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
......@@ -1615,7 +1638,6 @@ def extract_embeddings(idmap_name,
min_duration=win_duration
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
......@@ -1624,33 +1646,9 @@ def extract_embeddings(idmap_name,
num_workers=num_thread)
with torch.no_grad():
model.eval()
model.to(device)
# Get the size of embeddings to extract
if type(model) is Xtractor:
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
else:
name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
#embeddings = StatServer()
#embeddings.modelset = idmap.leftids
#embeddings.segset = idmap.rightids
#embeddings.start = idmap.start
#embeddings.stop = idmap.stop
#embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
#embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
embed = []
modelset= []
segset = []
......@@ -1662,12 +1660,9 @@ def extract_embeddings(idmap_name,
desc='xvector extraction',
mininterval=1,
disable=None)):
#if data.shape[1] > 20000000:
# data = data[...,:20000000]
print(f"data.shape = {data.shape}")
if data.dim() > 2:
data = data.squeeze()
print(f"data.shape = {data.shape}")
with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
......@@ -1675,9 +1670,6 @@ def extract_embeddings(idmap_name,
_, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu())
#modelset.extend([mod,] * data.shape[0])
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift , win_shift))
......@@ -1717,8 +1709,6 @@ def extract_embeddings_per_speaker(idmap_name,
model = model.to(memory_format=torch.channels_last)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
# Create dataset to load the data
dataset = IdMapSetPerSpeaker(idmap_name=idmap_name,
data_root_path=data_root_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment