Commit 53a1e215 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug transformation

parents 9a6ce925 50d0d1f3
......@@ -272,3 +272,38 @@ class ArcMarginProduct(torch.nn.Module):
output = output * self.s
return output
class SoftmaxAngularProto(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, init_w=10.0, init_b=-5.0, **kwargs):
super(SoftmaxAngularProto, self).__init__()
self.test_normalize = True
self.w = torch.nn.Parameter(torch.tensor(init_w))
self.b = torch.nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(256, spk_count))
]))
def forward(self, x, target=None):
assert x.size()[1] >= 2
cce_prediction = self.cce_backend(x)
if target==None:
return cce_prediction
x = x.reshape(-1,2,x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
cos_sim_matrix = torch.nn.functional.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
return cos_sim_matrix, cce_prediction
......@@ -69,7 +69,8 @@ class SideSampler(torch.utils.data.Sampler):
self.spk_count = spk_count
self.examples_per_speaker = examples_per_speaker
self.samples_per_speaker = samples_per_speaker
self.batch_size = batch_size
assert batch_size % examples_per_speaker == 0
self.batch_size = batch_size//examples_per_speaker
# reference all segment indexes per speaker
for idx in range(self.spk_count):
......@@ -220,16 +221,18 @@ class SideSet(Dataset):
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
self.transform = []
self.transform = dict()
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
transforms = self.transformation["pipeline"].split(',')
if "add_noise" in transforms:
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
self.noise_df = None
if "add_noise" in self.transform:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.rir_df = None
if "add_reverb" in self.transform:
......@@ -246,7 +249,7 @@ class SideSet(Dataset):
current_session = self.sessions.iloc[index]
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
start_frame = int(current_session['start'] * self.sample_rate)
start_frame = int(current_session['start'])
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
......@@ -59,7 +59,7 @@ from ..bosaris import Ndx
from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from .sincnet import SincNet
from .loss import ArcLinear
from .loss import SoftmaxAngularProto, ArcLinear
from .loss import l2_norm
from .loss import ArcMarginProduct
......@@ -240,7 +240,8 @@ def test_metrics(model,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True)
check_missing=True,
device=device)
tar, non = scores.get_tar_non(Key(key_test_filename))
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
......@@ -441,12 +442,16 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.loss = "aam"
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.4,
easy_margin = False)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -733,7 +738,7 @@ class Xtractor(torch.nn.Module):
else:
return self.after_speaker_embedding(x), x
elif self.loss == "aam":
elif self.loss in ['aam', 'aps']:
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
......@@ -813,7 +818,7 @@ def xtrain(speaker_number,
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_name is None:
model = Xtractor(speaker_number, model_yaml)
model = Xtractor(speaker_number, model_yaml, loss=loss)
else:
logging.critical(f"*** Load model from = {model_name}")
......@@ -877,8 +882,8 @@ def xtrain(speaker_number,
else:
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
checkpoint = torch.load(model_name, map_location=device)
model = Xtractor(speaker_number, model_yaml, loss=loss)
"""
Here we remove all layers that we don't want to reload
......@@ -905,6 +910,9 @@ def xtrain(speaker_number,
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size
......@@ -1118,15 +1126,25 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
loss = criterion(output, target)
else:
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
loss = criterion(output, target)
if not torch.isnan(loss):
if scaler is not None:
scaler.scale(loss).backward()
......@@ -1192,8 +1210,10 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
target = target.squeeze().to(device)
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == "aam":
if loss_criteria == 'aam':
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
elif loss_criteria == 'aps':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_embeddings = l2_norm(batch_embeddings)
......@@ -1203,15 +1223,9 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu()
classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
if classes.shape[0] > 2e4:
local_device = "cpu"
else:
local_device = device
mask = ((torch.ger(classes.to(local_device).float() + 1,
(1 / (classes.to(local_device).float() + 1))) == 1).long() * 2 - 1).float().cpu()
mask = mask.numpy()
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu()
scores = scores.numpy()
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment