Commit 19f68ec6 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sad

parent a69d5335
...@@ -216,8 +216,7 @@ class SAD_RNN(): ...@@ -216,8 +216,7 @@ class SAD_RNN():
def train_network(self, def train_network(self,
nb_epochs, nb_epochs,
train_list, uem_list, training_set,
features_server,
model_file_format): model_file_format):
""" """
Trains the network Trains the network
...@@ -236,26 +235,18 @@ class SAD_RNN(): ...@@ -236,26 +235,18 @@ class SAD_RNN():
optimizer = optim.RMSprop(self.model.parameters()) optimizer = optim.RMSprop(self.model.parameters())
losses = [] losses = []
est_it = 0 est_it = training_set.len // self.batch_size
for show in uem_list:
for seg in uem_list[show]:
est_it += seg['stop'] - seg['start']
est_it = est_it // self.batch_size // self.step
#est_it = # uem.duration // self.batch_size // self.step
for epoch in range(nb_epochs): for epoch in range(nb_epochs):
it = 1 it = 1
losses.append([]) losses.append([])
gen = self._sad_generator(train_list, uem_list, features_server) for batch_idx, (X, Y) in enumerate(training_set):
for X, Y in gen:
batch_loss = self._fit_batch(optimizer, criterion, X, Y) batch_loss = self._fit_batch(optimizer, criterion, X, Y)
losses[epoch].append(batch_loss) losses[epoch].append(batch_loss)
sys.stdout.write("\rEpoch {}/{} ({}/{}), loss {:.5f}".format( sys.stdout.write("\rEpoch {}/{} ({}/{}), loss {:.5f}".format(
epoch + 1, nb_epochs, it, est_it, numpy.mean(losses[epoch]))) epoch + 1, nb_epochs, it, est_it, numpy.mean(losses[epoch])))
sys.stdout.flush() sys.stdout.flush()
it += 1 it += 1
print() #est_it = len(losses[epoch])
est_it = len(losses[epoch])
torch.save(self.model.state_dict(), model_file_format.format(epoch+1)) torch.save(self.model.state_dict(), model_file_format.format(epoch+1))
def get_labels(self, model_fn, show, features_server, def get_labels(self, model_fn, show, features_server,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment