Commit e8289e54 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

minor

parent b9ef3f9f
......@@ -29,6 +29,7 @@ import h5py
import logging
import sys
import numpy
import pickle
import torch
import torch.optim as optim
import torch.multiprocessing as mp
......@@ -495,7 +496,7 @@ def xtrain_new(args):
# Split the training data in train and cv
total_seg_df = pickle.load(open(args.batch_training_list, "rb"))
cv_portion = 0.05
cv_portion = 0.007
idx = numpy.arange(len(total_seg_df))
numpy.random.shuffle(idx)
train_seg_df = total_seg_df.iloc[idx[:int((1-cv_portion)*len(idx))]].reset_index()
......@@ -588,7 +589,7 @@ def train_epoch_new(model, epoch, train_seg_df, args):
torch.manual_seed(args.seed)
train_set = VoxDataset(train_seg_df, 500)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=15)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=15)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
......@@ -618,7 +619,6 @@ def train_epoch_new(model, epoch, train_seg_df, args):
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
return model
......@@ -649,12 +649,14 @@ def cross_validation_new(args, model, cv_seg_df):
:return:
"""
cv_set = VoxDataset(cv_seg_df, 500)
cv_loader = DataLoader(cv_set, batch_size=1, shuffle=False, num_workers=1)
model.eval()
device = torch.device("cuda:0")
model.to(device)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(cv_set):
print(cv_set.__len__())
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * args.batch_size)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment