Skip to content

GOLDTrainer VLM support#5461

Open
Strongich wants to merge 12 commits intohuggingface:mainfrom
Strongich:gold_vlm_support
Open

GOLDTrainer VLM support#5461
Strongich wants to merge 12 commits intohuggingface:mainfrom
Strongich:gold_vlm_support

Conversation

@Strongich
Copy link
Copy Markdown

@Strongich Strongich commented Apr 6, 2026

What does this PR do?

Adds VLM support to GOLDTrainer:

  • Same-family VLM distillation: same <image_pad> tokens, same vision encoder's family -> JSD loss
  • Cross-architecture VLM distillation: images are processed separately through each model's processor to handle different image token formats
  • vLLM support for both

Motivation

The GOLD algorithm has no theoretical constraints against VLM-to-VLM distillation -- the barriers were purely engineering (incompatible image token formats, different tokenizers, raw image handling through the dataloader).

Key changes

  • GOLDTrainer detects VLM datasets and uses an identity collator to preserve raw PIL images through the dataloader
  • For cross-architecture pairs, a _teacher_processor is stored and used in compute_loss to build teacher-compatible vision tensors from raw images
  • Auto-resolves teacher_tokenizer_name_or_path
  • Added examples/scripts/gold_vlm.py with two documented usage examples (same-family JSD + vLLM, cross-family ULD)
  • Added tests for VLM collator (label masking, completion preservation), cross-architecture detection (rejects JSD, stores teacher processor for different archs, skips it for same arch), VLM + vLLM init (copied from the LLM example), rejects LLM teacher with vision dataset
  • VLM handling (identity collator, raw image storage, vLLM multimodal path) is borrowed (where it was possible) from SFTTrainer and GRPOTrainer

Note

Looking for feedback:

  • I'm not fully confident the current approach for storing and passing raw images through the pipeline is optimal (especially in _fill_buffer), and the overall design choice with two different collators, as well as two separate generation flows (_generate_on_policy_vlm_raw vs _generate_on_policy_for_slices). Would appreciate feedback from anyone with more experience in this area.
  • I didn't add VLM usage examples to docs/source/gold_trainer.md -- will add if that's desirable, just let me know.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Medium Risk
Adds new VLM-specific collation, buffering, and generation paths (including vLLM) plus cross-architecture teacher processing, which changes core training-loop behavior and could affect distillation correctness/performance on multimodal datasets.

Overview
Enables vision-language distillation in GOLDTrainer by detecting vision datasets, validating VLM student/teacher compatibility, and switching to a VLM-aware pipeline that preserves raw images through the dataloader (identity collator + on-the-fly collation).

Adds DataCollatorForVisionLanguageChatML and updates training/generation (_fill_buffer, new _generate_on_policy_vlm_raw, multimodal forward kwargs, prompt-length handling) to support both same-architecture JSD and cross-architecture ULD where the teacher can re-process images via a stored _teacher_processor.

Extends config defaults (remove_unused_columns=False), auto-resolves teacher tokenizer for ULD, adds a runnable examples/scripts/gold_vlm.py, and significantly expands test coverage for VLM collation, init validation, cross-architecture behavior, and VLM+vLLM integration.

Reviewed by Cursor Bugbot for commit fd3be85. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread trl/experimental/utils.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py
# Models
# ──────────────────────────────────────────────
student_model = AutoModelForImageTextToText.from_pretrained(cli_args.student_model_name, dtype=torch.bfloat16)
teacher_model = AutoModelForImageTextToText.from_pretrained(cli_args.teacher_model_name, dtype=torch.bfloat16)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example script uses wrong dtype parameter name

Low Severity

AutoModelForImageTextToText.from_pretrained is called with dtype=torch.bfloat16 instead of the correct torch_dtype=torch.bfloat16. The dtype kwarg is not a recognized parameter for from_pretrained, so the models will silently load in their default precision (float32) instead of bfloat16, increasing memory usage and potentially causing dtype mismatches during training.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9a1f345. Configure here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bug, torch_dtype is deprecated (everybody knows this warning)
Maybe I should add version checking, like here

Comment thread trl/experimental/utils.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py Outdated
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py Outdated
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 7c96055. Configure here.

Comment thread trl/experimental/gold/gold_trainer.py
@Strongich
Copy link
Copy Markdown
Author

@kashif @qgallouedec I think you guys might be interested in this PR, looking forward to hearing from u

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant