Commit 83a17da7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sad

parent 19f68ec6
......@@ -63,6 +63,7 @@ class SAD_Dataset(Dataset):
start, stop = 0, len(features)
for i in range(start, min(stop, len(features)) - self.duration, self.step):
self.segments.append((show, i, i + self.duration))
self.input_size = features.shape[1]
if shuffle:
random.shuffle(self.segments)
......@@ -242,11 +243,10 @@ class SAD_RNN():
for batch_idx, (X, Y) in enumerate(training_set):
batch_loss = self._fit_batch(optimizer, criterion, X, Y)
losses[epoch].append(batch_loss)
sys.stdout.write("\rEpoch {}/{} ({}/{}), loss {:.5f}".format(
epoch + 1, nb_epochs, it, est_it, numpy.mean(losses[epoch])))
sys.stdout.write("\rEpoch {}/{}, loss {:.5f}".format(
epoch + 1, nb_epochs, numpy.mean(losses[epoch])))
sys.stdout.flush()
it += 1
#est_it = len(losses[epoch])
torch.save(self.model.state_dict(), model_file_format.format(epoch+1))
def get_labels(self, model_fn, show, features_server,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment