Commit 67bd00fb authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add sincnet

parent 8caeb64d
......@@ -82,9 +82,10 @@ class Xtractor(torch.nn.Module):
"""
super(Xtractor, self).__init__()
self.speaker_number = speaker_number
self.feature_size = 24
self.feature_size = None
if model_archi is None:
self.feature_size = 30
self.activation = torch.nn.ReLU()
self.sequence_network = torch.nn.Sequential(OrderedDict([
......@@ -123,8 +124,33 @@ class Xtractor(torch.nn.Module):
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
"""
Prepare Preprocessor
"""
if "preprocessor" in cfg:
if cfg['preprocessor']["type"] == "sincnet":
self.sincnet = SincNet(
waveform_normalize=cfg['preprocessor']["waveform_normalize"],
sample_rate=cfg['preprocessor']["sample_rate"],
min_low_hz=cfg['preprocessor']["min_low_hz"],
min_band_hz=cfg['preprocessor']["min_band_hz"],
out_channels=cfg['preprocessor']["out_channels"],
kernel_size=cfg['preprocessor']["kernel_size"],
stride=cfg['preprocessor']["stride"],
max_pool=cfg['preprocessor']["max_pool"],
instance_normalize=cfg['preprocessor']["instance_normalize"],
activation=cfg['preprocessor']["activation"],
dropout=cfg['preprocessor']["dropout"]
)
self.feature_size = self.sincnet_.dimension
"""
Prepapre sequence network
"""
# Get Feature size
self.feature_size = cfg["feature_size"]
if self.feature_size is None:
self.feature_size = cfg["feature_size"]
input_size = self.feature_size
# Get activation function
......@@ -156,6 +182,9 @@ class Xtractor(torch.nn.Module):
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
self.sequence_network_weight_decay = cfg["segmental"]["weight_decay"]
"""
Prepapre last part of the network (after pooling)
"""
# Create sequential object for the second part of the network
input_size = input_size * 2
before_embedding_layers = []
......@@ -390,6 +419,7 @@ def cross_validation(model, validation_loader, device):
criterion = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment