Commit 19f68ec6 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sad

parent a69d5335
......@@ -216,8 +216,7 @@ class SAD_RNN():
def train_network(self,
nb_epochs,
train_list, uem_list,
features_server,
training_set,
model_file_format):
"""
Trains the network
......@@ -236,26 +235,18 @@ class SAD_RNN():
optimizer = optim.RMSprop(self.model.parameters())
losses = []
est_it = 0
for show in uem_list:
for seg in uem_list[show]:
est_it += seg['stop'] - seg['start']
est_it = est_it // self.batch_size // self.step
#est_it = # uem.duration // self.batch_size // self.step
est_it = training_set.len // self.batch_size
for epoch in range(nb_epochs):
it = 1
losses.append([])
gen = self._sad_generator(train_list, uem_list, features_server)
for X, Y in gen:
for batch_idx, (X, Y) in enumerate(training_set):
batch_loss = self._fit_batch(optimizer, criterion, X, Y)
losses[epoch].append(batch_loss)
sys.stdout.write("\rEpoch {}/{} ({}/{}), loss {:.5f}".format(
epoch + 1, nb_epochs, it, est_it, numpy.mean(losses[epoch])))
sys.stdout.flush()
it += 1
print()
est_it = len(losses[epoch])
#est_it = len(losses[epoch])
torch.save(self.model.state_dict(), model_file_format.format(epoch+1))
def get_labels(self, model_fn, show, features_server,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment