Skip to content

hfppl

Probabilistic programming with HuggingFace Transformer models.

Bernoulli

Bases: Distribution

A Bernoulli distribution.

Source code in hfppl/distributions/bernoulli.py
class Bernoulli(Distribution):
    """A Bernoulli distribution."""

    def __init__(self, p):
        """Create a Bernoulli distribution.

        Args:
            p: the probability-of-True for the Bernoulli distribution.
        """
        self.p = p

    async def sample(self):
        b = np.random.rand() < self.p
        return (b, await self.log_prob(b))

    async def log_prob(self, value):
        return np.log(self.p) if value else np.log1p(-self.p)

    async def argmax(self, idx):
        return (self.p > 0.5) if idx == 0 else (self.p < 0.5)

__init__(p)

Create a Bernoulli distribution.

Parameters:

Name Type Description Default
p

the probability-of-True for the Bernoulli distribution.

required
Source code in hfppl/distributions/bernoulli.py
def __init__(self, p):
    """Create a Bernoulli distribution.

    Args:
        p: the probability-of-True for the Bernoulli distribution.
    """
    self.p = p

CachedCausalLM

Wrapper around a HuggingFace causal language model, with support for caching.

Attributes:

Name Type Description
model

the underlying HuggingFace model.

tokenizer

the underlying HuggingFace tokenizer.

device str

the PyTorch device identifier (e.g. "cpu" or "cuda:0") on which the model is loaded.

cache TokenTrie

the cache of previously evaluated log probabilities and key/value vectors.

vocab list[str]

a list mapping token ids to strings.

batch_size int

when auto-batching, maximum number of queries to process in one batch.

timeout float

number of seconds to wait since last query before processing the current batch of queries, even if not full.

Source code in hfppl/llms.py
class CachedCausalLM:
    """Wrapper around a HuggingFace causal language model, with support for caching.

    Attributes:
        model: the underlying HuggingFace model.
        tokenizer: the underlying HuggingFace tokenizer.
        device (str): the PyTorch device identifier (e.g. "cpu" or "cuda:0") on which the model is loaded.
        cache (hfppl.llms.TokenTrie): the cache of previously evaluated log probabilities and key/value vectors.
        vocab (list[str]): a list mapping token ids to strings.
        batch_size (int): when auto-batching, maximum number of queries to process in one batch.
        timeout (float): number of seconds to wait since last query before processing the current batch of queries, even if not full.
    """

    @classmethod
    def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True):
        """Create a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] from a pretrained HuggingFace model.

        Args:
            model_id (str): the string identifier of the model in HuggingFace's model library.
            auth_token (str): a HuggingFace API key. Only necessary if using private models, e.g. Meta's Llama models, which require authorization.
            load_in_8bit (bool): whether to use the `bitsandbytes` library to load the model in 8-bit quantized form.

        Returns:
            model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model.
        """
        bnb_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)

        if not auth_token:
            tok = AutoTokenizer.from_pretrained(model_id)
            mod = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                quantization_config=bnb_config,
            )
        else:
            tok = AutoTokenizer.from_pretrained(model_id, token=auth_token)
            mod = AutoModelForCausalLM.from_pretrained(
                model_id,
                token=auth_token,
                device_map="auto",
                quantization_config=bnb_config,
            )

        return CachedCausalLM(mod, tok)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20):
        """
        Create a `CachedCausalLM` from a loaded HuggingFace model and tokenizer.

        Args:
            hf_model: a HuggingFace `CausalLM`.
            hf_tokenizer: a HuggingFace `Tokenizer`.
            batch_size (int): when auto-batching, maximum number of queries to process in one batch.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device

        # TODO: remove required BOS token
        if self.tokenizer.bos_token_id is None:
            raise RuntimeError(
                "Causal LM has no BOS token, distribution of first word unclear"
            )

        # Evaluate BOS token
        logits = self.model(
            torch.tensor([[self.tokenizer.bos_token_id]]).to(self.model.device)
        ).logits[0][0]
        logprobs = torch.log_softmax(logits, 0)

        self.cache = TokenTrie(None, logprobs.cpu().numpy())

        # Cache vocabulary
        bos_len = len(self.tokenizer.decode([self.tokenizer.bos_token_id]))
        self.vocab = [
            self.tokenizer.decode([self.tokenizer.bos_token_id, i])[bos_len:]
            for i in range(len(hf_tokenizer.vocab))
        ]

        # Precompute useful masks
        self.masks = Masks(self)

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = 0.02
        self.timer = None

    def __deepcopy__(self, memo):
        return self

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie(None, self.cache.logprobs)

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))

        node = self.cache.extend_cache(1, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    @torch.no_grad()
    def batch_evaluate_queries(self):

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        past_example = next((q.past for q in queries if q.past), False)
        max_past_length = max(q.past_len for q in queries)
        max_query_length = max(len(q.prompt) for q in queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [q.prompt_padded(padding_token_id, max_query_length) for q in queries]
        ).to(self.device)
        attn_masks = torch.tensor(
            [q.attention_mask(max_past_length, max_query_length) for q in queries]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        for i, q in enumerate(queries):
            q.future.set_result(results.logits[i])

    @torch.no_grad()
    def add_query(self, query, future, past):
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 1

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model.

        Returns:
            logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt.
        """

        # Ensure that token list begins with BOS
        assert token_ids[0] == self.tokenizer.bos_token_id

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_unbatched(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model.

        Returns:
            logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt.
        """

        # Ensure that token list begins with BOS
        assert token_ids[0] == self.tokenizer.bos_token_id

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

__init__(hf_model, hf_tokenizer, batch_size=20)

Create a CachedCausalLM from a loaded HuggingFace model and tokenizer.

Parameters:

Name Type Description Default
hf_model

a HuggingFace CausalLM.

required
hf_tokenizer

a HuggingFace Tokenizer.

required
batch_size int

when auto-batching, maximum number of queries to process in one batch.

20
Source code in hfppl/llms.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20):
    """
    Create a `CachedCausalLM` from a loaded HuggingFace model and tokenizer.

    Args:
        hf_model: a HuggingFace `CausalLM`.
        hf_tokenizer: a HuggingFace `Tokenizer`.
        batch_size (int): when auto-batching, maximum number of queries to process in one batch.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device

    # TODO: remove required BOS token
    if self.tokenizer.bos_token_id is None:
        raise RuntimeError(
            "Causal LM has no BOS token, distribution of first word unclear"
        )

    # Evaluate BOS token
    logits = self.model(
        torch.tensor([[self.tokenizer.bos_token_id]]).to(self.model.device)
    ).logits[0][0]
    logprobs = torch.log_softmax(logits, 0)

    self.cache = TokenTrie(None, logprobs.cpu().numpy())

    # Cache vocabulary
    bos_len = len(self.tokenizer.decode([self.tokenizer.bos_token_id]))
    self.vocab = [
        self.tokenizer.decode([self.tokenizer.bos_token_id, i])[bos_len:]
        for i in range(len(hf_tokenizer.vocab))
    ]

    # Precompute useful masks
    self.masks = Masks(self)

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = 0.02
    self.timer = None

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in hfppl/llms.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))

    node = self.cache.extend_cache(1, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in hfppl/llms.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie(None, self.cache.logprobs)

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in hfppl/llms.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

from_pretrained(model_id, auth_token=False, load_in_8bit=True) classmethod

Create a CachedCausalLM from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

the string identifier of the model in HuggingFace's model library.

required
auth_token str

a HuggingFace API key. Only necessary if using private models, e.g. Meta's Llama models, which require authorization.

False
load_in_8bit bool

whether to use the bitsandbytes library to load the model in 8-bit quantized form.

True

Returns:

Name Type Description
model CachedCausalLM

the LLaMPPL-compatible interface to the HuggingFace model.

Source code in hfppl/llms.py
@classmethod
def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True):
    """Create a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] from a pretrained HuggingFace model.

    Args:
        model_id (str): the string identifier of the model in HuggingFace's model library.
        auth_token (str): a HuggingFace API key. Only necessary if using private models, e.g. Meta's Llama models, which require authorization.
        load_in_8bit (bool): whether to use the `bitsandbytes` library to load the model in 8-bit quantized form.

    Returns:
        model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model.
    """
    bnb_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)

    if not auth_token:
        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            quantization_config=bnb_config,
        )
    else:
        tok = AutoTokenizer.from_pretrained(model_id, token=auth_token)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id,
            token=auth_token,
            device_map="auto",
            quantization_config=bnb_config,
        )

    return CachedCausalLM(mod, tok)

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids starting with tokenizer.bos_token_id, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs array

a numpy array of len(vocab), with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in hfppl/llms.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model.

    Returns:
        logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt.
    """

    # Ensure that token list begins with BOS
    assert token_ids[0] == self.tokenizer.bos_token_id

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_unbatched(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids starting with tokenizer.bos_token_id, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs array

a numpy array of len(vocab), with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in hfppl/llms.py
@torch.no_grad()
def next_token_logprobs_unbatched(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model.

    Returns:
        logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt.
    """

    # Ensure that token list begins with BOS
    assert token_ids[0] == self.tokenizer.bos_token_id

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in hfppl/llms.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

Distribution

Abstract base class for a distribution.

Source code in hfppl/distributions/distribution.py
class Distribution:
    """Abstract base class for a distribution."""

    async def sample(self):
        """Generate a random sample from the distribution.

        Returns:
            x: a value randomly sampled from the distribution."""
        raise NotImplementedError()

    async def log_prob(self, x):
        """Compute the log probability of a value under this distribution,
        or the log probability density if the distribution is continuous.

        Args:
            x: the point at which to evaluate the log probability.
        Returns:
            logprob (float): the log probability of `x`."""
        raise NotImplementedError()

    async def argmax(self, n):
        """Return the nth most probable outcome under this distribution (assuming this is a discrete distribution).

        Args:
            n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|).
        Returns:
            x: the nth most probable outcome from this distribution."""
        raise NotImplementedError()

argmax(n) async

Return the nth most probable outcome under this distribution (assuming this is a discrete distribution).

Parameters:

Name Type Description Default
n int

which value to return to, indexed from most probable (n=0) to least probable (n=|support|).

required

Returns: x: the nth most probable outcome from this distribution.

Source code in hfppl/distributions/distribution.py
async def argmax(self, n):
    """Return the nth most probable outcome under this distribution (assuming this is a discrete distribution).

    Args:
        n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|).
    Returns:
        x: the nth most probable outcome from this distribution."""
    raise NotImplementedError()

log_prob(x) async

Compute the log probability of a value under this distribution, or the log probability density if the distribution is continuous.

Parameters:

Name Type Description Default
x

the point at which to evaluate the log probability.

required

Returns: logprob (float): the log probability of x.

Source code in hfppl/distributions/distribution.py
async def log_prob(self, x):
    """Compute the log probability of a value under this distribution,
    or the log probability density if the distribution is continuous.

    Args:
        x: the point at which to evaluate the log probability.
    Returns:
        logprob (float): the log probability of `x`."""
    raise NotImplementedError()

sample() async

Generate a random sample from the distribution.

Returns:

Name Type Description
x

a value randomly sampled from the distribution.

Source code in hfppl/distributions/distribution.py
4
5
6
7
8
9
async def sample(self):
    """Generate a random sample from the distribution.

    Returns:
        x: a value randomly sampled from the distribution."""
    raise NotImplementedError()

Geometric

Bases: Distribution

A Geometric distribution.

Source code in hfppl/distributions/geometric.py
class Geometric(Distribution):
    """A Geometric distribution."""

    def __init__(self, p):
        """Create a Geometric distribution.

        Args:
            p: the rate of the Geometric distribution.
        """
        self.p = p

    async def sample(self):
        n = np.random.geometric(self.p)
        return n, await self.log_prob(n)

    async def log_prob(self, value):
        return np.log(self.p) + np.log(1 - self.p) * (value - 1)

    async def argmax(self, idx):
        return idx - 1  # Most likely outcome is 0, then 1, etc.

__init__(p)

Create a Geometric distribution.

Parameters:

Name Type Description Default
p

the rate of the Geometric distribution.

required
Source code in hfppl/distributions/geometric.py
def __init__(self, p):
    """Create a Geometric distribution.

    Args:
        p: the rate of the Geometric distribution.
    """
    self.p = p

LMContext

Represents a generation-in-progress from a language model.

The state tracks two pieces of information:

  • A sequence of tokens — the ever-growing context for the language model.
  • A current mask — a set of tokens that have not yet been ruled out as the next token.

Storing a mask enables sub-token generation: models can use LMContext to sample the next token in stages, first deciding, e.g., whether to use an upper-case or lower-case first letter, and only later deciding which upper-case or lower-case token to generate.

The state of a LMContext can be advanced in two ways:

  1. Sampling, observing, or intervening the next_token() distribution. This causes a token to be added to the growing sequence of tokens. Supports auto-batching.
  2. Sampling, observing, or intervening the mask_dist(mask) distribution for a given mask (set of token ids). This changes the current mask.

Attributes:

Name Type Description
lm CachedCausalLM

the language model for which this is a context

tokens list[int]

the underlying sequence of tokens, including prompt, in this context

next_token_logprobs array

numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by CachedCausalLM.next_token_logprobs, these probabilities are rescaled for this LMContext's temperature parameter, and for any active masks. This vector is managed by the LMContext object internally; do not mutate.

temp float

temeprature for next-token distribution (0 < temp < float('inf'))

model_mask set[int]

set of tokens that have not been ruled out as the next token. This mask is managed by the LMContext object internally; do not mutate.

show_prompt bool

controls whether the string representation of this LMContext includes the initial prompt or not. Defaults to False.

Source code in hfppl/distributions/lmcontext.py
class LMContext:
    """Represents a generation-in-progress from a language model.

    The state tracks two pieces of information:

    * A sequence of tokens — the ever-growing context for the language model.
    * A *current mask* — a set of tokens that have not yet been ruled out as the next token.

    Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample
    the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case
    first letter, and only later deciding which upper-case or lower-case token to generate.

    The state of a `LMContext` can be advanced in two ways:

    1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token
    to be added to the growing sequence of tokens. Supports auto-batching.
    2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of
    token ids). This changes the current mask.

    Attributes:
        lm (hfppl.llms.CachedCausalLM): the language model for which this is a context
        tokens (list[int]): the underlying sequence of tokens, including prompt, in this context
        next_token_logprobs (numpy.array): numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by `CachedCausalLM.next_token_logprobs`, these probabilities are rescaled for this `LMContext`'s temperature parameter, and for any active masks. This vector is managed by the `LMContext` object internally; do not mutate.
        temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
        model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate.
        show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`.
    """

    def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
        """Create a new `LMContext` with a given prompt and temperature.

        Args:
            lm (hfppl.llms.CachedCausalLM): the language model for which this is a context.
            prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`.
            temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
        """
        self.lm = lm
        self.tokens = lm.tokenizer.encode(prompt)
        self.next_token_logprobs = log_softmax(
            lm.next_token_logprobs_unbatched(self.tokens) / temp
        )
        self.temp = temp
        self.model_mask = lm.masks.ALL_TOKENS
        self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
        self.prompt_token_count = len(self.tokens)
        self.show_prompt = show_prompt
        self.show_eos = show_eos

    def next_token(self):
        """Distribution over the next token.

        Sampling or observing from this distribution advances the state of this `LMContext` instance.
        """
        return LMNextToken(self)

    def mask_dist(self, mask):
        """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs
        to the given mask.

        Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that
        the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from
        the given mask.

        Args:
            mask: a `set(int)` specifying which token ids are included within the mask.
        """
        return LMTokenMask(self, mask)

    @property
    def token_count(self):
        return len(self.tokens) - self.prompt_token_count

    def __str__(self):
        full_string = self.lm.tokenizer.decode(self.tokens)
        if not self.show_prompt:
            full_string = full_string[self.prompt_string_length :]
        if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token):
            full_string = full_string[: -len(self.lm.tokenizer.eos_token)]
        return full_string

    def __deepcopy__(self, memo):
        cpy = type(self).__new__(type(self))

        for k, v in self.__dict__.items():
            if k in set(["lm"]):
                setattr(cpy, k, v)
            else:
                setattr(cpy, k, copy.deepcopy(v, memo))

        return cpy

__init__(lm, prompt, temp=1.0, show_prompt=False, show_eos=True)

Create a new LMContext with a given prompt and temperature.

Parameters:

Name Type Description Default
lm CachedCausalLM

the language model for which this is a context.

required
prompt str

a string with which to initialize the context. Will be tokenized using lm.tokenizer.

required
temp float

temeprature for next-token distribution (0 < temp < float('inf'))

1.0
Source code in hfppl/distributions/lmcontext.py
def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
    """Create a new `LMContext` with a given prompt and temperature.

    Args:
        lm (hfppl.llms.CachedCausalLM): the language model for which this is a context.
        prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`.
        temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
    """
    self.lm = lm
    self.tokens = lm.tokenizer.encode(prompt)
    self.next_token_logprobs = log_softmax(
        lm.next_token_logprobs_unbatched(self.tokens) / temp
    )
    self.temp = temp
    self.model_mask = lm.masks.ALL_TOKENS
    self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
    self.prompt_token_count = len(self.tokens)
    self.show_prompt = show_prompt
    self.show_eos = show_eos

mask_dist(mask)

Bernoulli distribution, with probability of True equal to the probability that the next token of this LMContext belongs to the given mask.

Sampling or observing from this distribution modifies the state of this LMContext instance, so that the next_token() distribution either will (if True) or will not (if False) generate a token from the given mask.

Parameters:

Name Type Description Default
mask

a set(int) specifying which token ids are included within the mask.

required
Source code in hfppl/distributions/lmcontext.py
def mask_dist(self, mask):
    """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs
    to the given mask.

    Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that
    the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from
    the given mask.

    Args:
        mask: a `set(int)` specifying which token ids are included within the mask.
    """
    return LMTokenMask(self, mask)

next_token()

Distribution over the next token.

Sampling or observing from this distribution advances the state of this LMContext instance.

Source code in hfppl/distributions/lmcontext.py
def next_token(self):
    """Distribution over the next token.

    Sampling or observing from this distribution advances the state of this `LMContext` instance.
    """
    return LMNextToken(self)

LogCategorical

Bases: Distribution

A Geometric distribution.

Source code in hfppl/distributions/logcategorical.py
class LogCategorical(Distribution):
    """A Geometric distribution."""

    def __init__(self, logits):
        """Create a Categorical distribution from unnormalized log probabilities (logits).
        Given an array of logits, takes their `softmax` and samples an integer in `range(len(logits))`
        from the resulting categorical.

        Args:
            logits (np.array): a numpy array of unnormalized log probabilities.
        """
        self.log_probs = log_softmax(logits)

    async def sample(self):
        n = np.random.choice(len(self.log_probs), p=np.exp(self.log_probs))
        return n, await self.log_prob(n)

    async def log_prob(self, value):
        return self.log_probs[value]

    async def argmax(self, idx):
        return np.argsort(self.log_probs)[-idx]

__init__(logits)

Create a Categorical distribution from unnormalized log probabilities (logits). Given an array of logits, takes their softmax and samples an integer in range(len(logits)) from the resulting categorical.

Parameters:

Name Type Description Default
logits array

a numpy array of unnormalized log probabilities.

required
Source code in hfppl/distributions/logcategorical.py
def __init__(self, logits):
    """Create a Categorical distribution from unnormalized log probabilities (logits).
    Given an array of logits, takes their `softmax` and samples an integer in `range(len(logits))`
    from the resulting categorical.

    Args:
        logits (np.array): a numpy array of unnormalized log probabilities.
    """
    self.log_probs = log_softmax(logits)

Masks

Source code in hfppl/llms.py
class Masks:
    def __init__(self, lm):
        self.ALL_TOKENS = set(range(len(lm.vocab)))
        self.STARTS_NEW_WORD = set(
            i
            for (i, v) in enumerate(lm.vocab)
            if v[0] == " "
            and len(v) > 1
            and v[1] not in string.whitespace
            and v[1] not in string.punctuation
        )
        self.CONTINUES_CURRENT_WORD = set(
            i
            for (i, v) in enumerate(lm.vocab)
            if all(c in "'" or c.isalpha() for c in v)
        )
        self.MID_PUNCTUATION = set(
            i for (i, v) in enumerate(lm.vocab) if v in (",", ":", ";", "-", '"')
        )
        self.END_PUNCTUATION = set(
            i for (i, v) in enumerate(lm.vocab) if v in (".", "!", "?")
        )
        self.PUNCTUATION = self.MID_PUNCTUATION | self.END_PUNCTUATION
        self.CONTAINS_WHITESPACE = set(
            i
            for (i, v) in enumerate(lm.vocab)
            if any(c in string.whitespace for c in v)
        )

        self.MAX_TOKEN_LENGTH = self.precompute_token_length_masks(lm)

    def precompute_token_length_masks(self, lm):
        """Precompute masks for tokens of different lengths.

        Each mask is a set of token ids that are of the given length or shorter."""
        max_token_length = max([len(t) for t in lm.vocab])

        masks = defaultdict(lambda: self.ALL_TOKENS)
        masks[0] = set([lm.tokenizer.eos_token_id])
        for token_length in range(1, max_token_length + 1):
            masks[token_length] = set(
                i
                for (i, v) in enumerate(lm.vocab)
                if len(v) <= token_length and i != lm.tokenizer.eos_token_id
            )

        return masks

precompute_token_length_masks(lm)

Precompute masks for tokens of different lengths.

Each mask is a set of token ids that are of the given length or shorter.

Source code in hfppl/llms.py
def precompute_token_length_masks(self, lm):
    """Precompute masks for tokens of different lengths.

    Each mask is a set of token ids that are of the given length or shorter."""
    max_token_length = max([len(t) for t in lm.vocab])

    masks = defaultdict(lambda: self.ALL_TOKENS)
    masks[0] = set([lm.tokenizer.eos_token_id])
    for token_length in range(1, max_token_length + 1):
        masks[token_length] = set(
            i
            for (i, v) in enumerate(lm.vocab)
            if len(v) <= token_length and i != lm.tokenizer.eos_token_id
        )

    return masks

Model

Base class for all LLaMPPL models.

Your models should subclass this class. Minimally, you should provide an __init__ method that calls super().__init__(self), and a step method.

Source code in hfppl/modeling.py
class Model:
    """Base class for all LLaMPPL models.

    Your models should subclass this class. Minimally, you should provide an `__init__` method
    that calls `super().__init__(self)`, and a `step` method.
    """

    def __init__(self):
        self.weight = 0.0
        self.finished = False
        self.mode = "sample"
        self.beam_idx = 0
        self.force_eos = False
        self.twist_amount = 0.0

    def reset(self):
        self.weight = 0.0
        self.finished = False
        self.mode = "sample"
        self.beam_idx = 0
        self.force_eos = False
        self.twist_amount = 0.0

    def immutable_properties(self):
        """Return a `set[str]` of properties that LLaMPPL may assume do not change during execution of `step`.
        This set is empty by default but can be overridden by subclasses to speed up inference.

        Returns:
            properties (set[str]): a set of immutable property names"""
        return set()

    def __deepcopy__(self, memo):
        cpy = type(self).__new__(type(self))
        immutable = self.immutable_properties()

        for k, v in self.__dict__.items():
            if k in immutable:
                setattr(cpy, k, v)
            else:
                setattr(cpy, k, copy.deepcopy(v, memo))

        return cpy

    def twist(self, amt):
        """Multiply this particle's weight by `exp(amt)`, but divide it back out before the next `step`.

        Use this method to provide heuristic guidance about whether a particle is "on the right track"
        without changing the ultimate target distribution.

        Args:
            amt: the logarithm of the amount by which to (temporarily) multiply this particle's weight.
        """
        self.twist_amount += amt
        self.score(amt)

    def untwist(self):
        self.score(-self.twist_amount)
        self.twist_amount = 0.0

    def finish(self):
        self.untwist()
        self.finished = True

    def done_stepping(self):
        return self.finished

    async def step(self):
        """Defines the computation performed in each step of the model.

        All subclasses should override this method."""

        if not self.done_stepping():
            raise NotImplementedError("Model.step() must be implemented by subclasses")

    def __str__(self):
        return "Particle"

    async def start(self):
        pass

    def score(self, score):
        """Multiply this particle's weight by `exp(score)`.

        The `score` method is a low-level way to change the target distribution.
        For many use cases, it is sufficient to use `sample`, `observe`, `condition`,
        and `twist`, all of which are implemented in terms of `score`.

        Args:
            score: logarithm of the amount by which the particle's weight should be multiplied.
        """
        self.weight += score

    def condition(self, b):
        """Constrain a given Boolean expression to be `True`.

        If the condition is False, the particle's weight is set to zero and `self.finish()`
        is called, so that no further `step` calls are made.

        Args:
            b: the Boolean expression whose value is constrained to be True.
        """
        if not b:
            self.score(float("-inf"))
            self.finish()

    async def intervene(self, dist, x):
        """Force the distribution to take on the value `x`, but do not _condition_ on this result.

        This is useful primarily with distributions that have side effects (e.g., modifying some state).
        For example, a model with the code

        ```python
        token_1 = await self.sample(self.stateful_lm.next_token())
        await self.observe(self.stateful_lm.next_token(), token_2)
        ```

        encodes a posterior inference problem, to find `token_1` values that *likely preceded* `token_2`. By contrast,

        ```python
        token_1 = await self.sample(stateful_lm.next_token())
        await self.intervene(self.stateful_lm.next_token(), token_2)
        ```

        encodes a much easier task: freely generate `token_1` and then force-feed `token_2` as the following token.

        Args:
            dist (hfppl.distributions.distribution.Distribution): the distribution on which to intervene.
            x: the value to intervene with.
        """
        await dist.log_prob(x)
        return x

    async def observe(self, dist, x):
        """Condition the model on the value `x` being sampled from the distribution `dist`.

        For discrete distributions `dist`, `await self.observe(dist, x)` specifies the same constraint as
        ```
        val = await self.sample(dist)
        self.condition(val == x)
        ```
        but can be much more efficient.

        Args:
            dist: a `Distribution` object from which to observe
            x: the value observed from `dist`
        """
        p = await dist.log_prob(x)
        self.score(p)
        return x

    async def sample(self, dist, proposal=None):
        """Extend the model with a sample from a given `Distribution`, with support for autobatching.
        If specified, the Distribution `proposal` is used during inference to generate informed hypotheses.

        Args:
            dist: the `Distribution` object from which to sample
            proposal: if provided, inference algorithms will use this `Distribution` object to generate proposed samples, rather than `dist`.
              However, importance weights will be adjusted so that the target posterior is independent of the proposal.

        Returns:
            value: the value sampled from the distribution.
        """
        # Special logic for beam search
        # if self.mode == "beam":
        #     d = dist if proposal is None else proposal
        #     x, w = d.argmax(self.beam_idx)
        #     if proposal is not None:
        #         self.score(dist.log_prob(x))
        #     else:
        #         self.score(w)
        #     return x

        if proposal is None:
            x, _ = await dist.sample()
            return x
        else:
            x, q = await proposal.sample()
            p = await dist.log_prob(x)
            self.score(p - q)
            return x

    async def call(self, submodel):
        return await submodel.run_with_parent(self)

    def string_for_serialization(self):
        """Return a string representation of the particle for serialization purposes.

        Returns:
            str: a string representation of the particle.
        """
        return str(self)

condition(b)

Constrain a given Boolean expression to be True.

If the condition is False, the particle's weight is set to zero and self.finish() is called, so that no further step calls are made.

Parameters:

Name Type Description Default
b

the Boolean expression whose value is constrained to be True.

required
Source code in hfppl/modeling.py
def condition(self, b):
    """Constrain a given Boolean expression to be `True`.

    If the condition is False, the particle's weight is set to zero and `self.finish()`
    is called, so that no further `step` calls are made.

    Args:
        b: the Boolean expression whose value is constrained to be True.
    """
    if not b:
        self.score(float("-inf"))
        self.finish()

immutable_properties()

Return a set[str] of properties that LLaMPPL may assume do not change during execution of step. This set is empty by default but can be overridden by subclasses to speed up inference.

Returns:

Name Type Description
properties set[str]

a set of immutable property names

Source code in hfppl/modeling.py
def immutable_properties(self):
    """Return a `set[str]` of properties that LLaMPPL may assume do not change during execution of `step`.
    This set is empty by default but can be overridden by subclasses to speed up inference.

    Returns:
        properties (set[str]): a set of immutable property names"""
    return set()

intervene(dist, x) async

Force the distribution to take on the value x, but do not condition on this result.

This is useful primarily with distributions that have side effects (e.g., modifying some state). For example, a model with the code

token_1 = await self.sample(self.stateful_lm.next_token())
await self.observe(self.stateful_lm.next_token(), token_2)

encodes a posterior inference problem, to find token_1 values that likely preceded token_2. By contrast,

token_1 = await self.sample(stateful_lm.next_token())
await self.intervene(self.stateful_lm.next_token(), token_2)

encodes a much easier task: freely generate token_1 and then force-feed token_2 as the following token.

Parameters:

Name Type Description Default
dist Distribution

the distribution on which to intervene.

required
x

the value to intervene with.

required
Source code in hfppl/modeling.py
async def intervene(self, dist, x):
    """Force the distribution to take on the value `x`, but do not _condition_ on this result.

    This is useful primarily with distributions that have side effects (e.g., modifying some state).
    For example, a model with the code

    ```python
    token_1 = await self.sample(self.stateful_lm.next_token())
    await self.observe(self.stateful_lm.next_token(), token_2)
    ```

    encodes a posterior inference problem, to find `token_1` values that *likely preceded* `token_2`. By contrast,

    ```python
    token_1 = await self.sample(stateful_lm.next_token())
    await self.intervene(self.stateful_lm.next_token(), token_2)
    ```

    encodes a much easier task: freely generate `token_1` and then force-feed `token_2` as the following token.

    Args:
        dist (hfppl.distributions.distribution.Distribution): the distribution on which to intervene.
        x: the value to intervene with.
    """
    await dist.log_prob(x)
    return x

observe(dist, x) async

Condition the model on the value x being sampled from the distribution dist.

For discrete distributions dist, await self.observe(dist, x) specifies the same constraint as

val = await self.sample(dist)
self.condition(val == x)
but can be much more efficient.

Parameters:

Name Type Description Default
dist

a Distribution object from which to observe

required
x

the value observed from dist

required
Source code in hfppl/modeling.py
async def observe(self, dist, x):
    """Condition the model on the value `x` being sampled from the distribution `dist`.

    For discrete distributions `dist`, `await self.observe(dist, x)` specifies the same constraint as
    ```
    val = await self.sample(dist)
    self.condition(val == x)
    ```
    but can be much more efficient.

    Args:
        dist: a `Distribution` object from which to observe
        x: the value observed from `dist`
    """
    p = await dist.log_prob(x)
    self.score(p)
    return x

sample(dist, proposal=None) async

Extend the model with a sample from a given Distribution, with support for autobatching. If specified, the Distribution proposal is used during inference to generate informed hypotheses.

Parameters:

Name Type Description Default
dist

the Distribution object from which to sample

required
proposal

if provided, inference algorithms will use this Distribution object to generate proposed samples, rather than dist. However, importance weights will be adjusted so that the target posterior is independent of the proposal.

None

Returns:

Name Type Description
value

the value sampled from the distribution.

Source code in hfppl/modeling.py
async def sample(self, dist, proposal=None):
    """Extend the model with a sample from a given `Distribution`, with support for autobatching.
    If specified, the Distribution `proposal` is used during inference to generate informed hypotheses.

    Args:
        dist: the `Distribution` object from which to sample
        proposal: if provided, inference algorithms will use this `Distribution` object to generate proposed samples, rather than `dist`.
          However, importance weights will be adjusted so that the target posterior is independent of the proposal.

    Returns:
        value: the value sampled from the distribution.
    """
    # Special logic for beam search
    # if self.mode == "beam":
    #     d = dist if proposal is None else proposal
    #     x, w = d.argmax(self.beam_idx)
    #     if proposal is not None:
    #         self.score(dist.log_prob(x))
    #     else:
    #         self.score(w)
    #     return x

    if proposal is None:
        x, _ = await dist.sample()
        return x
    else:
        x, q = await proposal.sample()
        p = await dist.log_prob(x)
        self.score(p - q)
        return x

score(score)

Multiply this particle's weight by exp(score).

The score method is a low-level way to change the target distribution. For many use cases, it is sufficient to use sample, observe, condition, and twist, all of which are implemented in terms of score.

Parameters:

Name Type Description Default
score

logarithm of the amount by which the particle's weight should be multiplied.

required
Source code in hfppl/modeling.py
def score(self, score):
    """Multiply this particle's weight by `exp(score)`.

    The `score` method is a low-level way to change the target distribution.
    For many use cases, it is sufficient to use `sample`, `observe`, `condition`,
    and `twist`, all of which are implemented in terms of `score`.

    Args:
        score: logarithm of the amount by which the particle's weight should be multiplied.
    """
    self.weight += score

step() async

Defines the computation performed in each step of the model.

All subclasses should override this method.

Source code in hfppl/modeling.py
async def step(self):
    """Defines the computation performed in each step of the model.

    All subclasses should override this method."""

    if not self.done_stepping():
        raise NotImplementedError("Model.step() must be implemented by subclasses")

string_for_serialization()

Return a string representation of the particle for serialization purposes.

Returns:

Name Type Description
str

a string representation of the particle.

Source code in hfppl/modeling.py
def string_for_serialization(self):
    """Return a string representation of the particle for serialization purposes.

    Returns:
        str: a string representation of the particle.
    """
    return str(self)

twist(amt)

Multiply this particle's weight by exp(amt), but divide it back out before the next step.

Use this method to provide heuristic guidance about whether a particle is "on the right track" without changing the ultimate target distribution.

Parameters:

Name Type Description Default
amt

the logarithm of the amount by which to (temporarily) multiply this particle's weight.

required
Source code in hfppl/modeling.py
def twist(self, amt):
    """Multiply this particle's weight by `exp(amt)`, but divide it back out before the next `step`.

    Use this method to provide heuristic guidance about whether a particle is "on the right track"
    without changing the ultimate target distribution.

    Args:
        amt: the logarithm of the amount by which to (temporarily) multiply this particle's weight.
    """
    self.twist_amount += amt
    self.score(amt)

Query

A query to a language model, waiting to be batched.

Source code in hfppl/llms.py
class Query:
    """A query to a language model, waiting to be batched."""

    def __init__(self, prompt, future, past=None):
        self.prompt = prompt
        self.future = future
        self.past = past

        if self.past is not None:
            self.past_len = past[0][0].shape[
                2
            ]  # layers, key or value, batch size, num heads, num tokens, head repr length
        else:
            self.past_len = 0

    @torch.no_grad()
    def past_padded(self, layer, j, to_length, dtype, device, past_shape):

        if self.past is not None:
            return torch.cat(
                (
                    self.past[layer][j],
                    torch.zeros(
                        1,
                        past_shape[1],
                        to_length - self.past_len,
                        past_shape[3],
                        dtype=dtype,
                        device=device,
                    ),
                ),
                dim=2,
            )
        else:
            return torch.zeros(
                1, past_shape[1], to_length, past_shape[3], dtype=dtype, device=device
            )

    def prompt_padded(self, pad_token, to_length):
        return [*self.prompt, *[pad_token for _ in range(to_length - len(self.prompt))]]

    def attention_mask(self, total_past_length, total_seq_length):
        return [
            *[1 for _ in range(self.past_len)],
            *[0 for _ in range(total_past_length - self.past_len)],
            *[1 for _ in range(len(self.prompt))],
            *[0 for _ in range(total_seq_length - len(self.prompt))],
        ]

    def position_ids(self, total_past_length, total_seq_length):
        return [
            *range(self.past_len, self.past_len + len(self.prompt)),
            *[0 for _ in range(total_seq_length - len(self.prompt))],
        ]

Token

Class representing a token.

Attributes:

Name Type Description
lm CachedCausalLM

the language model for which this is a Token.

token_id int

the integer token id (an index into the vocabulary).

token_str str

a string, which the token represents—equal to lm.vocab[token_id].

Source code in hfppl/llms.py
class Token:
    """Class representing a token.

    Attributes:
        lm (hfppl.llms.CachedCausalLM): the language model for which this is a Token.
        token_id (int): the integer token id (an index into the vocabulary).
        token_str (str): a string, which the token represents—equal to `lm.vocab[token_id]`.
    """

    def __init__(self, lm, token_id, token_str):
        self.lm = lm
        self.token_id = token_id
        self.token_str = token_str

    # Adding tokens
    def __add__(self, other):
        s = TokenSequence(self.lm, [self.token_id])
        s += other
        return s

    def __radd__(self, other):
        s = TokenSequence(self.lm, [self.token_id])
        return other + s

    # Support checking for EOS
    def __eq__(self, other):
        if isinstance(other, Token):
            return self.lm is other.lm and self.token_id == other.token_id
        elif isinstance(other, int):
            return self.token_id == other
        else:
            return self.token_str == other

    def __str__(self):
        return self.token_str

    def __repr__(self):
        return f"<{self.token_str}|{self.token_id}>"

TokenCategorical

Bases: Distribution

Source code in hfppl/distributions/tokencategorical.py
class TokenCategorical(Distribution):

    def __init__(self, lm, logits):
        """Create a Categorical distribution whose values are Tokens, not integers.
        Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`),
        uses softmax to normalize them and samples a Token from the resulting categorical.

        Args:
            lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary is to be generated from.
            logits (np.array): a numpy array of unnormalized log probabilities.
        """
        self.lm = lm
        self.log_probs = log_softmax(logits)
        if self.lm.tokenizer.vocab_size != len(logits):
            raise RuntimeError(
                f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits."
            )

    async def sample(self):
        n = np.random.choice(len(self.log_probs), p=(np.exp(self.log_probs)))
        return (
            Token(self.lm, n, self.lm.tokenizer.convert_ids_to_tokens(n)),
            self.log_probs[n],
        )

    async def log_prob(self, value):
        return self.log_probs[value.token_id]

    async def argmax(self, idx):
        tok = torch.argsort(self.log_probs)[-idx]
        return (
            Token(self.lm, tok, self.lm.tokenizer.convert_ids_to_tokens(tok)),
            self.log_probs[tok],
        )

__init__(lm, logits)

Create a Categorical distribution whose values are Tokens, not integers. Given a language model lm and an array of unnormalized log probabilities (of length len(lm.vocab)), uses softmax to normalize them and samples a Token from the resulting categorical.

Parameters:

Name Type Description Default
lm CachedCausalLM

the language model whose vocabulary is to be generated from.

required
logits array

a numpy array of unnormalized log probabilities.

required
Source code in hfppl/distributions/tokencategorical.py
def __init__(self, lm, logits):
    """Create a Categorical distribution whose values are Tokens, not integers.
    Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`),
    uses softmax to normalize them and samples a Token from the resulting categorical.

    Args:
        lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary is to be generated from.
        logits (np.array): a numpy array of unnormalized log probabilities.
    """
    self.lm = lm
    self.log_probs = log_softmax(logits)
    if self.lm.tokenizer.vocab_size != len(logits):
        raise RuntimeError(
            f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits."
        )

TokenSequence

A sequence of tokens.

Supports addition (via + or mutating +=) with:

  • other TokenSequence instances (concatenation)
  • individual tokens, represented as integers or Token instances
  • strings, which are tokenized by lm.tokenizer

Attributes:

Name Type Description
lm CachedCausalLM

the language model whose vocabulary the tokens come from.

seq list[Token]

the sequence of tokens.

Source code in hfppl/llms.py
class TokenSequence:
    """A sequence of tokens.

    Supports addition (via `+` or mutating `+=`) with:

    * other `TokenSequence` instances (concatenation)
    * individual tokens, represented as integers or `Token` instances
    * strings, which are tokenized by `lm.tokenizer`

    Attributes:
        lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from.
        seq (list[hfppl.llms.Token]): the sequence of tokens."""

    def __init__(self, lm, seq=None):
        """Create a `TokenSequence` from a language model and a sequence.

        Args:
            lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from.
            seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token.
        """
        self.lm = lm
        if seq is None:
            self.seq = [lm.tokenizer.bos_token_id]
        elif isinstance(seq, str):
            self.seq = self.lm.tokenizer.encode(seq)
        else:
            self.seq = seq

    def __str__(self):
        return self.lm.tokenizer.decode(self.seq)

    def __iadd__(self, other):
        if isinstance(other, Token):
            assert other.lm is self.lm
            self.seq.append(other.token_id)
        elif isinstance(other, TokenSequence):
            assert other.lm is self.lm
            self.seq.extend(other.seq)
        elif isinstance(other, str):
            self.seq.extend(self.lm.tokenizer.encode(other, add_special_tokens=False))
        elif isinstance(other, int):
            self.seq.append(other)
        else:
            raise RuntimeError(f"Addition not supported on {type(other)}")
        return self

    def __radd__(self, other):
        if isinstance(other, Token):
            assert other.lm is self.lm
            return TokenSequence(self.lm, [other.token_id, *self.seq])
        elif isinstance(other, TokenSequence):
            assert other.lm is self.lm
            return TokenSequence(self.lm, other.seq + self.seq)
        elif isinstance(other, str):
            return TokenSequence(
                self.lm,
                self.lm.tokenizer.encode(other, add_special_tokens=False) + self.seq,
            )
        elif isinstance(other, int):
            return TokenSequence(self.lm, [other, *self.seq])
        else:
            raise RuntimeError(f"Addition not supported on {type(other)}")

    def __add__(self, other):
        s = TokenSequence(self.lm, self.seq)
        s += other
        return s

__init__(lm, seq=None)

Create a TokenSequence from a language model and a sequence.

Parameters:

Name Type Description Default
lm CachedCausalLM

the language model whose vocabulary the tokens come from.

required
seq str | list[int]

the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token.

None
Source code in hfppl/llms.py
def __init__(self, lm, seq=None):
    """Create a `TokenSequence` from a language model and a sequence.

    Args:
        lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from.
        seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token.
    """
    self.lm = lm
    if seq is None:
        self.seq = [lm.tokenizer.bos_token_id]
    elif isinstance(seq, str):
        self.seq = self.lm.tokenizer.encode(seq)
    else:
        self.seq = seq

TokenTrie

Class used internally to cache language model results.

Source code in hfppl/llms.py
class TokenTrie:
    """Class used internally to cache language model results."""

    # Trie of tokens.

    def __init__(self, parent=None, logprobs=None):
        self.children = {}  # maps token ID to child
        self.logprobs = logprobs  # for next token
        self.past_key_values = None

    def __repr__(self):
        return (
            f"{'*' if self.past_key_values is not None else ''}["
            + ", ".join(
                [
                    f"{node_id}: {node.__repr__()}"
                    for (node_id, node) in self.children.items()
                ]
            )
            + "]"
        )

    def clear_kv_cache(self):
        self.past_key_values = None
        for child, node in self.children.items():
            node.clear_kv_cache()

    def has_token(self, token_id):
        return token_id in self.children

    def get_token(self, token_id):
        return self.children[token_id]

    def add_token(self, token_id, logprobs=None):
        self.children[token_id] = TokenTrie(self, logprobs)
        return self.children[token_id]

    def extend_cache(self, next_token_index, token_ids, logits, base):
        node = self

        for j in range(next_token_index, len(token_ids)):
            token_id = token_ids[j]
            token_logits = logits[j - base]
            token_logprobs = torch.log_softmax(token_logits, 0)

            node = node.add_token(token_id, token_logprobs.cpu().numpy())

        return node

Transformer

Bases: Distribution

Source code in hfppl/distributions/transformer.py
class Transformer(Distribution):

    def __init__(self, lm, prompt, temp=1.0):
        """Create a Categorical distribution whose values are Tokens, with probabilities given
        by a language model. Supports auto-batching.

        Args:
            lm (hfppl.llms.CachedCausalLM): the language model.
            prompt (str | hfppl.llms.TokenSequence): the sequence of tokens to use as the prompt. If a string, `lm.tokenizer` is used to encode it.
            temp (float): temperature at which to generate (0 < `temp` < `float('inf')`).
        """
        self.lm = lm
        self.temp = temp

        # prompt will be a list of ints
        if isinstance(prompt, str):
            prompt = self.lm.tokenizer.encode(prompt)
        elif isinstance(prompt, TokenSequence):
            prompt = prompt.seq

        self.prompt = prompt

    async def log_prob(self, x):
        log_probs = await self.lm.next_token_logprobs(self.prompt)
        log_probs = log_probs / self.temp

        if isinstance(x, Token):
            x = x.token_id

        return log_probs[x]

    async def sample(self):
        log_probs = await self.lm.next_token_logprobs(self.prompt)
        log_probs = log_probs / self.temp
        probs = np.exp(log_probs)
        token_id = np.random.choice(len(probs), p=(probs))
        logprob = log_probs[token_id]
        return (
            Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)),
            logprob,
        )

__init__(lm, prompt, temp=1.0)

Create a Categorical distribution whose values are Tokens, with probabilities given by a language model. Supports auto-batching.

Parameters:

Name Type Description Default
lm CachedCausalLM

the language model.

required
prompt str | TokenSequence

the sequence of tokens to use as the prompt. If a string, lm.tokenizer is used to encode it.

required
temp float

temperature at which to generate (0 < temp < float('inf')).

1.0
Source code in hfppl/distributions/transformer.py
def __init__(self, lm, prompt, temp=1.0):
    """Create a Categorical distribution whose values are Tokens, with probabilities given
    by a language model. Supports auto-batching.

    Args:
        lm (hfppl.llms.CachedCausalLM): the language model.
        prompt (str | hfppl.llms.TokenSequence): the sequence of tokens to use as the prompt. If a string, `lm.tokenizer` is used to encode it.
        temp (float): temperature at which to generate (0 < `temp` < `float('inf')`).
    """
    self.lm = lm
    self.temp = temp

    # prompt will be a list of ints
    if isinstance(prompt, str):
        prompt = self.lm.tokenizer.encode(prompt)
    elif isinstance(prompt, TokenSequence):
        prompt = prompt.seq

    self.prompt = prompt

log_softmax(nums)

Compute log(softmax(nums)).

Parameters:

Name Type Description Default
nums

a vector or numpy array of unnormalized log probabilities.

required

Returns:

Type Description

np.array: an array of log (normalized) probabilities.

Source code in hfppl/util.py
def log_softmax(nums):
    """Compute log(softmax(nums)).

    Args:
        nums: a vector or numpy array of unnormalized log probabilities.

    Returns:
        np.array: an array of log (normalized) probabilities.
    """
    return nums - logsumexp(nums)

sample_word(self, context, max_tokens=5, allow_punctuation=True) async

Sample a word from the LMContext object context.

Source code in hfppl/chunks.py
@submodel
async def sample_word(self, context, max_tokens=5, allow_punctuation=True):
    """Sample a word from the `LMContext` object `context`."""
    last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else ""
    last_character = last_token[-1] if len(last_token) > 0 else ""
    needs_space = last_character not in string.whitespace and last_character not in [
        "-",
        "'",
        '"',
    ]
    if needs_space:
        starts_word_mask = context.lm.masks.STARTS_NEW_WORD
    else:
        starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD

    # Force model to start a new word
    await self.observe(context.mask_dist(starts_word_mask), True)

    word = ""
    num_tokens = 0
    while True:
        token = await self.sample(context.next_token())
        word += context.lm.vocab[token.token_id]
        num_tokens += 1

        if num_tokens == max_tokens:
            await self.observe(
                context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False
            )
            break

        if not (
            await self.sample(
                context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD)
            )
        ):
            break

    # Sample punctuation, if desired
    punctuation = ""
    if allow_punctuation and await self.sample(
        context.mask_dist(context.lm.masks.PUNCTUATION)
    ):
        punctuation_token = await self.sample(context.next_token())
        punctuation = context.lm.vocab[punctuation_token.token_id]

    return word, punctuation

sample_word_2(self, context, max_chars=None, allow_mid_punctuation=True, allow_end_punctuation=True) async

Sample a word from the LMContext object context.

Unlike sample_word() above, this method allows for character-level control over the length of the word. It also allows for control over the presence of punctuation in the middle and at the end of the word.

Parameters:

Name Type Description Default
max_chars int

Maximum number of characters in the word. If None, the model will sample a word of any length.

None
allow_mid_punctuation bool

If True, the model may sample punctuation in the middle of the word.

True
allow_end_punctuation bool

If True, the model may sample punctuation at the end of the word.

True

Returns:

Type Description

Tuple[str, str]: The sampled word and punctuation

Source code in hfppl/chunks.py
@submodel
async def sample_word_2(
    self,
    context,
    max_chars: int = None,
    allow_mid_punctuation: bool = True,
    allow_end_punctuation: bool = True,
):
    """Sample a word from the `LMContext` object `context`.

    Unlike sample_word() above, this method allows for character-level control over the length of the word.
    It also allows for control over the presence of punctuation in the middle and at the end of the word.

    Args:
        max_chars (int): Maximum number of characters in the word. If None, the model will sample a word of any length.
        allow_mid_punctuation (bool): If True, the model may sample punctuation in the middle of the word.
        allow_end_punctuation (bool): If True, the model may sample punctuation at the end of the word.

    Returns:
        Tuple[str, str]: The sampled word and punctuation
    """

    # This approach sometimes breaks with max_chars = 1
    if max_chars is not None:
        assert max_chars > 1

    last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else ""
    last_character = last_token[-1] if len(last_token) > 0 else ""
    needs_space = last_character not in string.whitespace and last_character not in [
        "-",
        "'",
        '"',
    ]
    if needs_space:
        starts_word_mask = context.lm.masks.STARTS_NEW_WORD
    else:
        starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD

    # Force model to start a new word
    await self.observe(context.mask_dist(starts_word_mask), True)

    word = ""
    while True:
        # Force model to sample a token with an appropriate number of characters
        if max_chars is not None:
            await self.observe(
                context.mask_dist(
                    context.lm.masks.MAX_TOKEN_LENGTH[max_chars - len(word.strip())]
                ),
                True,
            )

        token = await self.sample(context.next_token())
        word += context.lm.vocab[token.token_id]

        # If we ran out of chars, break
        if max_chars is not None and len(word.strip()) >= max_chars:
            await self.observe(
                context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False
            )
            break

        # If the model wants to end the word, break
        if not (
            await self.sample(
                context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD)
            )
        ):
            break

    # Sample punctuation, if desired
    punctuation = ""

    mask = set()
    if allow_mid_punctuation:
        mask = mask | context.lm.masks.MID_PUNCTUATION
    if allow_end_punctuation:
        mask = mask | context.lm.masks.END_PUNCTUATION

    if mask and await self.sample(context.mask_dist(mask)):
        punctuation_token = await self.sample(context.next_token())
        punctuation = context.lm.vocab[punctuation_token.token_id]

    return word, punctuation

submodel(f)

Decorator to create a SubModel implementation from an async function.

For example:

@submodel
async def sample_two_tokens(self, context):
    token1 = await self.sample(context.next_token())
    token2 = await self.sample(context.next_token())
    return token1, token2

This SubModel can then be used from another model or submodel, using the syntax await self.call(sample_two_tokens(context)).

Source code in hfppl/modeling.py
def submodel(f):
    """Decorator to create a SubModel implementation from an async function.

    For example:

    ```python
    @submodel
    async def sample_two_tokens(self, context):
        token1 = await self.sample(context.next_token())
        token2 = await self.sample(context.next_token())
        return token1, token2
    ```

    This SubModel can then be used from another model or submodel, using the syntax `await self.call(sample_two_tokens(context))`.
    """

    @functools.wraps(f, updated=())  # unclear if this is the best way to do it
    class SubModelImpl(SubModel):
        def __init__(self, *args, **kwargs):
            super().__init__()
            self.args = args
            self.kwargs = kwargs

        async def forward(self):
            return await f(self, *self.args, **self.kwargs)

    return SubModelImpl