@@ -371,6 +371,110 @@ def test_decode(self):
371
371
assert stream .step (tokenizer , 2 ) == " is"
372
372
assert stream .step (tokenizer , 3 ) == " john"
373
373
374
+ stream = DecodeStream (ids = [0 , 1 , 2 ])
375
+ assert stream .step (tokenizer , 3 ) == " john"
376
+
377
+ def test_decode_stream_fallback (self ):
378
+ tokenizer = Tokenizer .from_pretrained ("gpt2" )
379
+ # tokenizer.decode([255]) fails because its a fallback
380
+ # tokenizer.encode("อั").ids = [19567, 255, 19567, 109]
381
+ stream = DecodeStream ()
382
+ stream .step (tokenizer , [19567 ])
383
+ stream .step (tokenizer , [255 ])
384
+ stream .step (tokenizer , [19567 ])
385
+ out = stream .step (tokenizer , [109 ])
386
+ assert out == "ั"
387
+
388
+ stream = DecodeStream ()
389
+ out = stream .step (tokenizer , [19567 , 255 , 19567 , 109 ])
390
+ assert out == "อั"
391
+ stream = DecodeStream ()
392
+ stream .step (tokenizer , [19567 ])
393
+ out = stream .step (tokenizer , [255 , 19567 , 109 ])
394
+ assert out == "อั"
395
+
396
+ stream = DecodeStream ()
397
+ stream .step (tokenizer , [19567 ])
398
+ first_out = stream .step (tokenizer , [255 ])
399
+ assert first_out == "อ"
400
+ # since we emitted the 'อ', we can't produce 'อั'
401
+ out = stream .step (tokenizer , [19567 , 109 ])
402
+ assert out == "ั"
403
+
404
+ stream = DecodeStream ([19567 , 255 , 19567 ])
405
+ # the stream's prefix is 'อ�' which is invalid, thus all ids are kept for the next step
406
+ out = stream .step (tokenizer , [109 ])
407
+ assert out == "อั"
408
+
409
+ def test_decode_skip_special_tokens (self ):
410
+ tokenizer = Tokenizer .from_pretrained ("hf-internal-testing/Llama-3.1-8B-Instruct" )
411
+
412
+ stream = DecodeStream ([40 ])
413
+ out = stream .step (tokenizer , [2846 , 40 , 40 , 40 ])
414
+ assert out == "'mIII"
415
+
416
+ stream = DecodeStream (
417
+ [
418
+ 128000 ,
419
+ 128006 ,
420
+ 9125 ,
421
+ 128007 ,
422
+ 271 ,
423
+ 38766 ,
424
+ 1303 ,
425
+ 33025 ,
426
+ 2696 ,
427
+ 25 ,
428
+ 6790 ,
429
+ 220 ,
430
+ 2366 ,
431
+ 18 ,
432
+ 198 ,
433
+ 15724 ,
434
+ 2696 ,
435
+ 25 ,
436
+ 220 ,
437
+ 1627 ,
438
+ 10263 ,
439
+ 220 ,
440
+ 2366 ,
441
+ 19 ,
442
+ 271 ,
443
+ 9514 ,
444
+ 527 ,
445
+ 264 ,
446
+ 11190 ,
447
+ 18328 ,
448
+ 13 ,
449
+ 128009 ,
450
+ 128006 ,
451
+ 882 ,
452
+ 128007 ,
453
+ 271 ,
454
+ 15339 ,
455
+ 11 ,
456
+ 1268 ,
457
+ 527 ,
458
+ 499 ,
459
+ 30 ,
460
+ 128009 ,
461
+ 128006 ,
462
+ 78191 ,
463
+ 128007 ,
464
+ 271 ,
465
+ ]
466
+ )
467
+ out = stream .step (tokenizer , 40 )
468
+ assert out == "I"
469
+
470
+ stream = DecodeStream ([40 ])
471
+ out = stream .step (tokenizer , 2846 )
472
+ assert out == "'m"
473
+
474
+ stream = DecodeStream ([40 ])
475
+ out = stream .step (tokenizer , [2846 , 40 , 40 , 40 ])
476
+ assert out == "'mIII"
477
+
374
478
def test_decode_stream (self ):
375
479
vocab = [
376
480
("<unk>" , 0.0 ),
0 commit comments