Skip to content

Commit 0944f75

Browse files
committed
first commit
0 parents  commit 0944f75

File tree

9 files changed

+488
-0
lines changed

9 files changed

+488
-0
lines changed

README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Densely Connected Time Delay Neural Network
2+
3+
PyTorch implementation of Densely Connected Time Delay Neural Network (D-TDNN) in our paper ["Densely Connected Time Delay Neural Network for Speaker Verification"](https://www.isca-speech.org/archive/Interspeech_2020/abstracts/1275.html) (INTERSPEECH 2020).
4+
5+
We provide the [pretrained models](https://github.com/yuyq96/D-TDNN/releases) which can be used in many tasks such as:
6+
7+
- Speaker Verification
8+
- Speaker Adaption for Speech Recognition
9+
- Speaker-Dependent Speech Separation
10+
- Multi-Speaker Text-to-Speech
11+
12+
![D-TDNN & D-TDNN-SS](figure/D_TDNN.png)
13+
14+
## Usage
15+
16+
Data preparation
17+
* Install [Kaldi](https://github.com/kaldi-asr/kaldi) toolkit.
18+
* Download [VoxCeleb1 test set](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and unzip it.
19+
* Place `prepare_voxceleb1_test.sh` under `$kaldi_root/egs/voxceleb/v2` and change the `$datadir` and `$voxceleb1_root` in it.
20+
* Run `chmod +x prepare_voxceleb1_test.sh && ./prepare_voxceleb1_test.sh` to generate acoustic features ([30-Dim MFCCs](https://github.com/kaldi-asr/kaldi/blob/master/egs/voxceleb/v2/conf/mfcc.conf)).
21+
* Replace the `trials` under `$datadir/test_no_sil` with the [clean version](https://github.com/yuyq96/D-TDNN/releases).
22+
23+
Test
24+
```
25+
python main.py --root $datadir/test_no_sil --model D-TDNN --checkpoint model_zoo/dtdnn.pth --device cuda
26+
```
27+
28+
## Evaluation
29+
30+
VoxCeleb1-O
31+
32+
| Model | Emb. | Params (M) | Loss | Backend | EER (%) | DCF_0.01 | DCF_0.001 |
33+
| :---- | :--: | :--------: | :--: | :-----: | :-----: | :------: | :-------: |
34+
| [TDNN](https://github.com/yuyq96/D-TDNN/releases) | 512 | 4.2 | Softmax | PLDA | 2.34 | 0.28 | 0.38 |
35+
| E-TDNN | 512 | 6.1 | Softmax | PLDA | 2.08 | 0.26 | 0.41 |
36+
| F-TDNN | 512 | 12.4 | Softmax | PLDA | 1.89 | 0.21 | 0.29 |
37+
| [D-TDNN](https://github.com/yuyq96/D-TDNN/releases) | 512 | 2.8 | Softmax | Cosine | 1.81 | 0.20 | 0.28 |
38+
| D-TDNN-SS (0) | 512 | 3.0 | Softmax | Cosine | 1.55 | 0.20 | 0.30 |
39+
| D-TDNN-SS | 512 | 3.5 | Softmax | Cosine | 1.41 | 0.19 | 0.24 |
40+
| D-TDNN-SS | 128 | 3.1 | AAM-Softmax | Cosine | 1.22 | 0.13 | 0.20 |
41+
42+
## Citation
43+
44+
If you find D-TDNN helps your research, please cite
45+
```
46+
@inproceedings{DBLP:conf/interspeech/YuL20,
47+
author = {Ya-Qi Yu and
48+
Wu-Jun Li},
49+
title = {Densely Connected Time Delay Neural Network for Speaker Verification},
50+
booktitle = {Annual Conference of the International Speech Communication Association (INTERSPEECH)},
51+
year = {2020}
52+
}
53+
```

data.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
3+
import kaldiio
4+
from torch.utils.data import Dataset
5+
6+
7+
class KaldiFeatDataset(Dataset):
8+
9+
def __init__(self, root, transform=None):
10+
super(KaldiFeatDataset, self).__init__()
11+
self.transform = transform
12+
self.feats = []
13+
with open(os.path.join(root, 'feats.scp'), 'r') as f:
14+
for line in f:
15+
utt, feats = line.split(' ')
16+
self.feats.append((feats, utt))
17+
18+
def __len__(self):
19+
return len(self.feats)
20+
21+
def __getitem__(self, index):
22+
feats, utt = self.feats[index]
23+
feats = kaldiio.load_mat(feats)
24+
if self.transform is not None:
25+
feats = self.transform(feats)
26+
return feats, utt
27+
28+
29+
class Transpose2D(object):
30+
31+
def __call__(self, a):
32+
return a.transpose((1, 0))

figure/D_TDNN.png

138 KB
Loading

main.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
from numpy import linalg
7+
from torch.utils.data import DataLoader
8+
from tqdm import tqdm
9+
10+
from data import KaldiFeatDataset, Transpose2D
11+
from metric import compute_fnr_fpr, compute_eer, compute_c_norm
12+
from model.tdnn import TDNN
13+
from model.dtdnn import DTDNN
14+
15+
parser = argparse.ArgumentParser(description='Speaker Verification')
16+
parser.add_argument('--root', default='data', type=str)
17+
parser.add_argument('--model', default='D-TDNN', choices=['TDNN', 'D-TDNN'])
18+
parser.add_argument('--checkpoint', default=None, type=str)
19+
parser.add_argument('--device', default="cpu", choices=['cpu', 'cuda'])
20+
parser.add_argument('--pin-memory', default=True, type=bool)
21+
22+
23+
def load_model():
24+
assert os.path.isfile(args.checkpoint), "No checkpoint found at '{}'".format(args.checkpoint)
25+
print('Loading checkpoint {}'.format(args.checkpoint))
26+
state_dict = torch.load(args.checkpoint)['state_dict']
27+
if args.model == 'TDNN':
28+
model = TDNN()
29+
else:
30+
model = DTDNN()
31+
model.to(device)
32+
model.load_state_dict(state_dict)
33+
return model
34+
35+
36+
def test():
37+
model = load_model()
38+
model.eval()
39+
40+
transform = Transpose2D()
41+
dataset = KaldiFeatDataset(root=args.root, transform=transform)
42+
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=args.pin_memory)
43+
44+
utt2emb = {}
45+
for data, utt in tqdm(loader):
46+
with torch.no_grad():
47+
data = data.to(device)
48+
emb = model(data)
49+
utt2emb[utt[0]] = emb[0].cpu().numpy()
50+
51+
with open(os.path.join(args.root, 'trials'), 'r') as f:
52+
scores = []
53+
labels = []
54+
for line in f:
55+
utt1, utt2, label = line.split(' ')
56+
emb1, emb2 = utt2emb[utt1], utt2emb[utt2]
57+
score = emb1.dot(emb2) / (linalg.norm(emb1) * linalg.norm(emb2))
58+
scores.append(score)
59+
labels.append(1 if label.strip() == 'target' else 0)
60+
scores = np.array(scores)
61+
labels = np.array(labels)
62+
fnr, fpr = compute_fnr_fpr(scores, labels)
63+
eer, th = compute_eer(fnr, fpr, True, scores)
64+
print('Equal error rate is {:6f}%, at threshold {:6f}'.format(eer * 100, th))
65+
print('Minimum detection cost (0.01) is {:6f}'.format(compute_c_norm(fnr, fpr, 0.01)))
66+
print('Minimum detection cost (0.001) is {:6f}'.format(compute_c_norm(fnr, fpr, 0.001)))
67+
68+
69+
if __name__ == '__main__':
70+
args = parser.parse_args()
71+
device = torch.device(args.device)
72+
test()

metric.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
3+
4+
def compute_fnr_fpr(scores, labels):
5+
""" computes false negative rate (FNR) and false positive rate (FPR)
6+
given trial scores and their labels.
7+
"""
8+
9+
indices = np.argsort(scores)
10+
labels = labels[indices]
11+
12+
target = (labels == 1).astype('f8')
13+
nontar = (labels == 0).astype('f8')
14+
15+
fnr = np.cumsum(target) / np.sum(target)
16+
fpr = 1 - np.cumsum(nontar) / np.sum(nontar)
17+
return fnr, fpr
18+
19+
20+
def compute_eer(fnr, fpr, requires_threshold=False, scores=None):
21+
""" computes the equal error rate (EER) given FNR and FPR values calculated
22+
for a range of operating points on the DET curve
23+
*kaldi style*
24+
"""
25+
26+
diff_miss_fa = fnr - fpr
27+
x = np.flatnonzero(diff_miss_fa >= 0)[0]
28+
eer = fnr[x - 1]
29+
if requires_threshold:
30+
assert scores is not None
31+
scores = np.sort(scores)
32+
th = scores[x]
33+
return eer, th
34+
return eer
35+
36+
37+
def compute_c_norm(fnr, fpr, p_target, c_miss=1, c_fa=1):
38+
""" computes normalized minimum detection cost function (DCF) given
39+
the costs for false accepts and false rejects as well as a priori
40+
probability for target speakers
41+
"""
42+
43+
dcf = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target)
44+
c_det = np.min(dcf)
45+
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
46+
return c_det/c_def

model/dtdnn.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from collections import OrderedDict
2+
3+
from torch import nn
4+
5+
from .layers import TDNNLayer, DenseTDNNBlock, TransitLayer, DenseLayer, StatsPool
6+
7+
8+
class DTDNN(nn.Module):
9+
10+
def __init__(self, feat_dim=30, embedding_size=512,
11+
growth_rate=64, bn_size=2, init_channels=128,
12+
config_str='batchnorm-relu'):
13+
super(DTDNN, self).__init__()
14+
15+
self.xvector = nn.Sequential(OrderedDict([
16+
('tdnn', TDNNLayer(feat_dim, init_channels, 5, dilation=1, padding=-1,
17+
config_str=config_str)),
18+
]))
19+
channels = init_channels
20+
for i, (num_layers, kernel_size, dilation) in enumerate(zip((6, 12), (3, 3), (1, 3))):
21+
block = DenseTDNNBlock(
22+
num_layers=num_layers,
23+
in_channels=channels,
24+
out_channels=growth_rate,
25+
bn_channels=bn_size * growth_rate,
26+
kernel_size=kernel_size,
27+
dilation=dilation,
28+
config_str=config_str
29+
)
30+
self.xvector.add_module('block%d' % (i + 1), block)
31+
channels = channels + num_layers * growth_rate
32+
self.xvector.add_module(
33+
'transit%d' % (i + 1), TransitLayer(channels, channels // 2, bias=False,
34+
config_str=config_str))
35+
channels //= 2
36+
self.xvector.add_module('stats', StatsPool())
37+
self.xvector.add_module('dense', DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
38+
39+
def forward(self, x):
40+
return self.xvector(x)

0 commit comments

Comments
 (0)