Commit ebdf1b53 authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents 8e8fb525 9b6795f6
*.pyc *.pyc
*.DS_Store *.DS_Store
docs docs
.vscode
.gitignore
.vscode
.history
...@@ -173,6 +173,9 @@ def data_augmentation(speech, ...@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number) aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx] augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "none" in augmentations:
pass
if "stretch" in augmentations: if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch() strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2) rate = random.uniform(0.8,1.2)
...@@ -261,6 +264,7 @@ def data_augmentation(speech, ...@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape = speech.shape[1] final_shape = speech.shape[1]
configs = [ configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"), ({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": 'ALAW', "bits_per_sample": 8}, "8 bit a-law"),
({"format": "gsm"}, "GSM-FR"), ({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"), ({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis") ({"format": "vorbis", "compression": -1}, "Vorbis")
......
...@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module): ...@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module):
def forward(self, x): def forward(self, x):
""" """
:param x: :param x: [B, C*F, T]
:return: :return:
""" """
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
x = x.flatten(start_dim=1, end_dim=2)
# [B, C*F]
mean = torch.mean(x, dim=2) mean = torch.mean(x, dim=2)
# [B, C*F]
std = torch.std(x, dim=2) std = torch.std(x, dim=2)
# [B, 2*C*F]
return torch.cat([mean, std], dim=1) return torch.cat([mean, std], dim=1)
class ChannelWiseCorrPooling(torch.nn.Module):
def __init__(self, in_channels=256, out_channels=64, in_freqs=10, channels_dropout=0.25):
super(ChannelWiseCorrPooling, self).__init__()
self.channels_dropout = channels_dropout
self.merge_freqs_count = 2
assert in_freqs % self.merge_freqs_count == 0
self.groups = in_freqs//self.merge_freqs_count
self.out_channels = out_channels
self.out_dim = int(self.out_channels*(self.out_channels-1)/2)*self.groups
self.L_proj = torch.nn.Conv2d(in_channels*self.groups, out_channels*self.groups, kernel_size=(1, 1), groups=self.groups)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.mask = torch.tril(torch.ones((out_channels, out_channels)), diagonal=-1).type(torch.BoolTensor)
def forward(self, x):
"""
:param x: [B, C, T, F]
:return:
"""
batch_size=x.shape[0]
num_locations = x.shape[-1]*x.shape[-2]/self.groups
self.mask = self.mask.to(x.device)
if self.training:
x *= torch.nn.functional.dropout(torch.ones((1, x.shape[1], 1, 1), device=x.device), p=self.channels_dropout)
#[B, T, C, F]
x = x.permute(0, 2, 1, 3)
#[B, T, C, Fr, f]
x = x.reshape(x.shape[0], x.shape[1], x.shape[-2], self.groups, self.merge_freqs_count)
#[B, T, f, Fr, C]
x = x.permute(0, 1, 4, 3, 2)
#[B, T, f, Fr*C]
x = x.flatten(start_dim=3, end_dim=4)
#[B, Fr*C, T, f]
x = x.permute(0, 3, 1, 2)
#[B, Fr*C', T, f]
x = self.L_proj(x)
#[B, Fr, C', Tr]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
x -= torch.mean(x, axis=-1, keepdims=True)
out = x/(torch.std(x, axis=-1, keepdims=True) + 1e-5)
#[B, C', C']
out = torch.einsum('abci,abdi->abcd', out, out)
#[B, C'*(C'-1)/2]
out = torch.masked_select(out, self.mask).reshape(batch_size, -1)
out = out/num_locations
return out
class AttentivePooling(torch.nn.Module): class AttentivePooling(torch.nn.Module):
""" """
Mean and Standard deviation attentive pooling Mean and Standard deviation attentive pooling
...@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module): ...@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
def forward(self, x): def forward(self, x):
""" """
:param x: :param x: [B, C*F, T]
:return: :return:
""" """
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
# [B, C*F, T]
x = x.flatten(start_dim=1, end_dim=2)
if self.global_context: if self.global_context:
w = self.attention(torch.cat([x, self.gc(x).unsqueeze(2).repeat(1, 1, x.shape[-1])], dim=1)) w = self.attention(torch.cat([x, self.gc(x).unsqueeze(2).repeat(1, 1, x.shape[-1])], dim=1))
else: else:
......
...@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module): ...@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
:param x: :param x:
:return: :return:
""" """
out = x.unsqueeze(1) x = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out))) x = x.permute(0, 1, 3, 2)
out = self.layer1(out) x = x.to(memory_format=torch.channels_last)
out = self.layer2(out) x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer3(out) x = self.layer1(x)
out = self.layer4(out) x = self.layer2(x)
out = self.layer5(out) x = self.layer3(x)
out = self.layer6(out) x = self.layer4(x)
out = self.layer7(out) x = self.layer5(x)
out = torch.flatten(out, start_dim=1, end_dim=2) x = self.layer6(x)
return out x = self.layer7(x)
return x
class PreHalfResNet34(torch.nn.Module): class PreHalfResNet34(torch.nn.Module):
...@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module): ...@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module):
:param x: :param x:
:return: :return:
""" """
out = x.unsqueeze(1) x = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last) x = x.permute(0, 1, 3, 2)
out = torch.nn.functional.relu(self.bn1(self.conv1(out))) x = x.to(memory_format=torch.channels_last)
out = self.layer1(out) x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer2(out) x = self.layer1(x)
out = self.layer3(out) x = self.layer2(x)
out = self.layer4(out) x = self.layer3(x)
out = out.contiguous(memory_format=torch.contiguous_format) x = self.layer4(x)
out = torch.flatten(out, start_dim=1, end_dim=2) return x
return out
class PreFastResNet34(torch.nn.Module): class PreFastResNet34(torch.nn.Module):
...@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module): ...@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
self.speaker_number = speaker_number self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7, self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7,
stride=(2, 1), padding=3, bias=False) stride=(1, 2), padding=3, bias=False)
self.bn1 = torch.nn.BatchNorm2d(16) self.bn1 = torch.nn.BatchNorm2d(16)
# With block = [3, 4, 6, 3] # With block = [3, 4, 6, 3]
...@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module): ...@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module):
:param x: :param x:
:return: :return:
""" """
out = x.unsqueeze(1) x = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last) x = x.permute(0, 1, 3, 2)
out = torch.nn.functional.relu(self.bn1(self.conv1(out))) x = x.to(memory_format=torch.channels_last)
out = self.layer1(out) x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer2(out) x = self.layer1(x)
out = self.layer3(out) x = self.layer2(x)
out = self.layer4(out) x = self.layer3(x)
out = out.contiguous(memory_format=torch.contiguous_format) x = self.layer4(x)
out = torch.flatten(out, start_dim=1, end_dim=2) return x
return out
def ResNet34(): def ResNet34():
......
...@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int) self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self): def __iter__(self):
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.seed + self.epoch) g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch) numpy.random.seed(self.seed + self.epoch)
...@@ -175,7 +175,7 @@ class SideSet(Dataset): ...@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap=0., overlap=0.,
dataset_df=None, dataset_df=None,
min_duration=0.165, min_duration=0.165,
output_format="pytorch", output_format="pytorch"
): ):
""" """
...@@ -269,6 +269,8 @@ class SideSet(Dataset): ...@@ -269,6 +269,8 @@ class SideSet(Dataset):
self.transform["codec"] = [] self.transform["codec"] = []
if "phone_filtering" in transforms: if "phone_filtering" in transforms:
self.transform["phone_filtering"] = [] self.transform["phone_filtering"] = []
if "stretch" in transforms:
self.transform["stretch"] = []
self.noise_df = None self.noise_df = None
if "add_noise" in self.transform: if "add_noise" in self.transform:
...@@ -416,18 +418,27 @@ class IdMapSet(Dataset): ...@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start = int(self.idmap.start[index] * 0.01 * self.sample_rate) start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
if self.idmap.stop[index] is None: if self.idmap.stop[index] is None:
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}") speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
duration = int(speech.shape[1] - start) duration = int(speech.shape[1] - start)
else: else:
duration = int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start # TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
assert nfo.sample_rate == self.sample_rate
conversion_rate = nfo.sample_rate // self.sample_rate
duration = (int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start)
# add this in case the segment is too short # add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate: if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2 middle = start + duration // 2
start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2)))) start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
duration = int(self.min_duration * self.sample_rate) duration = int(self.min_duration * self.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}", speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start, frame_offset=start * conversion_rate,
num_frames=duration) num_frames=duration * conversion_rate)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
#speech += 10e-6 * torch.randn(speech.shape) #speech += 10e-6 * torch.randn(speech.shape)
......
...@@ -36,12 +36,12 @@ import shutil ...@@ -36,12 +36,12 @@ import shutil
import torch import torch
import tqdm import tqdm
import yaml import yaml
#torch.autograd.set_detect_anomaly(True)
from collections import OrderedDict from collections import OrderedDict
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from .pooling import MeanStdPooling from .pooling import MeanStdPooling
from .pooling import AttentivePooling from .pooling import AttentivePooling, ChannelWiseCorrPooling
from .pooling import GruPooling from .pooling import GruPooling
from .preprocessor import MfccFrontEnd from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd from .preprocessor import MelSpecFrontEnd
...@@ -522,6 +522,35 @@ class Xtractor(torch.nn.Module): ...@@ -522,6 +522,35 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = AttentivePooling(256, 80, global_context=True) self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number),
emb_dim=self.embedding_size)
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.000
elif model_archi == "experimental":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Linear(in_features = int(64*63*5/2),
out_features = self.embedding_size)
self.stat_pooling = ChannelWiseCorrPooling(in_channels=256, out_channels=64)
self.loss = loss self.loss = loss
if self.loss == "aam": if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size, self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
...@@ -788,7 +817,6 @@ class Xtractor(torch.nn.Module): ...@@ -788,7 +817,6 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling # Mean and Standard deviation pooling
x = self.stat_pooling(x) x = self.stat_pooling(x)
x = self.before_speaker_embedding(x) x = self.before_speaker_embedding(x)
if norm_embedding: if norm_embedding:
...@@ -1005,7 +1033,7 @@ def get_network(model_opts, local_rank): ...@@ -1005,7 +1033,7 @@ def get_network(model_opts, local_rank):
:return: :return:
""" """
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]: if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34", "experimental"]:
model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"], embedding_size=model_opts["embedding_size"]) model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"], embedding_size=model_opts["embedding_size"])
else: else:
# Custom type of model # Custom type of model
...@@ -1035,24 +1063,9 @@ def get_network(model_opts, local_rank): ...@@ -1035,24 +1063,9 @@ def get_network(model_opts, local_rank):
if name.split(".")[0] in model_opts["reset_parts"]: if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False param.requires_grad = False
if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30): #if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"]) # model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}") # print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
return model return model
...@@ -1080,7 +1093,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1080,7 +1093,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df, validation_df = train_test_split(df, training_df, validation_df = train_test_split(df,
test_size=dataset_opts["validation_ratio"], test_size=dataset_opts["validation_ratio"],
stratify=stratify) stratify=stratify)
# TODO
torch.manual_seed(training_opts['torch_seed'] + local_rank) torch.manual_seed(training_opts['torch_seed'] + local_rank)
torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank) torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank)
...@@ -1090,7 +1104,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1090,7 +1104,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number=dataset_opts['train']['transform_number'], transform_number=dataset_opts['train']['transform_number'],
overlap=dataset_opts['train']['overlap'], overlap=dataset_opts['train']['overlap'],
dataset_df=training_df, dataset_df=training_df,
output_format="pytorch", output_format="pytorch"
) )
validation_set = SideSet(dataset_opts, validation_set = SideSet(dataset_opts,
...@@ -1106,20 +1120,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1106,20 +1120,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if training_opts["multi_gpu"]: if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0 assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0 assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"]) batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'], side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"], spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"], examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"], samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size, batch_size=batch_size*torch.cuda.device_count(),
seed=training_opts['torch_seed'], seed=training_opts['torch_seed'],
rank=local_rank, rank=local_rank,
num_process=torch.cuda.device_count(), num_process=torch.cuda.device_count(),
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"] num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
) )
else: else:
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"] batch_size = dataset_opts["batch_size"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'], side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"], spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"], examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
...@@ -1380,8 +1394,18 @@ def xtrain(dataset_description, ...@@ -1380,8 +1394,18 @@ def xtrain(dataset_description,
# Initialize the model # Initialize the model
model = get_network(model_opts, local_rank) model = get_network(model_opts, local_rank)
if local_rank < 1: if local_rank < 1:
monitor.logger.info(model) monitor.logger.info(model)
monitor.logger.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size embedding_size = model.embedding_size
aam_scheduler = None aam_scheduler = None
...@@ -1569,7 +1593,7 @@ def train_epoch(model, ...@@ -1569,7 +1593,7 @@ def train_epoch(model,
loss += criterion(output, target) loss += criterion(output, target)
elif loss_criteria == 'aps': elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target) output_tuple, _ = model(data, target=target)
loss, output = output_tuple loss, no_margin_output = output_tuple
else: else:
output, _ = model(data, target=None) output, _ = model(data, target=None)
loss = criterion(output, target) loss = criterion(output, target)
...@@ -1603,7 +1627,7 @@ def train_epoch(model, ...@@ -1603,7 +1627,7 @@ def train_epoch(model,
if math.fmod(batch_idx, training_opts["log_interval"]) == 0: if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0] batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(), training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)) training_acc=100.0 * accuracy / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format( training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch, training_monitor.current_epoch,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment