Commit 0d43b570 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

enhanced add_noise augmentation

parent 942bf6b6
......@@ -163,12 +163,18 @@ class Scores:
:return: a vector of target scores.
:return: a vector of non-target scores.
"""
new_score = self.align_with_ndx(key)
tarndx = key.tar & new_score.scoremask
nonndx = key.non & new_score.scoremask
tar = new_score.scoremat[tarndx]
non = new_score.scoremat[nonndx]
return tar, non
if (key.modelset == self.modelset).all() \
and (key.segset == self.segset).all() \
and self.scoremask.shape[0] == key.tar.shape[0] \
and self.scoremask.shape[1] == key.tar.shape[1]:
return self.scoremat[key.tar & self.scoremask], self.scoremat[key.non & self.scoremask]
else:
new_score = self.align_with_ndx(key)
tarndx = key.tar & new_score.scoremask
nonndx = key.non & new_score.scoremask
tar = new_score.scoremat[tarndx]
non = new_score.scoremat[nonndx]
return tar, non
def align_with_ndx(self, ndx):
"""The ordering in the output Scores object corresponds to ndx, so
......
......@@ -103,10 +103,9 @@ def cosine_scoring(enroll, test, ndx, wccn=None, check_missing=True, device=None
device = torch.device("cuda:0" if torch.cuda.is_available() and s_size_in_bytes < 3e9 else "cpu")
else:
device = device if torch.cuda.is_available() and s_size_in_bytes < 3e9 else torch.device("cpu")
s = torch.mm(torch.FloatTensor(enroll_copy.stat1).to(device), torch.FloatTensor(test_copy.stat1).to(device).T).cpu().numpy()
score = Scores()
score.scoremat = s
score.scoremat = torch.einsum('ij,kj', torch.FloatTensor(enroll_copy.stat1).to(device), torch.FloatTensor(test_copy.stat1).to(device)).cpu().numpy()
score.modelset = clean_ndx.modelset
score.segset = clean_ndx.segset
score.scoremask = clean_ndx.trialmask
......
......@@ -485,44 +485,33 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
if "add_noise" in augmentations:
# Pick a noise sample from the noise_df
noise_row = noise_df.iloc[random.randrange(noise_df.shape[0])]
noise_type = noise_row['type']
noise_start = noise_row['start']
noise_duration = noise_row['duration']
noise_file_id = noise_row['file_id']
# Pick a SNR level
# TODO make SNRs configurable by noise type
if noise_type == 'music':
# Pick a noise type
noise = torch.zeros_like(speech)
noise_idx = random.randrange(3)
# speech
if noise_idx == 0:
# Pick a SNR level
# TODO make SNRs configurable by noise type
snr_db = random.randint(13, 20)
pick_count = random.randint(5, 10)
index_list = random.choices(range(noise_df.loc['speech'].shape[0]), k=pick_count)
for idx in index_list:
noise_row = noise_df.loc['speech'].iloc[idx]
noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"])
noise /= pick_count
# music
elif noise_idx == 1:
snr_db = random.randint(5, 15)
elif noise_type == 'noise':
noise_row = noise_df.loc['music'].iloc[random.randrange(noise_df.loc['music'].shape[0])]
noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"])
# noise
elif noise_idx == 2:
snr_db = random.randint(0, 15)
else:
snr_db = random.randint(13, 20)
if noise_duration * sample_rate > speech.shape[1]:
# We force frame_offset to stay in the 20 first seconds of the file, otherwise it takes too long to load
frame_offset = random.randrange(noise_start * sample_rate, min(int(20*sample_rate), int((noise_start + noise_duration) * sample_rate - speech.shape[1])))
else:
frame_offset = noise_start * sample_rate
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech.shape[1]:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech.shape[1]))
else:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
noise_row = noise_df.loc['noise'].iloc[random.randrange(noise_df.loc['noise'].shape[0])]
noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"])
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise.shape[1] < speech.shape[1]:
noise = torch.tensor(numpy.resize(noise.numpy(), speech.shape))
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
......@@ -539,6 +528,26 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
return speech
def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
noise_start = noise_row['start']
noise_duration = noise_row['duration']
noise_file_id = noise_row['file_id']
frame_offset = noise_start * sample_rate
noise_fn = data_path + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech_shape[1]:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
else:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise_seg.shape[1] < speech_shape[1]:
noise_seg = torch.tensor(numpy.resize(noise_seg.numpy(), speech_shape))
return noise_seg
"""
It might not be 100% on topic, but maybe this is interesting for you anyway. If you do not need to do real time processing, things can be made more easy. Limiting and dynamic compression can be seen as applying a dynamic transfer function. This function just maps input to output values. A linear function then returns the original audio and a "curved" function does compression or expansion. Applying a transfer function is as simple as
......
......@@ -251,9 +251,6 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, target=None):
assert input.size()[0] == target.size()[0]
assert input.size()[1] == self.in_features
# cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
if target == None:
......@@ -278,7 +275,7 @@ class ArcMarginProduct(torch.nn.Module):
class SoftmaxAngularProto(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, init_w=10.0, init_b=-5.0, **kwargs):
def __init__(self, spk_count, emb_dim=256, init_w=10.0, init_b=-5.0, **kwargs):
super(SoftmaxAngularProto, self).__init__()
self.test_normalize = True
......@@ -288,7 +285,7 @@ class SoftmaxAngularProto(torch.nn.Module):
self.criterion = torch.nn.CrossEntropyLoss()
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(256, spk_count))
("linear8", torch.nn.Linear(emb_dim, spk_count))
]))
def forward(self, x, target=None):
......@@ -309,3 +306,57 @@ class SoftmaxAngularProto(torch.nn.Module):
cos_sim_matrix = cos_sim_matrix * self.w + self.b
return cos_sim_matrix, cce_prediction
class SoftmaxMagnet(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, emb_dim=256, batch_size=256, init_w=1, init_b=0, **kwargs):
super(SoftmaxMagnet, self).__init__()
self.test_normalize = True
self.w = torch.nn.Parameter(torch.tensor(init_w))
self.b = torch.nn.Parameter(torch.tensor(init_b))
self.magnitude = torch.nn.Sequential(OrderedDict([
("linear9", torch.nn.Linear(emb_dim, 1)),
("relu9", torch.nn.ReLu()),
("linear9", torch.nn.Linear(512, 512)),
("relu9", torch.nn.ReLu()),
("linear9", torch.nn.Linear(512, 1)),
("relu9", torch.nn.ReLu())
]))
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(emb_dim, spk_count))
]))
self.classification_criterion = torch.nn.CrossEntropyLoss()
self.magnet_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
def forward(self, x, target=None):
assert x.size()[1] >= 2
cce_prediction = self.cce_backend(x)
if target==None:
return cce_prediction
x = self.magnitude(x) * torch.nn.functional.normalize(x)
x = x.reshape(-1,2,x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
labels = torch.arange(0, int(out_positive.shape[0]), device=torch.device("cuda:0")).unsqueeze(1)
cos_sim_matrix = torch.mm(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2).T)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix + self.b
cos_sim_matrix = cos_sim_matrix + torch.log(1/out_positive.shape[0] / (1 - 1/out_positive.shape[0]))
mask = torch.tile(labels, (1, labels.shape[0])) == labels.T
return self.magnet_criterion(cos_sim_matrix.flatten().unqueeze(1), mask.long().flatten().unqueeze(1)), self.classification_criterion(cce_prediction, target)
......@@ -473,6 +473,47 @@ class PreResNet34(torch.nn.Module):
return out
class PreHalfResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 4, 6, 3], speaker_number=10):
super(PreHalfResNet34, self).__init__()
self.in_planes = 32
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3,
stride=(1, 1), padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(32)
# With block = [3, 4, 6, 3]
self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=(1, 1))
self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=(2, 2))
self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=(2, 2))
self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=(2, 2))
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
class PreFastResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
......
......@@ -189,7 +189,7 @@ class SideSet(Dataset):
df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1):
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
current_session = tmp_sessions.iloc[idx]
# Compute possible starts
......@@ -230,8 +230,8 @@ class SideSet(Dataset):
self.noise_df = None
if "add_noise" in self.transform:
# Load the noise dataset, filter according to the duration
self.noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.noise_df = noise_df.set_index(noise_df.type)
self.rir_df = None
if "add_reverb" in self.transform:
......
......@@ -56,13 +56,14 @@ from .xsets import SideSampler
from .res_net import ResBlockWFMS
from .res_net import ResBlock
from .res_net import PreResNet34
from .res_net import PreFastResNet34
from .res_net import PreFastResNet34, PreHalfResNet34
from ..bosaris import IdMap
from ..bosaris import Key
from ..bosaris import Ndx
from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from .sincnet import SincNet
from ..bosaris.detplot import rocch, rocch2eer
from .loss import SoftmaxAngularProto, ArcLinear
from .loss import l2_norm
from .loss import ArcMarginProduct
......@@ -227,10 +228,11 @@ def test_metrics(model,
Returns:
[type]: [description]
"""
idmap_test_filename = 'h5f/idmap_test.h5'
ndx_test_filename = 'h5f/ndx_test.h5'
key_test_filename = 'h5f/key_test.h5'
data_root_name='/home/rsgb7088/data/vox1/test/wav'
idmap_test_filename = 'h5f/vox1_test_cleaned_idmap.h5'
ndx_test_filename = 'h5f/vox1_test_cleaned_ndx.h5'
key_test_filename = 'h5f/vox1_test_cleaned_key.h5'
data_root_name='/hdd/data/vox1/test/wav'
transform_pipeline = dict()
......@@ -243,16 +245,18 @@ def test_metrics(model,
mixed_precision=mixed_precision,
backward=False)
scores = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device)
tar, non = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device
).get_tar_non(Key(key_test_filename))
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss, pfa = rocch(tar, non)
tar, non = scores.get_tar_non(Key(key_test_filename))
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
return test_eer
return rocch2eer(pmiss, pfa)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
......@@ -304,14 +308,14 @@ class AttentivePooling(torch.nn.Module):
# TODO Make convolution parameters configurable
super(AttentivePooling, self).__init__()
self.attention = torch.nn.Sequential(
torch.nn.Conv1d(num_channels * 10 * 3, 128, kernel_size=1),
torch.nn.Conv1d(num_channels * 10, 128, kernel_size=1),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(128),
torch.nn.Tanh(),
torch.nn.Conv1d(128, num_channels * 10, kernel_size=1),
torch.nn.Softmax(dim=2),
)
self.global_context = MeanStdPooling()
#self.global_context = MeanStdPooling()
def new_parameter(self, *size):
out = torch.nn.Parameter(torch.FloatTensor(*size))
......@@ -325,10 +329,10 @@ class AttentivePooling(torch.nn.Module):
:return:
"""
global_context = self.global_context(x).unsqueeze(2).repeat(1, 1, x.shape[-1])
#global_context = self.global_context(x).unsqueeze(2).repeat(1, 1, x.shape[-1])
w = self.attention(torch.cat([x, global_context], dim=1))
#w = self.attention(x)
#w = self.attention(torch.cat([x, global_context], dim=1))
w = self.attention(x)
mu = torch.sum(x * w, dim=2)
rh = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-5) )
......@@ -485,7 +489,6 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "fastresnet34":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreFastResNet34()
self.embedding_size = 256
......@@ -496,6 +499,33 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = AttentivePooling(128)
self.stat_pooling_weight_decay = 0
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = 512
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256)
self.stat_pooling_weight_decay = 0
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
......@@ -796,7 +826,7 @@ class Xtractor(torch.nn.Module):
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
x = self.after_speaker_embedding(x, target=target), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -869,7 +899,7 @@ def xtrain(speaker_number,
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Use a predefined architecture
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
if model_name is None:
model = Xtractor(speaker_number, model_yaml, loss=loss)
......@@ -1005,7 +1035,7 @@ def xtrain(speaker_number,
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
1,
100,
128,
dataset_params["batch_size"])
training_loader = DataLoader(training_set,
......@@ -1056,8 +1086,8 @@ def xtrain(speaker_number,
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=10 * training_loader.__len__(),
gamma=0.95)
step_size=12 * training_loader.__len__(),
gamma=0.75)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
......@@ -1071,13 +1101,12 @@ def xtrain(speaker_number,
test_eer = 100.
classes = torch.LongTensor(validation_set.sessions['speaker_idx'].to_numpy())
local_device = "cpu" if classes.shape[0] > 3e4 else device
mask = ((torch.ger(classes.to(local_device).float() + 1,
(1 / (classes.to(local_device).float() + 1))) == 1).long() * 2 - 1).float().cpu()
mask = mask.numpy()
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
classes = torch.ByteTensor(validation_set.sessions['speaker_idx'].to_numpy())
mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
tar_indices = torch.tril(mask, -1).numpy()
non_indices = torch.tril(~mask, -1).numpy()
tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[1-tar_non_ratio, tar_non_ratio])
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
......@@ -1096,7 +1125,7 @@ def xtrain(speaker_number,
# Add the cross validation here
if math.fmod(epoch, 1) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mask, mixed_precision)
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], tar_indices, non_indices, mixed_precision)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
if compute_test_eer:
......@@ -1127,7 +1156,8 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_archi
'model_archi': model_archi,
'loss': loss
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
......@@ -1137,7 +1167,8 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_archi
'model_archi': model_archi,
'loss': loss
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
......@@ -1242,7 +1273,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
return model
def cross_validation(model, validation_loader, device, validation_shape, mask, mixed_precision=False):
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
"""
:param model:
......@@ -1264,7 +1295,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
embeddings = torch.zeros(validation_shape)
#classes = torch.zeros([validation_shape[0]])
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
target = target.squeeze()
target = target.to(device)
batch_size = target.shape[0]
......@@ -1284,14 +1315,17 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
positives = scores[numpy.argwhere(mask == 1)][:, 0].astype(float)
embeddings = embeddings.to(local_device)
scores = torch.einsum('ij,kj', embeddings, embeddings).cpu().numpy()
negatives = scores[non_indices]
positives = scores[tar_indices]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
equal_error_rate = eer(negatives, positives)
#pmiss, pfa = rocch(positives, negatives)
#equal_error_rate = rocch2eer(pmiss, pfa)
return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
equal_error_rate)
......@@ -1331,7 +1365,7 @@ def extract_embeddings(idmap_name,
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model = Xtractor(speaker_number, model_archi=model_archi, loss=checkpoint["loss"])
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......@@ -1398,7 +1432,8 @@ def extract_embeddings(idmap_name,
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
mininterval=1,
disable=None)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
with torch.cuda.amp.autocast(enabled=mixed_precision):
......@@ -1470,7 +1505,7 @@ def extract_embeddings_per_speaker(idmap_name,
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1, disable=None)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
print(f"Shape of data: {data.shape}")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment