Commit ad033a0c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

modulable xvecfors

parent 87794562
......@@ -216,13 +216,12 @@ def PLDA_scoring(enroll,
Vtrans=None,
p_known=0.0,
scaling_factor=1.,
full_model=False):
"""Compute the PLDA scores between to sets of vectors. The list of
"""Compute the PLDA scores between two sets of vectors. The list of
trials to perform is given in an Ndx object. PLDA matrices have to be
pre-computed. i-vectors are supposed to be whitened before.
Implements the appraoch described in [Lee13]_ including scoring
Implements the approach described in [Lee13]_ including scoring
for partially open-set identification
:param enroll: a StatServer in which stat1 are i-vectors
......
......@@ -103,6 +103,69 @@ def prepare_voxceleb1(vox1_root_dir, output_batch_file, seg_duration=4, samplera
fletcher32=True)
def prepare_voxceleb2(vox2_root_dir, output_batch_file, seg_duration=4, samplerate=16000):
# List wav files in VoxCeleb2
vox2_wav_list = [str(f) for f in list(Path(vox2_root_dir).rglob("*.[wW][aA][vV]"))]
vox2_dfs = [pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx"))] * 5
vox2_sublists = [[]]*5
lv2 = len(vox2_wav_list)
vox2_sublists[0] = vox2_wav_list[:lv2//2]
vox2_sublists[1] = vox2_wav_list[lv2 // 2: 2*(lv2 // 2)]
vox2_sublists[2] = vox2_wav_list[2*(lv2 // 2): 3 * (lv2 // 2)]
vox2_sublists[3] = vox2_wav_list[3 * (lv2 // 2): 4 * (lv2 // 2)]
vox2_sublists[3] = vox2_wav_list[4 * (lv2 // 2):]
print("*** Collect information from VoxCeleb2 data ***")
vox2_dfs = []
vox2_dfs.append(pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx")))
vox2_dfs.append(pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx")))
vox2_dfs.append(pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx")))
vox2_dfs.append(pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx")))
vox2_dfs.append(pandas.DataFrame(columns=("database", "speaker_id", "file_id", "duration", "speaker_idx")))
for idx, sublist in enumerate(vox2_sublists):
for fn in tqdm(sublist):
file_id = ('/').join(fn.split('/')[-2:]).split('.')[0]
speaker_id = fn.split('/')[-3]
_set = fn.split('/')[-5]
# get the duration of the wav file
data, _ = soundfile.read(fn)
duration = data.shape[0]
vox2_dfs[idx].append(
{"database": "vox2", "speaker_id": speaker_id, "file_id": file_id, "duration": duration, "speaker_idx": -1,
"set": _set}, ignore_index=True)
print("\n\n*** Create 5 HDF5 files with all training data ***")
# Create a HDF5 file and fill it with one 4s segment per session
obf = output_batch_file + f"_{idx}"
with h5py.File(obf, 'w') as fh:
for index, row in tqdm(vox2_dfs[idx].iterrows()):
session_id = row['speaker_id'] + '/' + row['file_id']
# Load the wav signal
fn = '/'.join((vox2_root_dir, row['set'], 'wav', session_id)) + ".wav"
data, samplerate = soundfile.read(fn, dtype='int16')
_nb_samp = len(data)
# Randomly select a segment of "duration" if it's long enough
if _nb_samp > nb_samp:
cut = numpy.random.randint(low = 0, high = _nb_samp - nb_samp)
# Write the segment in the HDF5 file
fh.create_dataset(session_id,
data=data[cut:cut+nb_samp].astype('int16'),
maxshape=(None,),
fletcher32=True)
class PreEmphasis(object):
"""
Perform pre-emphasis filtering on audio segment
......
......@@ -69,40 +69,103 @@ class Xtractor(torch.nn.Module):
"""
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, spk_number, dropout, activation='LeakyReLU'):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(30, 512, 5, dilation=1)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
self.dropout_lin0 = torch.nn.Dropout(p=dropout)
self.seg_lin1 = torch.nn.Linear(512, 512)
self.dropout_lin1 = torch.nn.Dropout(p=dropout)
self.seg_lin2 = torch.nn.Linear(512, spk_number)
#
self.norm0 = torch.nn.BatchNorm1d(512)
self.norm1 = torch.nn.BatchNorm1d(512)
self.norm2 = torch.nn.BatchNorm1d(512)
self.norm3 = torch.nn.BatchNorm1d(512)
self.norm4 = torch.nn.BatchNorm1d(3 * 512)
self.norm6 = torch.nn.BatchNorm1d(512)
self.norm7 = torch.nn.BatchNorm1d(512)
#
if activation == 'LeakyReLU':
self.activation = torch.nn.LeakyReLU(0.2)
elif activation == 'ReLU':
self.activation = torch.nn.ReLU()
elif activation == 'PReLU':
self.activation = torch.nn.PReLU()
elif activation == 'ReLU6':
self.activation = torch.nn.ReLU6()
elif activation == 'SELU':
self.activation = torch.nn.SELU()
else:
raise ValueError("Activation function is not implemented")
def __init__(self, speaker_number, config=None):
"""
If config is None, default architecture is created
:param config:
"""
self.speaker_number = speaker_number
if config is None:
self.sequence_network = nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(30, 512, 5, dilation=1)),
("activation1", torch.nn.LeakyReLU(0.2)),
("norm1", torch.nn.BatchNorm1d(512)),
("conv2", torch.nn.Conv1d(512, 512, 3, dilation=2)),
("activation2", torch.nn.LeakyReLU(0.2)),
("norm2", torch.nn.BatchNorm1d(512)),
("conv3", torch.nn.Conv1d(512, 512, 3, dilation=3)),
("activation3", torch.nn.LeakyReLU(0.2)),
("norm3", torch.nn.BatchNorm1d(512)),
("conv4", torch.nn.Conv1d(512, 512)),
("activation4", torch.nn.LeakyReLU(0.2)),
("norm4", torch.nn.BatchNorm1d(512)),
("conv5", torch.nn.Conv1d(512, 1536)),
("activation5", torch.nn.LeakyReLU(0.2)),
("norm5", torch.nn.BatchNorm1d(1536))
]))
self.before_speaker_embedding = nn.Sequential(OrderedDict([
("linear6", torch.nn.linear(1536, 512))
]))
self.after_speaker_embedding = nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("linear7", torch.nn.linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.linear(512, self.speaker_number ))
]))
else:
# Load Yaml configuration
with open(config, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
# Get Feature size
self.feature_size = cfg["feature_size"]
input_size = self.feature_size
# Get activation function
if cfg["activation"] == 'LeakyReLU':
self.activation = torch.nn.LeakyReLU(0.2)
elif cfg["activation"] == 'PReLU':
self.activation = torch.nn.PReLU()
elif cfg["activation"] == 'ReLU6':
self.activation = torch.nn.ReLU6()
else:
self.activation = torch.nn.ReLU()
# Create sequential object for the first part of the network
segmental_layers = []
for k in cfg["segmental"].keys():
if k.startswith("conv"):
segmental_layers.append((k, torch.nn.Conv2d(input_size,
cfg["segmental"][k]["output_channels"],
cfg["segmental"][k]["kernel_size"],
cfg["segmental"][k]["dilation"])))
input_size = cfg["segmental"][k]["output_channels"]
elif k.startswith("activation"):
segmental_layers.append((k, self.activation))
elif k.startswith('norm'):
segmental_layers.append((k, torch.nn.BatchNorm1d(input_size)))
self.sequence_network = nn.Sequential(OrderedDict(segmental_layers))
# Create sequential object for the second part of the network
input_size = input_size * 2
embedding_layers = []
for k in cfg["embedding"].keys():
if k.startswith("lin"):
if cfg["embedding"][k]["output"] == "speaker_number":
embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
else:
embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
input_size = cfg["embedding"][k]["output"]
elif k.startswith("activation"):
embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
self.before_speaker_embedding = nn.Sequential(OrderedDict(embedding_layers))
def produce_embeddings(self, x):
"""
......@@ -123,12 +186,12 @@ class Xtractor(torch.nn.Module):
embedding_a = self.seg_lin0(seg_emb)
return embedding_a
def forward(self, x):
def forward(self, x, is_eval=False):
"""
:param x:
:return:
"""
seg_emb_0 = self.produce_embeddings(x)
# batch-normalisation after this layer
seg_emb_1 = self.norm6(self.activation(seg_emb_0))
......@@ -137,18 +200,20 @@ class Xtractor(torch.nn.Module):
# No batch-normalisation after this layer
result = self.seg_lin2(seg_emb_2)
return result
def extract(self, x):
"""
Extract x-vector given an input sequence of features
x = self.sequence_network(x)
:param x:
:return:
"""
embedding_a = self.produce_embeddings(x)
embedding_b = self.seg_lin1(self.norm6(self.activation(embedding_a)))
# Mean and Standard deviation pooling
mean = torch.mean(x, dim=2)
std = torch.std(x, dim=2)
x = torch.cat([mean, std], dim=1)
x = self.before_speaker_embedding(x)
if is_eval:
return x
return embedding_a, embedding_b
x = self.after_speaker_embedding(x)
return x
def init_weights(self):
"""
......
......@@ -1273,7 +1273,7 @@ class StatServer:
"""
sts_per_model = sidekit.StatServer()
sts_per_model.modelset = numpy.unique(self.modelset)
sts_per_model.segset = sts_per_model.modelset
sts_per_model.segset = copy.deepcopy(sts_per_model.modelset)
sts_per_model.stat0 = numpy.zeros((sts_per_model.modelset.shape[0], self.stat0.shape[1]), dtype=STAT_TYPE)
sts_per_model.stat1 = numpy.zeros((sts_per_model.modelset.shape[0], self.stat1.shape[1]), dtype=STAT_TYPE)
sts_per_model.start = numpy.empty(sts_per_model.segset.shape, '|O')
......@@ -1295,7 +1295,7 @@ class StatServer:
"""
sts_per_model = sidekit.StatServer()
sts_per_model.modelset = numpy.unique(self.modelset)
sts_per_model.segset = sts_per_model.modelset
sts_per_model.segset = copy.deepcopy(sts_per_model.modelset)
sts_per_model.stat0 = numpy.zeros((sts_per_model.modelset.shape[0], self.stat0.shape[1]), dtype=STAT_TYPE)
sts_per_model.stat1 = numpy.zeros((sts_per_model.modelset.shape[0], self.stat1.shape[1]), dtype=STAT_TYPE)
sts_per_model.start = numpy.empty(sts_per_model.segset.shape, '|O')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment