Commit 63209176 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug new xtrain

parents 748a7a3e 196e47a9
......@@ -34,7 +34,6 @@ import os
import sys
# Read environment variable if it exists
SIDEKIT_CONFIG={"libsvm":True,
"mpi":False,
......
......@@ -40,7 +40,6 @@ import torch
import warnings
import sidekit.frontend
from ..sidekit_io import init_logging
from ..sidekit_wrappers import check_path_existance
__license__ = "LGPL"
......
......@@ -132,7 +132,7 @@ class SideSampler(torch.utils.data.Sampler):
class SideSet(Dataset):
def __init__(self,
data_set_yaml,
dataset,
set_type="train",
chunk_per_segment=1,
transform_number=1,
......@@ -148,10 +148,10 @@ class SideSet(Dataset):
:param chunk_per_segment: number of chunks to select for each segment
default is 1 and -1 means select all possible chunks
"""
with open(data_set_yaml, "r") as fh:
dataset = yaml.load(fh, Loader=yaml.FullLoader)
#with open(data_set_yaml, "r") as fh:
# dataset = yaml.load(fh, Loader=yaml.FullLoader)
self.data_path = dataset["data_root_directory"]
self.data_path = dataset["data_path"]
self.sample_rate = int(dataset["sample_rate"])
self.data_file_extension = dataset["data_file_extension"]
self.transformation = ''
......@@ -163,8 +163,8 @@ class SideSet(Dataset):
self.duration = dataset["train"]["duration"]
self.transformation = dataset["train"]["transformation"]
else:
self.duration = dataset["eval"]["duration"]
self.transformation = dataset["eval"]["transformation"]
self.duration = dataset["valid"]["duration"]
self.transformation = dataset["valid"]["transformation"]
self.sample_number = int(self.duration * self.sample_rate)
self.overlap = int(overlap * self.sample_rate)
......
This diff is collapsed.
......@@ -418,7 +418,7 @@ def h5merge(output_filename, input_filename_list):
fo.create_dataset(key, data=value)
def init_logging(level=logging.INFO, filename=None):
def init_logging(level=logging.DEBUG, filename=None):
"""
Initialize a logger
......@@ -426,18 +426,35 @@ def init_logging(level=logging.INFO, filename=None):
:param filename: name of the output file
"""
numpy.set_printoptions(linewidth=250, precision=4)
frm = '%(asctime)s - %(levelname)s - %(message)s'
root = logging.getLogger()
if root.handlers:
for handler in root.handlers:
root.removeHandler(handler)
logging.basicConfig(format=frm, level=level)
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M',
filename='./myapp.log',
filemode='w')
if filename is not None:
fh = logging.FileHandler(filename)
fh.setFormatter(logging.Formatter(frm))
fh.setFormatter(logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
fh.setLevel(level)
root.addHandler(fh)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
# tell the handler to use this format
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
def write_matrix_hdf5(M, filename):
with h5py.File(filename, "w") as h5f:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment