Skip to content

Commit ce459b3

Browse files
authored
[cli] support campplus_200k and eres2net_200k models of damo (wenet-e2e#281)
* [cli] support campplus_200k_common and eres2net_200k_common models of damo * [cli] fix typo
1 parent 8619edb commit ce459b3

File tree

4 files changed

+101
-25
lines changed

4 files changed

+101
-25
lines changed

wespeaker/cli/hub.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class Hub(object):
7272
Assets = {
7373
"chinese": "cnceleb_resnet34.tar.gz",
7474
"english": "voxceleb_resnet221_LM.tar.gz",
75+
"campplus": "campplus_cn_common_200k.tar.gz",
76+
"eres2net": "eres2net_cn_commom_200k.tar.gz",
7577
}
7678

7779
def __init__(self) -> None:

wespeaker/cli/speaker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, model_dir: str):
5252
self.resample_rate = 16000
5353
self.apply_vad = False
5454
self.device = torch.device('cpu')
55+
self.wavform_norm = False
5556

5657
# diarization parmas
5758
self.diar_num_spks = None
@@ -64,6 +65,9 @@ def __init__(self, model_dir: str):
6465
self.diar_batch_size = 32
6566
self.diar_subseg_cmn = True
6667

68+
def set_wavform_norm(self, wavform_norm: bool):
69+
self.wavform_norm = wavform_norm
70+
6771
def set_resample_rate(self, resample_rate: int):
6872
self.resample_rate = resample_rate
6973

@@ -132,7 +136,8 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
132136
return embeddings
133137

134138
def extract_embedding(self, audio_path: str):
135-
pcm, sample_rate = torchaudio.load(audio_path, normalize=False)
139+
pcm, sample_rate = torchaudio.load(audio_path,
140+
normalize=self.wavform_norm)
136141
if self.apply_vad:
137142
# TODO(Binbin Zhang): Refine the segments logic, here we just
138143
# suppose there is only silence at the start/end of the speech
@@ -160,7 +165,6 @@ def extract_embedding(self, audio_path: str):
160165
feats = feats.to(self.device)
161166
self.model.eval()
162167
with torch.no_grad():
163-
# _, outputs = self.model(feats)
164168
outputs = self.model(feats)
165169
outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
166170
embedding = outputs[0].to(torch.device('cpu'))
@@ -301,7 +305,14 @@ def load_model_local(model_dir: str) -> Speaker:
301305
def main():
302306
args = get_args()
303307
if args.pretrain == "":
304-
model = load_model(args.language)
308+
if args.campplus:
309+
model = load_model("campplus")
310+
model.set_wavform_norm(True)
311+
elif args.eres2net:
312+
model = load_model("eres2net")
313+
model.set_wavform_norm(True)
314+
else:
315+
model = load_model(args.language)
305316
else:
306317
model = load_model_local(args.pretrain)
307318
model.set_resample_rate(args.resample_rate)

wespeaker/cli/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ def get_args():
3737
],
3838
default='chinese',
3939
help='language type')
40+
parser.add_argument(
41+
'--campplus',
42+
action='store_true',
43+
help='whether to use the damo/speech_campplus_sv_zh-cn_16k-common model'
44+
)
45+
parser.add_argument(
46+
'--eres2net',
47+
action='store_true',
48+
help='whether to use the damo/speech_eres2net_sv_zh-cn_16k-common model'
49+
)
4050
parser.add_argument('-p',
4151
'--pretrain',
4252
type=str,

wespeaker/models/eres2net.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,20 @@ def forward(self, x, ds_y):
103103

104104

105105
class BasicBlockERes2Net(nn.Module):
106-
expansion = 2
107106

108-
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
107+
def __init__(self,
108+
in_planes,
109+
planes,
110+
stride=1,
111+
baseWidth=32,
112+
scale=2,
113+
expansion=2):
109114
super(BasicBlockERes2Net, self).__init__()
110115
width = int(math.floor(planes * (baseWidth / 64.0)))
111116
self.conv1 = conv1x1(in_planes, width * scale, stride)
112117
self.bn1 = nn.BatchNorm2d(width * scale)
113118
self.nums = scale
119+
self.expansion = expansion
114120

115121
convs = []
116122
bns = []
@@ -162,14 +168,20 @@ def forward(self, x):
162168

163169

164170
class BasicBlockERes2Net_diff_AFF(nn.Module):
165-
expansion = 2
166171

167-
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
172+
def __init__(self,
173+
in_planes,
174+
planes,
175+
stride=1,
176+
baseWidth=32,
177+
scale=2,
178+
expansion=2):
168179
super(BasicBlockERes2Net_diff_AFF, self).__init__()
169180
width = int(math.floor(planes * (baseWidth / 64.0)))
170181
self.conv1 = conv1x1(in_planes, width * scale, stride)
171182
self.bn1 = nn.BatchNorm2d(width * scale)
172183
self.nums = scale
184+
self.expansion = expansion
173185

174186
# to meet the torch.jit.script export requirements
175187
self.conv2_1 = conv3x3(width, width)
@@ -232,6 +244,9 @@ class ERes2Net(nn.Module):
232244
def __init__(self,
233245
m_channels,
234246
num_blocks,
247+
baseWidth=32,
248+
scale=2,
249+
expansion=2,
235250
block=BasicBlockERes2Net,
236251
block_fuse=BasicBlockERes2Net_diff_AFF,
237252
feat_dim=80,
@@ -244,6 +259,7 @@ def __init__(self,
244259
self.embed_dim = embed_dim
245260
self.stats_dim = int(feat_dim / 8) * m_channels * 8
246261
self.two_emb_layer = two_emb_layer
262+
self.expansion = expansion
247263

248264
self.conv1 = nn.Conv2d(1,
249265
m_channels,
@@ -255,48 +271,59 @@ def __init__(self,
255271
self.layer1 = self._make_layer(block,
256272
m_channels,
257273
num_blocks[0],
258-
stride=1)
274+
stride=1,
275+
baseWidth=baseWidth,
276+
scale=scale,
277+
expansion=expansion)
259278
self.layer2 = self._make_layer(block,
260279
m_channels * 2,
261280
num_blocks[1],
262-
stride=2)
281+
stride=2,
282+
baseWidth=baseWidth,
283+
scale=scale,
284+
expansion=expansion)
263285
self.layer3 = self._make_layer(block_fuse,
264286
m_channels * 4,
265287
num_blocks[2],
266-
stride=2)
288+
stride=2,
289+
baseWidth=baseWidth,
290+
scale=scale,
291+
expansion=expansion)
267292
self.layer4 = self._make_layer(block_fuse,
268293
m_channels * 8,
269294
num_blocks[3],
270-
stride=2)
295+
stride=2,
296+
baseWidth=baseWidth,
297+
scale=scale,
298+
expansion=expansion)
271299

272300
# Downsampling module for each layer
273-
self.layer1_downsample = nn.Conv2d(m_channels * 2,
274-
m_channels * 4,
301+
self.layer1_downsample = nn.Conv2d(m_channels * expansion,
302+
m_channels * expansion * 2,
275303
kernel_size=3,
276304
stride=2,
277305
padding=1,
278306
bias=False)
279-
self.layer2_downsample = nn.Conv2d(m_channels * 4,
280-
m_channels * 8,
307+
self.layer2_downsample = nn.Conv2d(m_channels * expansion * 2,
308+
m_channels * expansion * 4,
281309
kernel_size=3,
282310
padding=1,
283311
stride=2,
284312
bias=False)
285-
self.layer3_downsample = nn.Conv2d(m_channels * 8,
286-
m_channels * 16,
313+
self.layer3_downsample = nn.Conv2d(m_channels * expansion * 4,
314+
m_channels * expansion * 8,
287315
kernel_size=3,
288316
padding=1,
289317
stride=2,
290318
bias=False)
291319

292320
# Bottom-up fusion module
293-
self.fuse_mode12 = AFF(channels=m_channels * 4)
294-
self.fuse_mode123 = AFF(channels=m_channels * 8)
295-
self.fuse_mode1234 = AFF(channels=m_channels * 16)
321+
self.fuse_mode12 = AFF(channels=m_channels * expansion * 2)
322+
self.fuse_mode123 = AFF(channels=m_channels * expansion * 4)
323+
self.fuse_mode1234 = AFF(channels=m_channels * expansion * 8)
296324

297325
self.pool = getattr(pooling_layers,
298-
pooling_func)(in_dim=self.stats_dim *
299-
block.expansion)
326+
pooling_func)(in_dim=self.stats_dim * expansion)
300327
self.pool_out_dim = self.pool.get_out_dim()
301328
self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim)
302329
if self.two_emb_layer:
@@ -306,12 +333,21 @@ def __init__(self,
306333
self.seg_bn_1 = nn.Identity()
307334
self.seg_2 = nn.Identity()
308335

309-
def _make_layer(self, block, planes, num_blocks, stride):
336+
def _make_layer(self,
337+
block,
338+
planes,
339+
num_blocks,
340+
stride,
341+
baseWidth=32,
342+
scale=2,
343+
expansion=2):
310344
strides = [stride] + [1] * (num_blocks - 1)
311345
layers = []
312346
for stride in strides:
313-
layers.append(block(self.in_planes, planes, stride))
314-
self.in_planes = planes * block.expansion
347+
layers.append(
348+
block(self.in_planes, planes, stride, baseWidth, scale,
349+
expansion))
350+
self.in_planes = planes * self.expansion
315351
return nn.Sequential(*layers)
316352

317353
def forward(self, x):
@@ -362,6 +398,23 @@ def ERes2Net34_Large(feat_dim,
362398
two_emb_layer=two_emb_layer)
363399

364400

401+
def ERes2Net34_aug(feat_dim,
402+
embed_dim,
403+
pooling_func='TSTP',
404+
two_emb_layer=False,
405+
expansion=4,
406+
baseWidth=24,
407+
scale=3):
408+
return ERes2Net(64, [3, 4, 6, 3],
409+
expansion=expansion,
410+
baseWidth=baseWidth,
411+
scale=scale,
412+
feat_dim=feat_dim,
413+
embed_dim=embed_dim,
414+
pooling_func=pooling_func,
415+
two_emb_layer=two_emb_layer)
416+
417+
365418
if __name__ == '__main__':
366419
x = torch.zeros(1, 200, 80)
367420
model = ERes2Net34_Base(feat_dim=80, embed_dim=512, two_emb_layer=False)

0 commit comments

Comments
 (0)