Commit ac9c7030 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new nnet

parent 1268a4ce
......@@ -38,6 +38,10 @@ from .clustering.hac_utils import bic_square_root
from .clustering.cc_iv import ConnectedComponent
from .nnet.wavsets import AlliesSet
from .nnet.seqtoseq import PreNet
from .nnet.seqtoseq import BLSTM
from .model_iv import ModelIV
from .diar import Diar
......@@ -57,4 +61,4 @@ __maintainer__ = "Sylvain Meignier"
__email__ = "sylvain.meignierr@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__ = "0.1.4.4"
__version__ = "0.1.4.5"
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2020 Anthony Larcher
"""
from .wavsets import AlliesSet
from .seqtoseq import PreNet
from .seqtoseq import BLSTM
\ No newline at end of file
......@@ -45,20 +45,34 @@ __status__ = "Production"
__docformat__ = 'reS'
class PreNet(nn.Module):
def __init(self):
def __init(self,
sample_rate=16000,
windows_duration=0.2,
frame_shift=0.01):
super(PreNet, self).__init__()
self.conv1 = nn.Conv1d(in_channels=1,
out_channels=64,
kernel_size=200,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros')
windows_length = int(sample_rate * windows_duration)
if windows_length % 2:
windows_length += 1
stride_0 = int(sample_rate * frame_shift)
self.conv0 = torch.nn.Conv1d(1, 64, windows_length, stride=stride_0, dilation=1)
self.conv1 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.conv2 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.norm0 = torch.nn.BatchNorm1d(64)
self.norm1 = torch.nn.BatchNorm1d(64)
self.norm2 = torch.nn.BatchNorm1d(64)
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, input):
output = self.conv1(input)
x = self.norm0(self.activation(self.conv0(input)))
x = self.norm1(self.activation(self.conv1(x)))
output = self.norm2(self.activation(self.conv2(x)))
return output
class preprocessingBLSTM(nn.Module):
......
......@@ -119,7 +119,7 @@ def mdtm_to_label(mdtm_filename,
return label
def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, collar_duration, filter_type="gate")
def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, collar_duration, filter_type="gate"):
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment