Skip to content

fix[next]: Fix StaticArgs injection#2659

Open
SF-N wants to merge 4 commits into
GridTools:mainfrom
SF-N:empty_static_args
Open

fix[next]: Fix StaticArgs injection#2659
SF-N wants to merge 4 commits into
GridTools:mainfrom
SF-N:empty_static_args

Conversation

@SF-N

@SF-N SF-N commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

This pull request updates the compilation logic in CompiledProgram to ensure that static argument descriptors are only injected when static arguments are actually provided.

@SF-N SF-N requested a review from tehrengruber June 16, 2026 12:20
@SF-N SF-N changed the title Fix StaticArgs injection fix[next]: Fix StaticArgs injection Jun 16, 2026
Comment on lines +669 to +670
# Only inject a `StaticArg` descriptor mapping when static arguments were
# actually given.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Only inject a `StaticArg` descriptor mapping when static arguments were
# actually given.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking: This is just reiterating what the code does anyway.

# the dict directly. Note that we don't need to check any args, since the pool checks
# this on compile anyway.
if "_compiled_programs" not in self.__dict__:
static_params: tuple[str, ...] = ()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here are not doing anything, the other path where the argument descriptors are built is in _make_compiled_programs_pool and that needs to be kept in sync with CompiledProgramsPool.compile

Comment on lines +131 to +146
empty_static_args = {}
decorator_path_testee = compile_testee.with_backend(cartesian_case.backend)
decorator_path_testee.compile(
offset_provider=cartesian_case.offset_provider, **empty_static_args
)
decorator_path_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping
assert arguments.StaticArg not in decorator_path_mapping

compiled_program_path_testee = compile_testee.with_backend(cartesian_case.backend)
compiled_program_path_pool = compiled_program_path_testee._make_compiled_programs_pool(
static_params=(), static_domains=False
)
compiled_program_path_pool.argument_descriptor_mapping = None
compiled_program_path_pool.compile(offset_providers=[cartesian_case.offset_provider])

assert compiled_program_path_pool.argument_descriptor_mapping == decorator_path_mapping

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
empty_static_args = {}
decorator_path_testee = compile_testee.with_backend(cartesian_case.backend)
decorator_path_testee.compile(
offset_provider=cartesian_case.offset_provider, **empty_static_args
)
decorator_path_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping
assert arguments.StaticArg not in decorator_path_mapping
compiled_program_path_testee = compile_testee.with_backend(cartesian_case.backend)
compiled_program_path_pool = compiled_program_path_testee._make_compiled_programs_pool(
static_params=(), static_domains=False
)
compiled_program_path_pool.argument_descriptor_mapping = None
compiled_program_path_pool.compile(offset_providers=[cartesian_case.offset_provider])
assert compiled_program_path_pool.argument_descriptor_mapping == decorator_path_mapping
empty_static_args = {}
precompiled_testee = compile_testee.with_backend(cartesian_case.backend)
precompiled_testee.compile(
offset_provider=cartesian_case.offset_provider, **empty_static_args
)
precompiled_arg_descr_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping
jit_testee = compile_testee.with_backend(cartesian_case.backend)
jit_testee(...)
jit_arg_descr_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping
assert precompiled_arg_descr_mapping == jit_arg_descr_mapping

assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray)


def test_compile_with_empty_static_args_matches_no_static_args(cartesian_case, compile_testee):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_compile_with_empty_static_args_matches_no_static_args(cartesian_case, compile_testee):
def test_precompile_and_jit_have_same_arg_descr_mapping(cartesian_case, compile_testee):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants