Modified sequential Monte Carlo algorithm that uses without-replacement resampling,
as described in our workshop abstract.
Parameters:
Name |
Type |
Description |
Default |
model |
Model
|
The model to perform inference on.
|
required
|
n_particles |
int
|
Number of particles to maintain.
|
required
|
n_beam |
int
|
Number of continuations to consider for each particle.
|
required
|
Returns:
Name | Type |
Description |
particles |
list[Model]
|
The completed particles after inference.
|
Source code in hfppl/inference/smc_steer.py
| async def smc_steer(model, n_particles, n_beam):
"""
Modified sequential Monte Carlo algorithm that uses without-replacement resampling,
as described in [our workshop abstract](https://arxiv.org/abs/2306.03081).
Args:
model (hfppl.modeling.Model): The model to perform inference on.
n_particles (int): Number of particles to maintain.
n_beam (int): Number of continuations to consider for each particle.
Returns:
particles (list[hfppl.modeling.Model]): The completed particles after inference.
"""
# Create n_particles copies of the model
particles = [copy.deepcopy(model) for _ in range(n_particles)]
await asyncio.gather(*[p.start() for p in particles])
while any(map(lambda p: not p.done_stepping(), particles)):
# Count the number of finished particles
n_finished = sum(map(lambda p: p.done_stepping(), particles))
n_total = n_finished + (n_particles - n_finished) * n_beam
# Create a super-list of particles that has n_beam copies of each
super_particles = []
for p in particles:
p.untwist()
super_particles.append(p)
if p.done_stepping():
p.weight += np.log(n_total) - np.log(n_particles)
else:
p.weight += np.log(n_total) - np.log(n_particles) - np.log(n_beam)
super_particles.extend([copy.deepcopy(p) for _ in range(n_beam - 1)])
# Step each super-particle
await asyncio.gather(
*[p.step() for p in super_particles if not p.done_stepping()]
)
# Use optimal resampling to resample
W = np.array([p.weight for p in super_particles])
W_tot = logsumexp(W)
W_normalized = softmax(W)
det_indices, stoch_indices, c = resample_optimal(W_normalized, n_particles)
particles = [
super_particles[i] for i in np.concatenate((det_indices, stoch_indices))
]
# For deterministic particles: w = w * N/N'
for i in det_indices:
super_particles[i].weight += np.log(n_particles) - np.log(n_total)
# For stochastic particles: w = 1/c * total sum(stoch weights) / num_stoch = sum(stoch weights / total) / num_stoch * total * N/M
for i in stoch_indices:
super_particles[i].weight = (
W_tot - np.log(c) + np.log(n_particles) - np.log(n_total)
)
# Return the particles
return particles
|