Skip to content

Commit abee958

Browse files
authored
New stream (#1856)
* update * update * updates * up * oikay * use stream input * nice all test pass? * fmt * dev * rename * simplify a hell lot * proper testing * fix inti * fix test * nits * make clippy happy now * fmt fml * remove the prints * fix gate
1 parent 95b882a commit abee958

File tree

7 files changed

+164
-21
lines changed

7 files changed

+164
-21
lines changed

bindings/python/py_src/tokenizers/decoders/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ class DecodeStream:
44
Class needed for streaming decode
55
66
"""
7-
def __init__(self, skip_special_tokens):
7+
def __init__(self, ids=None, skip_special_tokens=False):
88
pass
99

1010
class Decoder:

bindings/python/src/decoders.rs

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,21 +646,44 @@ pub struct PyDecodeStream {
646646
prefix_index: usize,
647647
}
648648

649+
#[derive(Clone)]
650+
enum StreamInput {
651+
Id(u32),
652+
Ids(Vec<u32>),
653+
}
654+
655+
impl FromPyObject<'_> for StreamInput {
656+
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
657+
if let Ok(id) = obj.extract::<u32>() {
658+
Ok(StreamInput::Id(id))
659+
} else if let Ok(ids) = obj.extract::<Vec<u32>>() {
660+
Ok(StreamInput::Ids(ids))
661+
} else {
662+
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
663+
"StreamInput must be either an integer or a list of integers",
664+
))
665+
}
666+
}
667+
}
668+
649669
#[pymethods]
650670
impl PyDecodeStream {
651671
#[new]
652-
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
653-
fn new(skip_special_tokens: bool) -> Self {
672+
#[pyo3(signature = (ids=None, skip_special_tokens=false), text_signature = "(self, ids=None, skip_special_tokens=False)")]
673+
fn new(ids: Option<Vec<u32>>, skip_special_tokens: Option<bool>) -> Self {
654674
PyDecodeStream {
655-
skip_special_tokens,
656-
ids: vec![],
657-
prefix: "".to_string(),
675+
skip_special_tokens: skip_special_tokens.unwrap_or(false),
676+
ids: ids.unwrap_or_default(),
677+
prefix: String::new(),
658678
prefix_index: 0,
659679
}
660680
}
661-
662681
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
663-
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
682+
fn step(&mut self, tokenizer: &PyTokenizer, id: StreamInput) -> PyResult<Option<String>> {
683+
let id: Vec<u32> = match id {
684+
StreamInput::Id(id) => vec![id],
685+
StreamInput::Ids(ids) => ids,
686+
};
664687
ToPyResult(tk::tokenizer::step_decode_stream(
665688
&tokenizer.tokenizer,
666689
id,

bindings/python/src/processors.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ impl PyBertProcessing {
341341
}
342342

343343
#[getter]
344-
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
344+
fn get_sep(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
345345
let py = self_.py();
346346
let (tok, id) = getter!(self_, Bert, get_sep_copy());
347347
PyTuple::new(
@@ -358,7 +358,7 @@ impl PyBertProcessing {
358358
}
359359

360360
#[getter]
361-
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
361+
fn get_cls(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
362362
let py = self_.py();
363363
let (tok, id) = getter!(self_, Bert, get_cls_copy());
364364
PyTuple::new(
@@ -422,7 +422,7 @@ impl PyRobertaProcessing {
422422
}
423423

424424
#[getter]
425-
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
425+
fn get_sep(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
426426
let py = self_.py();
427427
let (tok, id) = getter!(self_, Roberta, get_sep_copy());
428428
PyTuple::new(
@@ -439,7 +439,7 @@ impl PyRobertaProcessing {
439439
}
440440

441441
#[getter]
442-
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
442+
fn get_cls(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
443443
let py = self_.py();
444444
let (tok, id) = getter!(self_, Roberta, get_cls_copy());
445445
PyTuple::new(

bindings/python/tests/bindings/test_tokenizer.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,110 @@ def test_decode(self):
371371
assert stream.step(tokenizer, 2) == " is"
372372
assert stream.step(tokenizer, 3) == " john"
373373

374+
stream = DecodeStream(ids=[0, 1, 2])
375+
assert stream.step(tokenizer, 3) == " john"
376+
377+
def test_decode_stream_fallback(self):
378+
tokenizer = Tokenizer.from_pretrained("gpt2")
379+
# tokenizer.decode([255]) fails because its a fallback
380+
# tokenizer.encode("อั").ids = [19567, 255, 19567, 109]
381+
stream = DecodeStream()
382+
stream.step(tokenizer, [19567])
383+
stream.step(tokenizer, [255])
384+
stream.step(tokenizer, [19567])
385+
out = stream.step(tokenizer, [109])
386+
assert out == "ั"
387+
388+
stream = DecodeStream()
389+
out = stream.step(tokenizer, [19567, 255, 19567, 109])
390+
assert out == "อั"
391+
stream = DecodeStream()
392+
stream.step(tokenizer, [19567])
393+
out = stream.step(tokenizer, [255, 19567, 109])
394+
assert out == "อั"
395+
396+
stream = DecodeStream()
397+
stream.step(tokenizer, [19567])
398+
first_out = stream.step(tokenizer, [255])
399+
assert first_out == "อ"
400+
# since we emitted the 'อ', we can't produce 'อั'
401+
out = stream.step(tokenizer, [19567, 109])
402+
assert out == "ั"
403+
404+
stream = DecodeStream([19567, 255, 19567])
405+
# the stream's prefix is 'อ�' which is invalid, thus all ids are kept for the next step
406+
out = stream.step(tokenizer, [109])
407+
assert out == "อั"
408+
409+
def test_decode_skip_special_tokens(self):
410+
tokenizer = Tokenizer.from_pretrained("hf-internal-testing/Llama-3.1-8B-Instruct")
411+
412+
stream = DecodeStream([40])
413+
out = stream.step(tokenizer, [2846, 40, 40, 40])
414+
assert out == "'mIII"
415+
416+
stream = DecodeStream(
417+
[
418+
128000,
419+
128006,
420+
9125,
421+
128007,
422+
271,
423+
38766,
424+
1303,
425+
33025,
426+
2696,
427+
25,
428+
6790,
429+
220,
430+
2366,
431+
18,
432+
198,
433+
15724,
434+
2696,
435+
25,
436+
220,
437+
1627,
438+
10263,
439+
220,
440+
2366,
441+
19,
442+
271,
443+
9514,
444+
527,
445+
264,
446+
11190,
447+
18328,
448+
13,
449+
128009,
450+
128006,
451+
882,
452+
128007,
453+
271,
454+
15339,
455+
11,
456+
1268,
457+
527,
458+
499,
459+
30,
460+
128009,
461+
128006,
462+
78191,
463+
128007,
464+
271,
465+
]
466+
)
467+
out = stream.step(tokenizer, 40)
468+
assert out == "I"
469+
470+
stream = DecodeStream([40])
471+
out = stream.step(tokenizer, 2846)
472+
assert out == "'m"
473+
474+
stream = DecodeStream([40])
475+
out = stream.step(tokenizer, [2846, 40, 40, 40])
476+
assert out == "'mIII"
477+
374478
def test_decode_stream(self):
375479
vocab = [
376480
("<unk>", 0.0),

tokenizers/src/models/unigram/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ impl Unigram {
356356
}
357357

358358
/// Iterate of vocabulary of the model as a pair of `(token, score)`.
359-
pub fn iter(&self) -> UnigramIterator {
359+
pub fn iter(&self) -> UnigramIterator<'_> {
360360
UnigramIterator { model: self, i: 0 }
361361
}
362362

tokenizers/src/models/unigram/trie.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl<Label: Eq + Hash + Copy> Trie<Label> {
3030
node.is_leaf = true;
3131
}
3232

33-
pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<Label, T>
33+
pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<'_, Label, T>
3434
where
3535
T: Iterator<Item = Label>,
3636
{

tokenizers/src/tokenizer/mod.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,8 +1041,12 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {
10411041

10421042
#[derive(thiserror::Error, Debug)]
10431043
pub enum DecodeStreamError {
1044-
#[error("Invalid prefix encountered")]
1045-
InvalidPrefix,
1044+
#[error("Invalid prefix encountered while decoding stream. Token ID: {token_id}, Expected prefix: '{expected_prefix}', Actual string: '{actual_string}'")]
1045+
InvalidPrefix {
1046+
token_id: u32,
1047+
expected_prefix: String,
1048+
actual_string: String,
1049+
},
10461050
}
10471051

10481052
impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
@@ -1067,7 +1071,7 @@ where
10671071
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
10681072
step_decode_stream(
10691073
self.tokenizer,
1070-
id,
1074+
vec![id],
10711075
self.skip_special_tokens,
10721076
&mut self.ids,
10731077
&mut self.prefix,
@@ -1079,7 +1083,7 @@ where
10791083
/// Internal function exposed only to bypass python limitations
10801084
pub fn step_decode_stream<M, N, PT, PP, D>(
10811085
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
1082-
id: u32,
1086+
token_ids: Vec<u32>,
10831087
skip_special_tokens: bool,
10841088
ids: &mut Vec<u32>,
10851089
prefix: &mut String,
@@ -1092,12 +1096,25 @@ where
10921096
PP: PostProcessor,
10931097
D: Decoder,
10941098
{
1095-
ids.push(id);
1099+
if prefix.is_empty() && !ids.is_empty() {
1100+
let new_prefix = tokenizer.decode(ids, skip_special_tokens)?;
1101+
if !new_prefix.ends_with('�') {
1102+
*prefix = new_prefix;
1103+
*prefix_index = ids.len();
1104+
}
1105+
}
1106+
1107+
ids.extend(token_ids);
10961108
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
10971109
if string.len() > prefix.len() && !string.ends_with('�') {
10981110
if !(string.starts_with(&*prefix)) {
1099-
return Err(Box::new(DecodeStreamError::InvalidPrefix));
1111+
return Err(Box::new(DecodeStreamError::InvalidPrefix {
1112+
token_id: *ids.last().unwrap(),
1113+
expected_prefix: prefix.clone(),
1114+
actual_string: string,
1115+
}));
11001116
}
1117+
11011118
let new_text = &string[prefix.len()..].to_string();
11021119
let new_prefix_index = ids.len() - *prefix_index;
11031120
*ids = ids.drain(*prefix_index..).collect();
@@ -1108,7 +1125,6 @@ where
11081125
Ok(None)
11091126
}
11101127
}
1111-
11121128
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
11131129
where
11141130
M: Model,

0 commit comments

Comments
 (0)