Commit c4d4493d authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new nnet module

parent 1268a4ce
......@@ -38,6 +38,10 @@ from .clustering.hac_utils import bic_square_root
from .clustering.cc_iv import ConnectedComponent
from .nnet.wavsets import AlliesSet
from .nnet.seqtoseq import PreNet
from .nnet.seqtoseq import BLSTM
from .model_iv import ModelIV
from .diar import Diar
......
......@@ -769,7 +769,8 @@ class Diar():
if not diarization._attributes.exist('channel'):
diarization.add_attribut(new_attribut='channel', default='U')
try:
for line in fic line = re.sub('\s+',' ',line)
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
# logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
......
......@@ -44,24 +44,44 @@ __email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'
class PreNet(nn.Module):
def __init(self):
def __init(self,
sample_rate=16000,
windows_duration=0.2,
frame_shift=0.01):
super(PreNet, self).__init__()
self.conv1 = nn.Conv1d(in_channels=1,
out_channels=64,
kernel_size=200,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros')
windows_length = int(sample_rate * windows_duration)
if windows_length % 2:
windows_length += 1
stride_0 = int(sample_rate * frame_shift)
self.conv0 = torch.nn.Conv1d(1, 64, windows_length, stride=stride_0, dilation=1)
self.conv1 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.conv2 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.norm0 = torch.nn.BatchNorm1d(64)
self.norm1 = torch.nn.BatchNorm1d(64)
self.norm2 = torch.nn.BatchNorm1d(64)
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, input):
output = self.conv1(input)
x = self.norm0(self.activation(self.conv0(input)))
x = self.norm1(self.activation(self.conv1(x)))
output = self.norm2(self.activation(self.conv2(x)))
return output
class preprocessingBLSTM(nn.Module):
class BLSTM(nn.Module):
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
......
......@@ -111,15 +111,15 @@ def mdtm_to_label(mdtm_filename,
# Fill the labels with spk_idx
for segment in diarization:
start = int(segment['start']) * framerate / 100.
stop = int(segment['stop']) * framerate / 100.
start = int(segment['start']) * framerate // 100
stop = int(segment['stop']) * framerate // 100
spk_idx = speaker_dict[segment['cluster']]
label[start:stop] = spk_idx
return label
def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, collar_duration, filter_type="gate")
def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, collar_duration, filter_type="gate"):
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment