Commit 3a64e622 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent 19953b86
......@@ -468,7 +468,9 @@ def read_audio(input_file_name, framerate=None):
def write_label(label,
output_file_name,
selected_label='speech',
frame_per_second=100):
frame_per_second=100,
show=None,
format="mdtm"):
"""Save labels in ALIZE format
:param output_file_name: name of the file to write to
......@@ -483,11 +485,30 @@ def write_label(label,
# append 0 at the beginning of the list, append the last index to the list
idx = [0] + (numpy.arange(len(bits))[bits] + 1).tolist() + [len(label)]
framerate = decimal.Decimal(1) / decimal.Decimal(frame_per_second)
# for each pair of indexes (idx[i] and idx[i+1]), create a segment
with open(output_file_name, 'w') as fid:
if format == "lab":
# for each pair of indexes (idx[i] and idx[i+1]), create a segment
with open(output_file_name, 'w') as fid:
for i in range(~label[0], len(idx) - 1, 2):
fid.write('{} {} {}\n'.format(str(idx[i]*framerate),
str(idx[i + 1]*framerate), selected_label))
else:
# write in MDTM format
lst = []
for i in range(~label[0], len(idx) - 1, 2):
fid.write('{} {} {}\n'.format(str(idx[i]*framerate),
str(idx[i + 1]*framerate), selected_label))
gender = 'U'
env = 'U'
channel = 'U'
start = idx[i]*framerate
stop = idx[i + 1]*framerate
lst.append('{:s} 1 {:.2f} {:.2f} {:s} {:s} {:s} {:s}\n'.format(
show, start, stop - start, gender,
channel, env, "speech"))
with open(output_file_name, 'w', encoding="utf8") as fid:
for line in lst:
fid.write(line)
def read_label(input_file_name, selected_label='speech', frame_per_second=100):
......
......@@ -337,7 +337,6 @@ class SincNet(torch.nn.Module):
bias=True,
)
else:
print(f"smaple_rate ={self.sample_rate} ")
conv1d = SincConv1d(
out_channels,
kernel_size,
......
......@@ -645,6 +645,13 @@ class IdMapSet(Dataset):
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
if 'add_noise' in t:
self.add_noise[:] = 1
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv="list/musan.csv",
snr_min_max=[5.0, 15.0],
noise_root_path="./data/musan/"))
self.transforms = transforms.Compose(_transform)
def __getitem__(self, index):
......
......@@ -102,13 +102,6 @@ class GuruMeditation (torch.autograd.detect_anomaly):
pdb.set_trace()
def select_n_random(data, labels, n=100):
'''
Selects n random datapoints and their corresponding labels from a dataset
......@@ -1003,8 +996,8 @@ def cross_validation(model, validation_loader, device):
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
batch_size = target.shape[0]
if loss_criteria == "aam":
output = model(data.to(device), target=target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment