-
Notifications
You must be signed in to change notification settings - Fork 964
Open
Description
When computing the log probabilities for
prompts = (
'the sky is blue',
'the sky is pink',
'the sky is bacon',
)
I get very different values, depending on whether I use a cache=RotatingBufferCache(...)
or cache=None
:
use_cache=False:
[ 1] <s> -> [ 272] ▁the : -7.41 0.06%
[ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41%
[ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40%
[ 349] ▁is -> [ 5045] ▁blue : -0.81 44.68%
[ 1] <s> -> [ 272] ▁the : -7.41 0.06%
[ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41%
[ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40%
[ 349] ▁is -> [ 12937] ▁pink : -2.45 8.59%
[ 1] <s> -> [ 272] ▁the : -7.41 0.06%
[ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41%
[ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40%
[ 349] ▁is -> [ 287] ▁b : -5.00 0.67%
[ 287] ▁b -> [ 10364] acon : -0.04 96.17%
use_cache=True:
[ 1] <s> -> [ 272] ▁the : -9.24 0.01%
[ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06%
[ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46%
[ 349] ▁is -> [ 5045] ▁blue : -2.39 9.13%
[ 1] <s> -> [ 272] ▁the : -9.24 0.01%
[ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06%
[ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46%
[ 349] ▁is -> [ 12937] ▁pink : -4.82 0.81%
[ 1] <s> -> [ 272] ▁the : -9.24 0.01%
[ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06%
[ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46%
[ 349] ▁is -> [ 287] ▁b : -7.59 0.05%
[ 287] ▁b -> [ 10364] acon : -4.41 1.21%
The values without cache do not make any sense (the values with cache seem reasonable though).
Why is this? How can I use the model without cache?
Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb
Show code from Colab
def get_logprobs(model, tokenizer, prompts, *, use_cache):
"""Returns `(encoded_prompts, logprobs)`, optionally using the cache."""
encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts[:3]]
seqlens = [len(x) for x in encoded_prompts]
concatenated_prompts = torch.tensor(sum(encoded_prompts, []), device=model.device, dtype=torch.long)
if use_cache:
sliding_window = model.args.sliding_window
sliding_window = min(max(seqlens), sliding_window)
cache = mistral.cache.RotatingBufferCache(
model.args.n_layers,
model.args.max_batch_size,
sliding_window,
model.args.n_kv_heads,
model.args.head_dim,
)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()
else:
cache = None
prelogits = model.forward(
concatenated_prompts,
seqlens=seqlens,
cache=cache,
)
logits = torch.log_softmax(prelogits, dim=-1)
logprobs = [[] for _ in range(len(prompts))]
offset = 0
for i_seq, sequence in enumerate(encoded_prompts):
logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
offset += len(sequence)
return encoded_prompts, logprobs
def print_logprobs(id2token, encoded_prompts, logprobs):
"""prints `(encoded_prompts, logprobs)` tokens / transition probabilities."""
for i, t in enumerate(encoded_prompts):
for j, (t1, t2) in enumerate(zip(t, t[1:])):
logit = float(logprobs[i][j])
print(
f'[{t1:6}] {id2token(t1):8} '
f'-> [{t2:6}] {id2token(t2):8}: '
f'{logit:7.2f} '
f'{np.exp(logit):6.2%}'
)
print()
prompts = (
'the sky is blue',
'the sky is pink',
'the sky is bacon',
)
for use_cache in (False, True):
print(f'use_cache={use_cache}:\n')
print_logprobs(tokenizer._model.id_to_piece, *get_logprobs(model, tokenizer, prompts, use_cache=use_cache))
Metadata
Metadata
Assignees
Labels
No labels