Commit 1618f7ea authored by Anthony Larcher's avatar Anthony Larcher
Browse files

autocast

parent 9906ef08
......@@ -163,7 +163,9 @@ def plot_classes_preds(model, speech, labels):
def test_metrics(model,
device,
speaker_number):
speaker_number,
num_thread,
mixed_precision):
"""Compute model metrics
Args:
......@@ -202,7 +204,9 @@ def test_metrics(model,
model_filename=model,
data_root_name=data_root_name,
device=device,
transform_pipeline=transform_pipeline)
transform_pipeline=transform_pipeline,
num_thread=num_thread,
mixed_precision=mixed_precision)
scores = cosine_scoring(xv_stat,
xv_stat,
......@@ -400,7 +404,7 @@ class Xtractor(torch.nn.Module):
int(self.speaker_number),
s = 30.0,
m = 0.50,
easy_margin = False)
easy_margin = True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -689,15 +693,15 @@ class Xtractor(torch.nn.Module):
if self.loss == "cce":
if is_eval:
return self.after_speaker_embedding(x), x
return x
else:
return self.after_speaker_embedding(x)
return self.after_speaker_embedding(x), x
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=None), torch.nn.functional.normalize(x, dim=1)
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -726,6 +730,7 @@ def xtrain(speaker_number,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
mixed_precision=False,
clipping=False,
opt=None,
reset_parts=[],
......@@ -836,7 +841,13 @@ def xtrain(speaker_number,
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
print(model)
logging.critical(model)
logging.critical("model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size
if torch.cuda.device_count() > 1 and multi_gpu:
......@@ -904,11 +915,11 @@ def xtrain(speaker_number,
drop_last=True,
pin_memory=True,
num_workers=num_thread,
persistent_workers=False)
persistent_workers=True)
validation_loader = DataLoader(validation_set,
batch_size=batch_size,
drop_last=True,
drop_last=False,
pin_memory=True,
num_workers=num_thread,
persistent_workers=False)
......@@ -950,6 +961,11 @@ def xtrain(speaker_number,
last_epoch=-1,
verbose=False)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
else:
scaler = None
best_accuracy = 0.0
best_accuracy_epoch = 1
best_eer = 100
......@@ -960,9 +976,10 @@ def xtrain(speaker_number,
validation_loader,
device,
[validation_set.__len__(),
embedding_size])
embedding_size],
mixed_precision)
test_eer = test_metrics(model, device, speaker_number)
test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
......@@ -978,14 +995,15 @@ def xtrain(speaker_number,
optimizer,
scheduler,
dataset_params["log_interval"],
device=device,
device,
scaler=scaler,
clipping=clipping)
# Add the cross validation here
if math.fmod(epoch, 136) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size])
if math.fmod(epoch, 5) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mixed_precision)
test_eer = test_metrics(model, device, speaker_number)
test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
......@@ -994,7 +1012,8 @@ def xtrain(speaker_number,
#scheduler.step(val_loss)
# remember best accuracy and save checkpoint
is_best = val_acc > best_accuracy
is_best = test_eer > best_eer
best_eer = max(test_eer, best_eer)
best_accuracy = max(val_acc, best_accuracy)
if tmp_model_name is None:
......@@ -1034,7 +1053,7 @@ def xtrain(speaker_number,
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, clipping=False):
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, scaler=None, clipping=False):
"""
:param model:
......@@ -1062,20 +1081,30 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
target = target.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
if scaler is not None:
with autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
else:
if loss_criteria == 'aam':
output, _ = model(data, target=target)
else:
output = model(data, target=None)
output, _ = model(data, target=None)
loss = criterion(output, target)
if not torch.isnan(loss):
loss.backward()
if clipping:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
running_loss += loss.item()
optimizer.step()
if scaler is not None:
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % log_interval == 0:
......@@ -1103,7 +1132,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
return model
def cross_validation(model, validation_loader, device, validation_shape):
def cross_validation(model, validation_loader, device, validation_shape, mixed_precision=False):
"""
:param model:
......@@ -1129,11 +1158,11 @@ def cross_validation(model, validation_loader, device, validation_shape):
batch_size = target.shape[0]
target = target.squeeze().to(device)
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=False):
with torch.cuda.amp.autocastautocast(enabled=mixed_precision):
if loss_criteria == "aam":
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
......@@ -1158,7 +1187,7 @@ def cross_validation(model, validation_loader, device, validation_shape):
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
equal_error_rate = eer(negatives, positives)
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
return 100. * accuracy.cpu().numpy() / validation_shape[0], \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size), equal_error_rate
......@@ -1575,10 +1604,11 @@ def extract_embeddings(idmap_name,
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
num_thread=1):
num_thread=1,
mixed_precision=False):
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
checkpoint = torch.load(model_filename, map_location=device)
if speaker_number is None:
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
......@@ -1619,7 +1649,9 @@ def extract_embeddings(idmap_name,
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
......@@ -1636,7 +1668,8 @@ def extract_embeddings(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
preds, vec = model(data.to(device), is_eval=True)
with autocast(enabled=mixed_precision):
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment