Skip to content

Caching in LLaMPPL

LLaMPPL performs two kinds of caching to improve performance.

Log probability caching

Next-token log probabilities are always cached, whenever they are computed. This way, if different particles make exactly the same log probability queries, the Transformer is run only once. This is primarily beneficial when:

  • particles are cloned during resampling: when each particle is

  • cloned particles happen to sample the same next token: if the next-token distribution is concentrated, it is likely that multiple copies of a particle will sample the same next token. Log probability caching allows them to sample the following token using only a single call to the language model.

The log probability cache can be cleared using the lm.clear_cache() method. Note that this method will also clear the KV cache.

Key-value caching

Key-value caching caches the key and value vectors computed by each layer of a Transformer, for reuse when processing new tokens at the end of a previously evaluated sequence.

In principle, key-value caching is most useful when:

  • There is a long common prompt from which all particles are generating. In this case, the prompt's tokens can be evaluated just once by the language model, and each subsequent call only has to pay for the new tokens generated after the prompt.

  • Generations from the model are very long. In this case, it may be worth paying the memory cost to cache different key-value sequences for each particle, to speed up future next-token queries.

Currently, only the first use case is well-supported by the LLaMPPL library, via the lm.cache_kv(prompt) method. This method computes and caches key and value vectors for every token in prompt. Future calls to lm.next_token_logprobs and lm.next_token_logprobs_unbatched will automatically recognize when prompt is a prefix of the new query, and automatically exploit incremental computation. Multiple prompts can be cached, and lm.clear_kv_cache() can be used to clear the KV-cache without clearing the log probability cache.

Because lm.cache_kv is not a batched call, it is not well-suited to caching different strings for different particles. Rather, it is best used in the __init__ method of a model--or even outside of a model--on fixed prompt strings that every particle will share.