Commit 0ae9bb73 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 1618f7ea
......@@ -381,9 +381,9 @@ class Xtractor(torch.nn.Module):
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
self.sequence_network_weight_decay = 0.0
self.before_speaker_embedding_weight_decay = 0.0
self.after_speaker_embedding_weight_decay = 0.0
self.embedding_size = 512
elif model_archi == "resnet34":
......@@ -403,7 +403,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding = ArcMarginProduct(256,
int(self.speaker_number),
s = 30.0,
m = 0.50,
m = 0.20,
easy_margin = True)
self.preprocessor_weight_decay = 0.000
......@@ -1082,7 +1082,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
optimizer.zero_grad()
if scaler is not None:
with autocast():
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
else:
......@@ -1158,7 +1158,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
batch_size = target.shape[0]
target = target.squeeze().to(device)
data = data.squeeze().to(device)
with torch.cuda.amp.autocastautocast(enabled=mixed_precision):
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == "aam":
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
else:
......@@ -1544,7 +1544,7 @@ def train_cumulative_epoch(model,
target = target.to(device)
REPRENDRE ICI POUR NE FAIRE LE BACKWARD QUE TOUS LES n batchs
#REPRENDRE ICI POUR NE FAIRE LE BACKWARD QUE TOUS LES n batchs
if batch_idx % cumulative_frequency == 0:
optimizer.zero_grad()
......@@ -1625,7 +1625,14 @@ def extract_embeddings(idmap_name,
else:
idmap = IdMap(idmap_name)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
......@@ -1633,7 +1640,7 @@ def extract_embeddings(idmap_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
min_duration=(model.context_size() + 2) * frame_shift * 2
min_duration=(model_cs + 2) * frame_shift * 2
)
dataloader = DataLoader(dataset,
......@@ -1649,11 +1656,18 @@ def extract_embeddings(idmap_name,
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
if type(model) is Xtractor:
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
else:
name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
......@@ -1668,7 +1682,7 @@ def extract_embeddings(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
with autocast(enabled=mixed_precision):
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment