Commit 2952f1a2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add options

parent c6f046c4
......@@ -190,5 +190,5 @@ __maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__="1.3.6.5"
__version__="1.3.6.6"
......@@ -37,7 +37,7 @@ import numpy
import os
if "DISPLAY" not in os.environ:
matplotlib.use('PDF', warn=False, force=True)
matplotlib.use('PDF')
import matplotlib.pyplot as mpl
import scipy
from collections import namedtuple
......
......@@ -445,7 +445,8 @@ def read_audio(input_file_name, framerate=None):
raise TypeError("Unknown extension of audio file")
# Convert to 16 bit encoding if needed
sig *= (2**(15-sampwidth))
#if not sampwidth == 2:
# sig *= (2**(15-sampwidth))
if framerate > read_framerate:
print("Warning in read_audio, up-sampling function is not implemented yet!")
......
......@@ -112,7 +112,7 @@ class StatDataset(Dataset):
self.len = self.idmap.leftids.shape[0]
def __getitem__(self, index):
data, _ = self.fs.load(self.idmap.rightids[index])
data, _ = self.fs.load(self.idmap.rightids[index], start=self.idmap.start[index], stop=self.idmap.stop[index])
data = (data - data.mean(0)) / data.std(0)
data = data.reshape((1, data.shape[0], data.shape[1])).transpose(0, 2, 1).astype(numpy.float32)
return self.idmap.leftids[index], self.idmap.rightids[index], torch.from_numpy(data).type(torch.FloatTensor)
......@@ -120,7 +120,6 @@ class StatDataset(Dataset):
def __len__(self):
return self.len
class VoxDataset(Dataset):
"""
......
......@@ -266,7 +266,7 @@ def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
if "TemporalMask" in t:
a = int(t.split("(")[1].split(")")[0])
train_transform.append(TemporalMask(a))
train_set = VoxDataset(train_seg_df, speaker_dict, 500, transform=transforms.Compose(train_transform),
train_set = VoxDataset(train_seg_df, speaker_dict, args.duration, transform=transforms.Compose(train_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=15)
......
......@@ -417,8 +417,6 @@ class StatServer:
"""
assert self.validate(), "Error: wrong StatServer format"
file_already_exist = os.path.exists(output_file_name)
start = copy.deepcopy(self.start)
start[numpy.isnan(self.start.astype('float'))] = -1
start = start.astype('int32', copy=False)
......@@ -429,8 +427,10 @@ class StatServer:
with h5py.File(output_file_name, mode) as f:
ds_already_exist = prefix in f
# If the file doesn't exist before, create it
if mode == "w" or not file_already_exist:
if mode == "w" or not ds_already_exist:
f.create_dataset(prefix+"modelset", data=self.modelset.astype('S'),
maxshape=(None,),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment