xvector.py 92.1 KB
Newer Older
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2001
        for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
2002
2003
            target = target.squeeze()
            target = target.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
2004
            batch_size = target.shape[0]
2005
            data = data.squeeze().to(device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2006
            with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2007
2008
2009
2010
2011
2012
2013
                output, batch_embeddings = model(data, target=None, is_eval=True)
                if loss_criteria == 'cce':
                    batch_embeddings = l2_norm(batch_embeddings)
                if loss_criteria == 'smn':
                    batch_embeddings, batch_predictions = output
                else:
                    batch_predictions = output
2014
2015
                accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
                loss += criterion(batch_predictions, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2016
2017
            embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
            cursor += batch_size
2018

Anthony Larcher's avatar
merge    
Anthony Larcher committed
2019
    local_device = "cpu" if embeddings.shape[0] > 3e4 else device
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2020
2021
2022
2023
    embeddings = embeddings.to(local_device)
    scores = torch.einsum('ij,kj', embeddings, embeddings).cpu().numpy()
    negatives = scores[non_indices]
    positives = scores[tar_indices]
2024

Anthony Larcher's avatar
Anthony Larcher committed
2025
    # Faster EER computation available here : https://github.com/gl3lan/fast_eer
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2026
2027
2028
2029
    #equal_error_rate = eer(negatives, positives)

    pmiss, pfa = rocch(positives, negatives)
    equal_error_rate = rocch2eer(pmiss, pfa)
Anthony Larcher's avatar
Anthony Larcher committed
2030

Anthony Larcher's avatar
Anthony Larcher committed
2031
    return (100. * accuracy.cpu().numpy() / validation_shape[0],
Anthony Larcher's avatar
Anthony Larcher committed
2032
            loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
Anthony Larcher's avatar
Anthony Larcher committed
2033
            equal_error_rate)
2034
2035


Anthony Larcher's avatar
Anthony Larcher committed
2036
2037
2038
2039
def extract_embeddings(idmap_name,
                       model_filename,
                       data_root_name,
                       device,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2040
                       loss,
Anthony Larcher's avatar
Anthony Larcher committed
2041
                       file_extension="wav",
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2042
                       transform_pipeline="",
Anthony Larcher's avatar
Anthony Larcher committed
2043
2044
2045
2046
2047
                       frame_shift=0.01,
                       frame_duration=0.025,
                       extract_after_pooling=False,
                       num_thread=1,
                       mixed_precision=False):
2048
2049
    """

Anthony Larcher's avatar
Anthony Larcher committed
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
    :param idmap_name:
    :param model_filename:
    :param data_root_name:
    :param device:
    :param model_yaml:
    :param speaker_number:
    :param file_extension:
    :param transform_pipeline:
    :param frame_shift:
    :param frame_duration:
    :param extract_after_pooling:
2061
    :param num_thread:
Anthony Larcher's avatar
Anthony Larcher committed
2062
    :param mixed_precision:
2063
2064
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
2065
2066
    # Load the model
    if isinstance(model_filename, str):
Anthony Larcher's avatar
Anthony Larcher committed
2067
        checkpoint = torch.load(model_filename, map_location=device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2068
2069
        speaker_number = checkpoint["speaker_number"]
        model_archi = checkpoint["model_archi"]
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2070
        model = Xtractor(speaker_number, model_archi=model_archi, loss=checkpoint["loss"])
2071
2072
2073
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename
Anthony Larcher's avatar
Anthony Larcher committed
2074

Anthony Larcher's avatar
Anthony Larcher committed
2075
    if isinstance(idmap_name, IdMap):
2076
2077
2078
2079
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
2080
2081
2082
2083
2084
2085
    if type(model) is Xtractor:
        min_duration = (model.context_size() - 1) * frame_shift + frame_duration
        model_cs = model.context_size()
    else:
        min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
        model_cs = model.module.context_size()
2086

Anthony Larcher's avatar
Anthony Larcher committed
2087
    # Create dataset to load the data
Anthony Larcher's avatar
Anthony Larcher committed
2088
    dataset = IdMapSet(idmap_name=idmap_name,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2089
                       data_path=data_root_name,
Anthony Larcher's avatar
Anthony Larcher committed
2090
                       file_extension=file_extension,
2091
                       transform_pipeline=transform_pipeline,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2092
                       min_duration=(model_cs + 2) * frame_shift * 2
2093
                       )
Anthony Larcher's avatar
Anthony Larcher committed
2094

2095
2096
2097
2098
2099
2100
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)
2101

Anthony Larcher's avatar
Anthony Larcher committed
2102
    with torch.no_grad():
2103
2104
2105
2106
2107

        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
        if type(model) is Xtractor:
            name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
        else:
            name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
Anthony Larcher's avatar
Anthony Larcher committed
2118

Anthony Larcher's avatar
debug    
Anthony Larcher committed
2119

2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
        # Create the StatServer
        embeddings = StatServer()
        embeddings.modelset = idmap.leftids
        embeddings.segset = idmap.rightids
        embeddings.start = idmap.start
        embeddings.stop = idmap.stop
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))

        # Process the data
        with torch.no_grad():
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2131
2132
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
                                                                          desc='xvector extraction',
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2133
2134
                                                                          mininterval=1,
                                                                          disable=None)):
Anthony Larcher's avatar
Anthony Larcher committed
2135
2136
                if data.shape[1] > 20000000:
                    data = data[...,:20000000]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2137
                with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2138
                    _, vec = model(x=data.to(device), is_eval=True)
2139
                embeddings.stat1[idx, :] = vec.detach().cpu()
Anthony Larcher's avatar
Anthony Larcher committed
2140
2141
2142
2143

    return embeddings


Anthony Larcher's avatar
Anthony Larcher committed
2144
2145
2146
2147
2148
2149
2150
2151
def extract_embeddings_per_speaker(idmap_name,
                                   model_filename,
                                   data_root_name,
                                   device,
                                   file_extension="wav",
                                   transform_pipeline=None,
                                   frame_shift=0.01,
                                   frame_duration=0.025,
2152
                                   extract_after_pooling=False,
Anthony Larcher's avatar
Anthony Larcher committed
2153
2154
2155
2156
                                   num_thread=1):
    # Load the model
    if isinstance(model_filename, str):
        checkpoint = torch.load(model_filename)
Anthony Larcher's avatar
Anthony Larcher committed
2157

Anthony Larcher's avatar
debug    
Anthony Larcher committed
2158
        model_archi = checkpoint["model_archi"]
Anthony Larcher's avatar
Anthony Larcher committed
2159

Anthony Larcher's avatar
debug    
Anthony Larcher committed
2160
        model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi, loss="aam")
Anthony Larcher's avatar
Anthony Larcher committed
2161
2162
2163
2164
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename

Anthony Larcher's avatar
merge    
Anthony Larcher committed
2165
    model = model.to(memory_format=torch.channels_last)
Anthony Larcher's avatar
Anthony Larcher committed
2166

Anthony Larcher's avatar
Anthony Larcher committed
2167
2168
2169
    min_duration = (model.context_size() - 1) * frame_shift + frame_duration

    # Create dataset to load the data
Anthony Larcher's avatar
Anthony Larcher committed
2170
2171
2172
2173
2174
2175
    dataset = IdMapSetPerSpeaker(idmap_name=idmap_name,
                                 data_root_path=data_root_name,
                                 file_extension=file_extension,
                                 transform_pipeline=transform_pipeline,
                                 frame_rate=int(1. / frame_shift),
                                 min_duration=(model.context_size() + 2) * frame_shift * 2)
Anthony Larcher's avatar
Anthony Larcher committed
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)

    with torch.no_grad():
        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
        name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
Anthony Larcher's avatar
Anthony Larcher committed
2190
2191
2192
        if extract_after_pooling:
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
        else:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2193
            emb_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
2194
2195
2196

        # Create the StatServer
        embeddings = StatServer()
Anthony Larcher's avatar
Anthony Larcher committed
2197
2198
2199
2200
        embeddings.modelset = dataset.output_im.leftids
        embeddings.segset = dataset.output_im.rightids
        embeddings.start = dataset.output_im.start
        embeddings.stop = dataset.output_im.stop
Anthony Larcher's avatar
Anthony Larcher committed
2201
2202
2203
2204
2205
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))

        # Process the data
        with torch.no_grad():
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2206
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
Anthony Larcher's avatar
Anthony Larcher committed
2207
2208
                if data.shape[1] > 20000000:
                    data = data[..., :20000000]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
2209
                vec = model(data.to(device), is_eval=True)
Anthony Larcher's avatar
Anthony Larcher committed
2210
2211
2212
2213
                embeddings.stat1[idx, :] = vec.detach().cpu()

    return embeddings

Anthony Larcher's avatar
Anthony Larcher committed
2214

Anthony Larcher's avatar
Anthony Larcher committed
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
def extract_sliding_embedding(idmap_name,
                              window_len,
                              window_shift,
                              model_filename,
                              data_root_name ,
                              device,
                              sample_rate=16000,
                              file_extension="wav",
                              transform_pipeline=None,
                              num_thread=1,
                              mixed_precision=False):
    """

    :param idmap_name:
    :param window_length:
    :param sample_rate:
    :param overlap:
    :param speaker_number:
    :param model_filename:
    :param model_yaml:
    :param data_root_name:
    :param device:
    :param file_extension:
    :param transform_pipeline:
    :return:
    """
    # From the original IdMap, create the new one to extract x-vectors
    if not isinstance(idmap_name, IdMap):
        input_idmap = IdMap(idmap_name)
    else:
        input_idmap = idmap_name

    # Load the model
    if isinstance(model_filename, str):
        checkpoint = torch.load(model_filename, map_location=device)
        speaker_number = checkpoint["speaker_number"]
        model_archi = checkpoint["model_archi"]
        model = Xtractor(speaker_number, model_archi=model_archi)
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename

    if isinstance(idmap_name, IdMap):
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)

    # Create dataset to load the data
    dataset = IdMapSet(idmap_name=idmap_name,
                       data_path=data_root_name,
                       file_extension=file_extension,
                       transform_pipeline=transform_pipeline,
                       sliding_window=True,
                       window_len=window_len,
                       window_shift=window_shift,
                       sample_rate=sample_rate,
                       min_duration=0.1
                       )

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)

    with torch.no_grad():

        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
        if type(model) is Xtractor:
            name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
        else:
            name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]

Anthony Larcher's avatar
Anthony Larcher committed
2298
2299
2300
2301
        embeddings = []
        modelset= []
        segset = []
        starts = []
Anthony Larcher's avatar
Anthony Larcher committed
2302
2303
2304
2305
2306
2307
2308

        # Process the data
        with torch.no_grad():
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
                                                                          desc='xvector extraction',
                                                                          mininterval=1)):
                with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
2309
2310
2311
2312
2313
2314
2315
2316
2317
                    data = data.squeeze()
                    tmp_data = torch.split(data,data.shape[0]//(data.shape[0]//100))
                    for td in tmp_data:
                        vec = model(x=td.to(device), is_eval=True)
                        embeddings.append(vec.detach().cpu())
                    modelset += [mod,] *  data.shape[0]
                    segset += [seg,] *  data.shape[0]
                    starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]

Anthony Larcher's avatar
debug    
Anthony Larcher committed
2318
        #REPRENDRE ICI
Anthony Larcher's avatar
Anthony Larcher committed
2319

Anthony Larcher's avatar
Anthony Larcher committed
2320
2321
2322
2323
2324
2325
2326
2327
        # Create the StatServer
        embeddings = StatServer()
        embeddings.modelset = numpy.array(modelset).astype('>U')
        embeddings.segset = numpy.array(segset).astype('>U')
        embeddings.start = numpy.array(starts)
        embeddings.stop = numpy.array(starts) + window_len
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.concatenate(embeddings)
Anthony Larcher's avatar
Anthony Larcher committed
2328

Anthony Larcher's avatar
Anthony Larcher committed
2329
    return embeddings