Improve integration with external samplers#2203
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
| def logdensity_fn(position: PositionDict) -> jax.Array: | ||
| return -bound_potential(position) | ||
|
|
||
| return LogDensityInfo( |
There was a problem hiding this comment.
I guess you can simply add a bound flag and return additional log_density_fn (or better name) in initialize model
| model_args: tuple[Any, ...] = (), | ||
| model_kwargs: Optional[dict[str, Any]] = None, | ||
| return_deterministic: bool = True, | ||
| batch_ndims: int = 1, |
There was a problem hiding this comment.
If this only introduces batch_ndims, let's add it in constrain_fn
| return jax.tree.map(lambda s: jnp.zeros(s.shape, s.dtype), info_shape) | ||
|
|
||
|
|
||
| class ExternalKernel(MCMCKernel): |
There was a problem hiding this comment.
Why do we need this class? The example seems to indicate that users will have more flexibility when subclass the mcmckernel directly, rather than constructing build_mclmc
There was a problem hiding this comment.
You are right, this helper class is less flexible (just experimenting with a friendly API). I have now modified the example to inhering directly from MCMCKernel in Added it in 1329c58
|
Thank you for the feedback @fehiepsi . Let me know how this iteration looks. I am happy to keep iterating until we get the best level of abstraction :) |
Motivated by #2124 and specifically by #2124 (comment), we suggest a cleaner API to integrate with other samplers. This includes MCLMC from Blackjax. This is illustrated in the notebook.
(most from the code changes come from the notebook)