Commit 97d5d11d authored by Félix Michaud's avatar Félix Michaud
Browse files

training 1 species

parent 1ca52594
......@@ -216,14 +216,14 @@ class Dataset(data.Dataset):
print("** WARNING ** No data loaded from " + path)
return dict_classes
def mixed_audio_augmentation(self, audio, sampling_rate):
def get_noise(self):
classe_noise = random.randint(0, len(list(self.dict_noises.keys()))-1)
classe_noise = list(self.dict_noises.keys())[classe_noise]
#random natural noise augmentation
filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
# audio_noise, sr = load_file(filename_noise)
# coeff = int(np.ceil(np.max(audio)*random.choice([1, 2, 3, 4, 5, 6, 7])))
# noisy_audio = audio + (audio_noise)*coeff
return filename_noise
def data_augment(audio):
#random pitch shifting
# step_pitch = random.uniform(-0.001, 0.001)
# mod_audio = librosa.effects.pitch_shift(noisy_audio, sampling_rate, n_steps=step_pitch)
......@@ -296,10 +296,13 @@ class Dataset(data.Dataset):
'Build mixed mask'
if self.augmentation:
# file_noise = self.mixed_audio_augmentation(audio_mix, sr)
n_noise = np.random.normal(loc=0, scale=1, size=(1, max_time*sr))
n_noise = librosa.to_mono(n_noise)
snr = np.random.randint(-10, 5) #-10/5 for natural noise, 30/50
if random.randint(0, 1) == 1:
n_noise = self.get_noise()
snr = np.random.randint(-10, 5)
else:
n_noise = np.random.normal(loc=0, scale=1, size=(1, max_time*sr))
n_noise = librosa.to_mono(n_noise)
snr = np.random.randint(30, 50) #-10/5 for natural noise, 30/50
audio_mix = _add_noise(audio_mix, n_noise, snr, sr)
mag_mix, phase_mix = _stft(audio_mix)
......
......@@ -17,7 +17,7 @@ from unet import UNet5
import torch.nn as nn
#from tensorboardX import SummaryWriter
from matplotlib import image as mpimg
from Dataloader import Dataset
from Dataloader_solo import Dataset
import matplotlib
import numpy as np
import collections
......@@ -143,7 +143,7 @@ if __name__ == '__main__':
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.starting_training_time = time.time()
args.save_per_batchs = 500
args.nb_classes = 3
args.nb_classes = 1
args.mode = 'train'
args.lr_sounds = 1e-5
args.saved_model = '5_5000'
......@@ -165,7 +165,7 @@ if __name__ == '__main__':
#Dataset loading
root = './data_sound/trainset/'
ext = '.wav'
train_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
train_classes = Dataset(root, nb_classes=args.nb_classes, nb_classes_noise=3, path_background="./data_sound/noises/")
loader_train = torch.utils.data.DataLoader(
train_classes,
batch_size = args.batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment