Commit cd7ea34d authored by Anthony Larcher's avatar Anthony Larcher
Browse files

ddp

parent cb585357
......@@ -58,6 +58,10 @@ from .loss import l2_norm
import torch.autograd.profiler as profiler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
__license__ = "LGPL"
__author__ = "Anthony Larcher"
......@@ -645,6 +649,9 @@ def xtrain(speaker_number,
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
......@@ -1445,3 +1452,32 @@ def xtime(model, training_set, validation_set,
print(torch.cuda.memory_summary(ii))
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = torch.nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment