Commit 473e89b9 authored by Florent Desnous 's avatar Florent Desnous
Browse files

get_scores() without batch

parent 19f68ec6
......@@ -184,20 +184,17 @@ class SAD_RNN():
features, _ = features_server.load(show)
x = []
X = torch.tensor([]).to(device)
X = []
for i in range(0, len(features) - self.duration, self.step):
x.append(features[i:i + self.duration])
X.append(features[i:i + self.duration])
if i + self.step > len(features) - self.duration:
pad_size = self.batch_size - len(x)
pad_size = self.batch_size - len(X)
pad = [[[0] * self.input_size] * self.duration] * pad_size
x += pad
if len(x) == self.batch_size:
x = torch.Tensor(x)
x = x.to(device)
self.model.hidden = None
X = torch.cat((X, self.model(x)))
x = []
X += pad
X = torch.Tensor(X).to(device)
self.model.hidden = None
X = self.model(X)
o = numpy.asarray(X.squeeze(2).tolist())
scores = numpy.zeros((len(o) * self.step + self.duration - self.step))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment