Skip to content

Why does cache=None produce different outputs? #88

@andsteing

Description

@andsteing

When computing the log probabilities for

prompts = (
    'the sky is blue',
    'the sky is pink',
    'the sky is bacon',
)

I get very different values, depending on whether I use a cache=RotatingBufferCache(...) or cache=None:

use_cache=False:

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [  5045] ▁blue   :   -0.81 44.68%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [ 12937] ▁pink   :   -2.45  8.59%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [   287] ▁b      :   -5.00  0.67%
[   287] ▁b       -> [ 10364] acon    :   -0.04 96.17%

use_cache=True:

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [  5045] ▁blue   :   -2.39  9.13%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [ 12937] ▁pink   :   -4.82  0.81%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [   287] ▁b      :   -7.59  0.05%
[   287] ▁b       -> [ 10364] acon    :   -4.41  1.21%

The values without cache do not make any sense (the values with cache seem reasonable though).

Why is this? How can I use the model without cache?

Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb

Show code from Colab
def get_logprobs(model, tokenizer, prompts, *, use_cache):
  """Returns `(encoded_prompts, logprobs)`, optionally using the cache."""

  encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts[:3]]
  seqlens = [len(x) for x in encoded_prompts]
  concatenated_prompts = torch.tensor(sum(encoded_prompts, []), device=model.device, dtype=torch.long)

  if use_cache:
    sliding_window = model.args.sliding_window
    sliding_window = min(max(seqlens), sliding_window)

    cache = mistral.cache.RotatingBufferCache(
        model.args.n_layers,
        model.args.max_batch_size,
        sliding_window,
        model.args.n_kv_heads,
        model.args.head_dim,
    )
    cache.to(device=model.device, dtype=model.dtype)
    cache.reset()
  else:
    cache = None

  prelogits = model.forward(
      concatenated_prompts,
      seqlens=seqlens,
      cache=cache,
  )

  logits = torch.log_softmax(prelogits, dim=-1)
  logprobs = [[] for _ in range(len(prompts))]
  offset = 0
  for i_seq, sequence in enumerate(encoded_prompts):
    logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
    offset += len(sequence)

  return encoded_prompts, logprobs


def print_logprobs(id2token, encoded_prompts, logprobs):
  """prints `(encoded_prompts, logprobs)` tokens / transition probabilities."""
  for i, t in enumerate(encoded_prompts):
    for j, (t1, t2) in enumerate(zip(t, t[1:])):
      logit = float(logprobs[i][j])
      print(
          f'[{t1:6}] {id2token(t1):8} '
          f'-> [{t2:6}] {id2token(t2):8}: '
          f'{logit:7.2f} '
          f'{np.exp(logit):6.2%}'
      )
    print()


prompts = (
    'the sky is blue',
    'the sky is pink',
    'the sky is bacon',
)

for use_cache in (False, True):
  print(f'use_cache={use_cache}:\n')
  print_logprobs(tokenizer._model.id_to_piece, *get_logprobs(model, tokenizer, prompts, use_cache=use_cache))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions