Skip to content

Add tool calling support to RLOOTrainer#5395

Draft
qgallouedec wants to merge 10 commits intomainfrom
rloo-tool-call
Draft

Add tool calling support to RLOOTrainer#5395
qgallouedec wants to merge 10 commits intomainfrom
rloo-tool-call

Conversation

@qgallouedec
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec commented Mar 27, 2026

Mirrors the tool calling integration already present in GRPOTrainer:

  • RLOOTrainer.__init__ accepts a tools argument (list of sync/async callables)
  • _tool_call_loop and _get_tool_suffix_ids copied verbatim from GRPOTrainer
  • tool_mask (1 = model-generated, 0 = tool result) applied to completion mask in loss and KL computations
  • max_tool_calling_iterations moved from training section to generation section in both RLOOConfig and GRPOConfig
  • Two tests added to test_rloo_trainer.py: test_training_with_tools (sync + async variants) and test_training_with_malformed_tool_calls

The key RLOO-specific difference: logprobs are discarded (_generate_single_turn returns None), so the logprobs is not None guards in _tool_call_loop are no-ops.


Note

Medium Risk
Adds multi-turn tool execution and new masking paths in RLOOTrainer generation/loss/KL calculations, which can affect training dynamics and sequence handling. Also introduces new runtime requirements (Transformers>=5 and jmespath) when tools are enabled.

Overview
Adds tool-calling support to RLOOTrainer. The trainer now accepts a tools list (sync or async callables), validates transformers>=5.0.0 and jmespath, and runs a multi-turn tool execution loop during generation (with optional max_tool_calling_iterations).

Tool-result tokens are tracked via a new tool_mask and are excluded from completion-length stats, logprob/KL calculations, and loss/entropy masking; new tool call/failure frequency metrics are logged. Config docs/fields move max_tool_calling_iterations into the generation section for both RLOOConfig and GRPOConfig, and tests add coverage for tool calling (including malformed calls) plus minor comments to reduce memory usage in several trainer tests.

Written by Cursor Bugbot for commit 2297245. This will update automatically on new commits. Configure here.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread trl/trainer/rloo_trainer.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 26c4f9bbed

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread trl/trainer/rloo_trainer.py
Comment on lines +408 to +410
generation_batch_size = args.per_device_train_batch_size * args.steps_per_generation
self._sync_tool_dicts = [{} for _ in range(generation_batch_size)]
self._async_tool_dicts = [{} for _ in range(generation_batch_size)]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Size tool lookup storage to active batch size

Tool lookup tables are preallocated using per_device_train_batch_size * steps_per_generation, but _tool_call_loop indexes them by the current batch position (idx_with_tool). During evaluation, per_device_eval_batch_size may legitimately be larger than the train generation batch, so a tool call in a higher index will hit IndexError at lookup. This makes tool-enabled eval fragile for valid configs; build per-call dicts or size storage by the active batch.

Useful? React with 👍 / 👎.

Comment thread trl/trainer/rloo_trainer.py Outdated
Comment thread tests/test_rloo_trainer.py Outdated
Comment thread trl/trainer/rloo_trainer.py
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing EOS guard crashes _get_tool_suffix_ids for some models

High Severity

_get_tool_suffix_ids uses max() on a generator that yields nothing when prefix_ids contains no EOS token, raising ValueError: max() arg is an empty sequence. The corresponding GRPOTrainer code safely builds a list and guards with if eos_positions: before trimming, with a comment explaining that "Templates that don't use EOS as end-of-turn (e.g. Gemma uses <turn|>) skip this trimming." This guard is missing in the RLOO copy, so any model whose chat template doesn't include the EOS token in turn boundaries will crash during tool calling.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

if name in sync_tool_dict:
tool_call_results.append((name, sync_tool_dict[name](**function["arguments"])))
elif name in async_tool_dict:
async_coros.append((name, async_tool_dict[name](**function["arguments"])))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing string-arguments parsing in tool call loop

High Severity

_tool_call_loop passes function["arguments"] directly to **kwargs without checking if it's a string. GRPOTrainer's version (lines 1597–1616) includes argument normalization that handles models (e.g., Gemma) returning arguments as strings instead of dicts — including JSON parsing, brace-wrapping, and regex fallback. Without this, **function["arguments"] on a string raises TypeError, which gets silently caught as a tool failure, producing incorrect results.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

@qgallouedec qgallouedec marked this pull request as draft April 7, 2026 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants