Skip to content

limit the # of threads for jax #127

@h3jia

Description

@h3jia

Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. jax-ml/jax#743, jax-ml/jax#1539 and jax-ml/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.

import jax.numpy as jnp
from threadpoolctl import threadpool_limits

ja = jnp.ones((1000, 1000))
with threadpool_limits(5):
    for _ in range(100):
        foo = ja @ ja

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions