Commit 552805f1 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent cd7ea34d
......@@ -539,7 +539,7 @@ class SideSet(Dataset):
if self.output_format == "pytorch":
return torch.from_numpy(sig).type(torch.FloatTensor), torch.from_numpy(speaker_idx).type(torch.LongTensor)
else:
return sig, speaker_idx
return sig.astype(numpy.float32), speaker_idx
def __len__(self):
"""
......
......@@ -31,10 +31,12 @@ import traceback
import logging
import matplotlib.pyplot as plt
import multiprocessing
import os
import numpy
import pandas
import pickle
import shutil
import sys
import time
import torch
import tqdm
......@@ -62,6 +64,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
os.environ['MKL_THREADING_LAYER'] = 'GNU'
__license__ = "LGPL"
__author__ = "Anthony Larcher"
......@@ -448,7 +451,7 @@ class Xtractor(torch.nn.Module):
elif k.startswith("ctrans"):
segmental_layers.append((k, torch.nn.ConvTranspose1d(input_size,
cfg["segmental"][k][":"],
cfg["segmental"][k]["output_channels"],
kernel_size=cfg["segmental"][k]["kernel_size"],
dilation=cfg["segmental"][k]["dilation"])))
elif k.startswith("activation"):
......@@ -647,10 +650,9 @@ def xtrain(speaker_number,
if num_thread is None:
num_thread = multiprocessing.cpu_count()
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
logging.critical(f"Use {num_thread} cpus")
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -740,7 +742,7 @@ def xtrain(speaker_number,
Then we provide those two
"""
if write_batches_to_disk:
if write_batches_to_disk or dataset_params["batch_size"] > 1:
output_format = "numpy"
else:
output_format = "pytorch"
......@@ -1007,16 +1009,16 @@ def cross_validation(model, validation_loader, device):
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
target = target.squeeze().to(device)
data = data.squeeze().to(device)
if loss_criteria == "aam":
output = model(data.to(device), target=target)
output = model(data, target=target)
else:
output = model(data.to(device), target=None)
output = model(data, target=None)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
accuracy += (torch.argmax(output.data, 1) == target).sum()
loss += criterion(output, target.to(device))
loss += criterion(output, target)
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
......@@ -1455,29 +1457,3 @@ def xtime(model, training_set, validation_set,
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = torch.nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment