Skip to content

Obscure NotImplementedError for Categorical #545

@rtbs-dev

Description

@rtbs-dev

Somewhat new to numpyro, though more familiar with Jax, so apologies if this is a known issue.

Modelling the boilerplate off of the baseball and time-series forcasting examples, working on a network inference problem (see here for an older jax version with discussion)

Setup looks like:

@jit 
def jax_squareform(edgelist, n=n_nodes):
    """edgelist to adj. matrix"""
    empty = np.zeros((n,n))
    half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
    full = half+half.T
    return full


def spread_jax(p,u_init,T):
    """
    p: transmission probability matrix
    u_init: initial infection node states
    T: num. iterations to observe at
    """
    def scan_fn(u, t):
        u_add = lax.tanh(p@u)
        u_p = 1-(1-u)*(1-u_add)
        return u_p, u_add
    u_end, u_adds = lax.scan(
        scan_fn, u_init, np.arange(T) 
    )
    return u_end, u_adds


def diff_kg(infections):
    n_cascades, n_nodes  = infections.shape
    n_edges = n_nodes*(n_nodes-1)//2 # complete graph
    
    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(np.zeros(n_edges), 
                                         np.ones(n_edges)))
    v = ny.sample("v", dist.Gamma(np.ones(n_edges),
                                       20*np.ones(n_edges)))
    ## Bayesian Inference and Decision Theory, Dr. Laskey (GMU)
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
    s_ij = jax_squareform(Λ)  # adjacency matrix to recover via inference
    
    
    with ny.plate("n_cascades", n_cascades):
        # infer source node
        ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
        x0 = ny.sample("x0", dist.Categorical(ϕ))
        
        # simulate ode and realize
#         infectious = spread_jax(s_ij, x0, 0, 5)
        infectious, hist = spread_jax(s_ij, x0, 5)
        numpyro.sample("obs", dist.Bernoulli(probs=infectious), 
                       obs=infections)

kernel = ny.infer.NUTS(diff_kg)
mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
mcmc.run(PRNGKey(0), infections)
mcmc.print_summary()
samples = mcmc.get_samples()

Where infections is an array with columns as nodes (0=susceptible, 1=infected) and rows as unique observations, simulated from a "ground-truth" network and different source nodes.

Running based on documentation examples results in the following error that I'm having quite a hard time parsing (sorry for the wall of text):

Details
KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    526         try:
--> 527             factory = self._registry[type(constraint)]
    528         except KeyError:

KeyError: <class 'numpyro.distributions.constraints._IntegerInterval'>

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-33-5d6906d300b6> in <module>
      1 kernel = ny.infer.NUTS(diff_kg)
      2 mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
----> 3 mcmc.run(PRNGKey(0), cascades)
      4 mcmc.print_summary()
      5 samples = mcmc.get_samples()

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
   1194         collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
   1195         if self.num_chains == 1:
-> 1196             states_flat, last_state = self._single_chain_mcmc(rng_key, init_state, init_params,
   1197                                                               args, kwargs, collect_fields)
   1198             states = tree_map(lambda x: x[np.newaxis, ...], states_flat)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields)
   1067     def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
   1068         if init_state is None:
-> 1069             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
   1070                                            model_args=args, model_kwargs=kwargs)
   1071         if self.postprocess_fn is None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    506         # Find valid initial params
    507         if self._model and not init_params:
--> 508             init_params, is_valid = find_valid_initial_params(rng_key, self._model,
    509                                                               init_strategy=self._init_strategy,
    510                                                               param_as_improper=True,

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, param_as_improper, model_args, model_kwargs)
    370     # Handle possible vectorization
    371     if rng_key.ndim == 1:
--> 372         init_params, is_valid = _find_valid_params(rng_key)
    373     else:
    374         init_params, is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key_)
    359 
    360     def _find_valid_params(rng_key_):
--> 361         _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
    362         # Early return if valid params found.
    363         if not_jax_tracer(is_valid):

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
    329         # Use `block` to not record sample primitives in `init_loc_fn`.
    330         seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
--> 331         model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    332         constrained_values, inv_transforms = {}, {}
    333         for k, v in model_trace.items():

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    147         :return: `OrderedDict` containing the execution trace.
    148         """
--> 149         self(*args, **kwargs)
    150         return self.trace
    151 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

<ipython-input-26-070713a497d6> in diff_kg(infections)
     40         # infer source node
     41         ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
---> 42         x0 = ny.sample("x0", dist.Categorical(ϕ))
     43 
     44         # simulate ode and realize

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape)
    103 
    104     # ...and use apply_stack to send it to the Messengers
--> 105     msg = apply_stack(initial_msg)
    106     return msg['value']
    107 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)
     20     pointer = 0
     21     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 22         handler.process_message(msg)
     23         # When a Messenger sets the "stop" field of a message,
     24         # it prevents any Messengers above it on the stack from being applied.

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in process_message(self, msg)
    430                 msg['value'] = self.param_map[msg['name']]
    431         else:
--> 432             base_value = self.substitute_fn(msg) if self.substitute_fn \
    433                 else self.base_param_map.get(msg['name'], None)
    434             if base_value is not None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _init_to_uniform(site, radius, skip_param)
    226             fn = site['fn']
    227         value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape'])
--> 228         base_transform = biject_to(fn.support)
    229         unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius),
    230                                              sample_shape=np.shape(base_transform.inv(value)))

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    527             factory = self._registry[type(constraint)]
    528         except KeyError:
--> 529             raise NotImplementedError
    530 
    531         return factory(constraint)

NotImplementedError: 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions