Skip to content

smc_steer

smc_steer(model, n_particles, n_beam) async

Modified sequential Monte Carlo algorithm that uses without-replacement resampling, as described in our workshop abstract.

Parameters:

Name Type Description Default
model Model

The model to perform inference on.

required
n_particles int

Number of particles to maintain.

required
n_beam int

Number of continuations to consider for each particle.

required

Returns:

Name Type Description
particles list[Model]

The completed particles after inference.

Source code in hfppl/inference/smc_steer.py
async def smc_steer(model, n_particles, n_beam):
    """
    Modified sequential Monte Carlo algorithm that uses without-replacement resampling,
    as described in [our workshop abstract](https://arxiv.org/abs/2306.03081).

    Args:
        model (hfppl.modeling.Model): The model to perform inference on.
        n_particles (int): Number of particles to maintain.
        n_beam (int): Number of continuations to consider for each particle.

    Returns:
        particles (list[hfppl.modeling.Model]): The completed particles after inference.
    """
    # Create n_particles copies of the model
    particles = [copy.deepcopy(model) for _ in range(n_particles)]

    for particle in particles:
        particle.start() # TODO: allow to be async?

    while any(map(lambda p: not p.done_stepping(), particles)):
        # Count the number of finished particles
        n_finished = sum(map(lambda p: p.done_stepping(), particles))
        n_total = n_finished + (n_particles - n_finished) * n_beam

        # Create a super-list of particles that has n_beam copies of each
        super_particles = []
        for p in particles:
            p.untwist()
            super_particles.append(p)
            if p.done_stepping():
                p.weight += np.log(n_total) - np.log(n_particles)
            else:
                p.weight += np.log(n_total) - np.log(n_particles) - np.log(n_beam)
                super_particles.extend([copy.deepcopy(p) for _ in range(n_beam-1)])

        # Step each super-particle
        await asyncio.gather(*[p.step() for p in super_particles if not p.done_stepping()])

        # Use optimal resampling to resample
        W = np.array([p.weight for p in super_particles])
        W_tot = logsumexp(W)
        W_normalized = softmax(W)
        det_indices, stoch_indices, c = resample_optimal(W_normalized, n_particles)
        particles = [super_particles[i] for i in np.concatenate((det_indices, stoch_indices))]
        # For deterministic particles: w = w * N/N'
        for i in det_indices:
            super_particles[i].weight += np.log(n_particles) - np.log(n_total)
        # For stochastic particles: w = 1/c * total       sum(stoch weights) / num_stoch = sum(stoch weights / total) / num_stoch * total * N/M
        for i in stoch_indices:
            super_particles[i].weight = W_tot - np.log(c) + np.log(n_particles) - np.log(n_total)

    # Return the particles
    return particles