Skip to content

lmcontext

LMContext

Represents a generation-in-progress from a language model.

The state tracks two pieces of information:

  • A sequence of tokens — the ever-growing context for the language model.
  • A current mask — a set of tokens that have not yet been ruled out as the next token.

Storing a mask enables sub-token generation: models can use LMContext to sample the next token in stages, first deciding, e.g., whether to use an upper-case or lower-case first letter, and only later deciding which upper-case or lower-case token to generate.

The state of a LMContext can be advanced in two ways:

  1. Sampling, observing, or intervening the next_token() distribution. This causes a token to be added to the growing sequence of tokens. Supports auto-batching.
  2. Sampling, observing, or intervening the mask_dist(mask) distribution for a given mask (set of token ids). This changes the current mask.

Attributes:

Name Type Description
lm CachedCausalLM

the language model for which this is a context

tokens list[int]

the underlying sequence of tokens, including prompt, in this context

next_token_logprobs array

numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by CachedCausalLM.next_token_logprobs, these probabilities are rescaled for this LMContext's temperature parameter, and for any active masks. This vector is managed by the LMContext object internally; do not mutate.

temp float

temeprature for next-token distribution (0 < temp < float('inf'))

model_mask set[int]

set of tokens that have not been ruled out as the next token. This mask is managed by the LMContext object internally; do not mutate.

show_prompt bool

controls whether the string representation of this LMContext includes the initial prompt or not. Defaults to False.

Source code in hfppl/distributions/lmcontext.py
class LMContext:
    """Represents a generation-in-progress from a language model.

    The state tracks two pieces of information:

    * A sequence of tokens — the ever-growing context for the language model.
    * A *current mask* — a set of tokens that have not yet been ruled out as the next token.

    Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample
    the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case
    first letter, and only later deciding which upper-case or lower-case token to generate.

    The state of a `LMContext` can be advanced in two ways:

    1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token
    to be added to the growing sequence of tokens. Supports auto-batching.
    2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of
    token ids). This changes the current mask.

    Attributes:
        lm (hfppl.llms.CachedCausalLM): the language model for which this is a context
        tokens (list[int]): the underlying sequence of tokens, including prompt, in this context
        next_token_logprobs (numpy.array): numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by `CachedCausalLM.next_token_logprobs`, these probabilities are rescaled for this `LMContext`'s temperature parameter, and for any active masks. This vector is managed by the `LMContext` object internally; do not mutate.
        temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
        model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate.
        show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`.
    """

    def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
        """Create a new `LMContext` with a given prompt and temperature.

        Args:
            lm (hfppl.llms.CachedCausalLM): the language model for which this is a context.
            prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`.
            temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
        """
        self.lm = lm
        self.tokens = lm.tokenizer.encode(prompt)
        self.next_token_logprobs = log_softmax(
            lm.next_token_logprobs_unbatched(self.tokens) / temp
        )
        self.temp = temp
        self.model_mask = lm.masks.ALL_TOKENS
        self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
        self.prompt_token_count = len(self.tokens)
        self.show_prompt = show_prompt
        self.show_eos = show_eos

    def next_token(self):
        """Distribution over the next token.

        Sampling or observing from this distribution advances the state of this `LMContext` instance.
        """
        return LMNextToken(self)

    def mask_dist(self, mask):
        """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs
        to the given mask.

        Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that
        the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from
        the given mask.

        Args:
            mask: a `set(int)` specifying which token ids are included within the mask.
        """
        return LMTokenMask(self, mask)

    @property
    def token_count(self):
        return len(self.tokens) - self.prompt_token_count

    def __str__(self):
        full_string = self.lm.tokenizer.decode(self.tokens)
        if not self.show_prompt:
            full_string = full_string[self.prompt_string_length :]
        if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token):
            full_string = full_string[: -len(self.lm.tokenizer.eos_token)]
        return full_string

    def __deepcopy__(self, memo):
        cpy = type(self).__new__(type(self))

        for k, v in self.__dict__.items():
            if k in set(["lm"]):
                setattr(cpy, k, v)
            else:
                setattr(cpy, k, copy.deepcopy(v, memo))

        return cpy

__init__(lm, prompt, temp=1.0, show_prompt=False, show_eos=True)

Create a new LMContext with a given prompt and temperature.

Parameters:

Name Type Description Default
lm CachedCausalLM

the language model for which this is a context.

required
prompt str

a string with which to initialize the context. Will be tokenized using lm.tokenizer.

required
temp float

temeprature for next-token distribution (0 < temp < float('inf'))

1.0
Source code in hfppl/distributions/lmcontext.py
def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
    """Create a new `LMContext` with a given prompt and temperature.

    Args:
        lm (hfppl.llms.CachedCausalLM): the language model for which this is a context.
        prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`.
        temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
    """
    self.lm = lm
    self.tokens = lm.tokenizer.encode(prompt)
    self.next_token_logprobs = log_softmax(
        lm.next_token_logprobs_unbatched(self.tokens) / temp
    )
    self.temp = temp
    self.model_mask = lm.masks.ALL_TOKENS
    self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
    self.prompt_token_count = len(self.tokens)
    self.show_prompt = show_prompt
    self.show_eos = show_eos

mask_dist(mask)

Bernoulli distribution, with probability of True equal to the probability that the next token of this LMContext belongs to the given mask.

Sampling or observing from this distribution modifies the state of this LMContext instance, so that the next_token() distribution either will (if True) or will not (if False) generate a token from the given mask.

Parameters:

Name Type Description Default
mask

a set(int) specifying which token ids are included within the mask.

required
Source code in hfppl/distributions/lmcontext.py
def mask_dist(self, mask):
    """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs
    to the given mask.

    Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that
    the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from
    the given mask.

    Args:
        mask: a `set(int)` specifying which token ids are included within the mask.
    """
    return LMTokenMask(self, mask)

next_token()

Distribution over the next token.

Sampling or observing from this distribution advances the state of this LMContext instance.

Source code in hfppl/distributions/lmcontext.py
def next_token(self):
    """Distribution over the next token.

    Sampling or observing from this distribution advances the state of this `LMContext` instance.
    """
    return LMNextToken(self)