seqtoseq.py 4.15 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

import os
import sys
import numpy
import random
import h5py
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset
import logging


__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'

Anthony Larcher's avatar
Anthony Larcher committed
46
47
48



Anthony Larcher's avatar
Anthony Larcher committed
49
class PreNet(nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
50
    def __init__(self,
Anthony Larcher's avatar
Anthony Larcher committed
51
52
53
               sample_rate=16000,
               windows_duration=0.2,
               frame_shift=0.01):
Anthony Larcher's avatar
Anthony Larcher committed
54
        super(PreNet, self).__init__()
Anthony Larcher's avatar
Anthony Larcher committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

        windows_length = int(sample_rate * windows_duration)
        if windows_length % 2:
            windows_length += 1
        stride_0 = int(sample_rate * frame_shift)

        self.conv0 = torch.nn.Conv1d(1, 64, windows_length, stride=stride_0, dilation=1)
        self.conv1 = torch.nn.Conv1d(64, 64, 3, dilation=1)
        self.conv2 = torch.nn.Conv1d(64, 64, 3, dilation=1)

        self.norm0 = torch.nn.BatchNorm1d(64)
        self.norm1 = torch.nn.BatchNorm1d(64)
        self.norm2 = torch.nn.BatchNorm1d(64)

        self.activation = torch.nn.LeakyReLU(0.2)

Anthony Larcher's avatar
Anthony Larcher committed
71
72

    def forward(self, input):
Anthony Larcher's avatar
Anthony Larcher committed
73
74
75
76
77

        x = self.norm0(self.activation(self.conv0(input)))
        x = self.norm1(self.activation(self.conv1(x)))
        output = self.norm2(self.activation(self.conv2(x)))

Anthony Larcher's avatar
Anthony Larcher committed
78
79
        return output

Anthony Larcher's avatar
Anthony Larcher committed
80
81
82
83



class BLSTM(nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    """
    Bi LSTM model used for voice activity detection or speaker turn detection
    """
    def __init__(self,
                 input_size,
                 lstm_1,
                 lstm_2,
                 linear_1,
                 linear_2,
                 output_size=1):
        """

        :param input_size:
        :param lstm_1:
        :param lstm_2:
        :param linear_1:
        :param linear_2:
        :param output_size:
        """
        super(BLSTM, self).__init__()

        self.lstm_1 = nn.LSTM(input_size, lstm_1 // 2, bidirectional=True, batch_first=True)
        self.lstm_2 = nn.LSTM(lstm_1, lstm_2 // 2, bidirectional=True, batch_first=True)
        self.linear_1 = nn.Linear(lstm_2, linear_1)
        self.linear_2 = nn.Linear(linear_1, linear_2)
        self.output = nn.Linear(linear_2, output_size)
        self.hidden = None

    def forward(self, inputs):
        """

        :param inputs:
        :return:
        """
        if self.hidden is None:
            hidden_1, hidden_2 = None, None
        else:
            hidden_1, hidden_2 = self.hidden
        tmp, hidden_1 = self.lstm_1(inputs, hidden_1)
        x, hidden_2 = self.lstm_2(tmp, hidden_2)
        self.hidden = (hidden_1, hidden_2)
        x = torch.tanh(self.linear_1(x))
        x = torch.tanh(self.linear_2(x))
        x = torch.sigmoid(self.output(x))
        return x

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
130

Anthony Larcher's avatar
Anthony Larcher committed
131
132
133
class SeqToSeq(nn.Module):

    def __init__(self):
Anthony Larcher's avatar
Anthony Larcher committed
134
        super(SeqToSeq, self).__init__()
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
135
136
        self.preprocessor = PreNet(sample_rate=16000,
                                   windows_duration=0.2,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
137
138
                                   frame_shift=0.01)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
139
        self.sequence_model = BLSTM(input_size=1,
Anthony Larcher's avatar
Anthony Larcher committed
140
141
142
143
144
                           lstm_1=64,
                           lstm_2=40,
                           linear_1=40,
                           linear_2=10)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
145
146
147
148
    def forward(self, input):
        x = self.preprocessor(input)
        output = self.sequence_model(x)
        return output
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
149
150