diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_cp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_cp.yaml new file mode 100644 index 000000000..58593d072 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_cp.yaml @@ -0,0 +1,380 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: false + seed: 42 + drop_last: true + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + context_parallel_degree: 2 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + multi_device_generator_policy: error + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_cp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_cp_model: + component_key: model + variant_key: gpt2_cp + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.n_layer} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.n_embd} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE \ No newline at end of file diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index 7cd58a6d5..2d8d65678 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -10,6 +10,7 @@ settings: global_rank: ${cuda_env:RANK} world_size: ${cuda_env:WORLD_SIZE} paths: + experiments_root_path: ${modalities_env:experiments_root_path} checkpoint_saving_path: data/checkpoints train_dataset_path: ./data/lorem_ipsum_long.pbin test_dataset_path: ./data/lorem_ipsum.pbin diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_cp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_cp.yaml new file mode 100644 index 000000000..2d8d65678 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_cp.yaml @@ -0,0 +1,466 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + experiments_root_path: ${modalities_env:experiments_root_path} + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + debugging: + component_key: debugging + variant_key: settings + config: + enable_determinism: false + forward_hooks: + - instance_key: error_on_nan + pass_type: BY_REFERENCE + - instance_key: print_forward_hook + pass_type: BY_REFERENCE + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 4 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: Interleaved1F1B + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + # variant_key: dummy + # config: {} + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.n_layer} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.n_embd} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +error_on_nan: + component_key: model_debugging_hook + variant_key: nan_hook + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + raise_exception: true + +print_forward_hook: + component_key: model_debugging_hook + variant_key: print_forward_hook + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + print_shape_only: true diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml index 8e44e38b8..4591174b6 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -10,6 +10,7 @@ settings: global_rank: ${cuda_env:RANK} world_size: ${cuda_env:WORLD_SIZE} paths: + experiments_root_path: ${modalities_env:experiments_root_path} checkpoint_saving_path: data/checkpoints train_dataset_path: ./data/lorem_ipsum_long.pbin test_dataset_path: ./data/lorem_ipsum.pbin diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp_cp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp_cp.yaml new file mode 100644 index 000000000..b260b1a4e --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp_cp.yaml @@ -0,0 +1,442 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + experiments_root_path: ${modalities_env:experiments_root_path} + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 2 + checkpointing_interval_in_steps: 100000 + evaluation_interval_in_steps: 15 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 16 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_steps: ${settings.training_target.num_target_steps} + num_target_steps: 20 + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 2 + context_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + seed: 42 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + context_parallel_load_balancer: headtail + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.n_layer} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.n_embd} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE \ No newline at end of file diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..a2720c717 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -327,6 +327,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None: class GPT2ModelTPConfig(BaseModel): model: PydanticPytorchModuleOrListType # TODO set proper type device_mesh: PydanticDeviceMeshIFType + context_parallel_load_balancer: Literal["headtail", "ptrr"] | None = "headtail" @model_validator(mode="after") def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig": @@ -335,12 +336,36 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig": raise ValueError(f"Device mesh {self.device_mesh=} has no defined mesh_dim_names.") if ParallelismDegrees.TP.value not in mesh_dim_names: raise ValueError(f"Tensor parallelism key '{ParallelismDegrees.TP.value}' not in {self.device_mesh=}") + if ( + "context_parallel_load_balancer" in self.model_fields_set + and self.context_parallel_load_balancer is not None + and ParallelismDegrees.CP.value not in mesh_dim_names + ): + raise ValueError( + "context_parallel_load_balancer can only be set when context parallelism is configured in the mesh. " + f"Expected key '{ParallelismDegrees.CP.value}' in {self.device_mesh=}." + ) if ParallelismDegrees.DP_REPLICATE.value in mesh_dim_names: # TorchTitan uses replicate (i.e, plain DP) to combine DP with TP. raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.") return self +class GPT2ModelCPConfig(BaseModel): + model: PydanticPytorchModuleOrListType + device_mesh: PydanticDeviceMeshIFType + context_parallel_load_balancer: Literal["headtail", "ptrr"] | None = "headtail" + + @model_validator(mode="after") + def validate_cp_mesh_existence(self) -> "GPT2ModelCPConfig": + mesh_dim_names = self.device_mesh.mesh_dim_names + if mesh_dim_names is None: + raise ValueError(f"Device mesh {self.device_mesh=} has no defined mesh_dim_names.") + if ParallelismDegrees.CP.value not in mesh_dim_names: + raise ValueError(f"Context parallelism key '{ParallelismDegrees.CP.value}' not in {self.device_mesh=}") + return self + + class CompiledModelConfig(BaseModel): model: PydanticPytorchModuleOrListType block_names: list[str] diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index fb9bdc0d3..33f8c24aa 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -13,6 +13,7 @@ from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_degree from modalities.running_env.fsdp.reducer import Reducer +from modalities.trainer import apply_context_parallel_sharding_to_batch from modalities.util import TimeRecorder @@ -33,6 +34,7 @@ def __init__( """ self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher + self.device_mesh = device_mesh if device_mesh is not None: self.dp_degree = get_parallel_degree( device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD] @@ -62,6 +64,16 @@ def evaluate_batch( torch.Tensor | None: The loss of the batch None, if a non-last stage was processed in pipeline parallelism """ + sample_key = getattr(model[0], "sample_key", None) + context_parallel_load_balancer = getattr(model[0], "_context_parallel_load_balancer", "headtail") + apply_context_parallel_sharding_to_batch( + device_mesh=self.device_mesh, + batch=batch, + sample_key=sample_key, + target_key=loss_fun.target_key, + context_parallel_load_balancer=context_parallel_load_balancer, + ) + with torch.no_grad(): if scheduled_pipeline is not None: pp_schedule = scheduled_pipeline.pp_schedule diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..d282a0ccc 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -207,8 +207,26 @@ def apply_rotary_pos_emb(self, x, cos, sin): # the rotation below work return (x * cos) + (self.rotate_half(x) * sin) + def _compute_cos_sin_from_positions( + self, position_ids: torch.Tensor, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # position_ids: (B, T) or (T,) — explicit global token positions + # Returns cos, sin of shape (B or 1, 1, T, dim_model) matching x: (B, nh, T, hd) + pos = position_ids.float() + if pos.dim() == 1: + pos = pos.unsqueeze(0) # (1, T) + freqs = torch.einsum("bt,d->btd", pos, self.inv_freq.to(x.dtype)) # (B or 1, T, dim/2) + emb = torch.cat((freqs, freqs), dim=-1) # (B or 1, T, dim) + cos = emb.cos().to(x.dtype).unsqueeze(1) # (B or 1, 1, T, dim) + sin = emb.sin().to(x.dtype).unsqueeze(1) + return cos, sin + def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + position_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the RotaryTransform module. @@ -217,12 +235,20 @@ def forward( q (torch.Tensor): Query tensor. k (torch.Tensor): Key tensor. v (torch.Tensor): Value tensor. + position_ids (torch.Tensor | None): Optional explicit global position indices of shape + (B, T) or (T,). When provided (e.g. for context-parallel ranks that hold + non-contiguous token ranges), the correct global RoPE frequencies are computed + from these positions instead of assuming a local 0-based range. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing the modified query tensor, key tensor, and value tensor. """ - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) + if position_ids is not None: + cos, sin = self._compute_cos_sin_from_positions(position_ids, k) + self._cos_cached, self._sin_cached = cos, sin + else: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached) k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached) @@ -514,7 +540,12 @@ def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch @staticmethod def execute_qkv_transforms( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_transforms: nn.ModuleList, + n_head_q: int, + position_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies a series of transformations to the query, key, and value tensors. @@ -525,6 +556,8 @@ def execute_qkv_transforms( v (torch.Tensor): The value tensors. qkv_transforms (nn.ModuleList): A list of transformation modules to be applied to q, k, and v. n_head_q (int): The number of heads for the query tensors. + position_ids (torch.Tensor | None): Optional explicit global position indices forwarded + to RotaryTransform so CP ranks use correct global RoPE frequencies. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -541,7 +574,10 @@ def execute_qkv_transforms( v = v.view(batch_size, sequence_length, -1, n_head_dim).transpose(1, 2).contiguous() # (B, nh_kv, T, hd) for transform in qkv_transforms: - q, k, v = transform(q, k, v) + if isinstance(transform, RotaryTransform) and position_ids is not None: + q, k, v = transform(q, k, v, position_ids=position_ids) + else: + q, k, v = transform(q, k, v) return q, k, v @@ -655,12 +691,14 @@ def execute_attention( raise NotImplementedError(f"Attention implementation {attention_impl} not supported") return y # (B, T, nh_q, hd) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor: """ Forward pass of the CausalSelfAttention module. Args: x (torch.Tensor): Input tensor of shape (B, T, n_embd) + position_ids (torch.Tensor | None): Optional global position indices forwarded to + RotaryTransform for correct CP-rank-aware RoPE. Returns: torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection. @@ -669,7 +707,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: q, k, v = self.projection(x) # q: (B, T, n_embd), k: (B, T, n_embd // n_rep), v: (B, T, n_embd // n_rep) # q: (B, nh_q, T, hd), k: (B, nh_kv, T, hd), v: (B, nh_kv, T, hd) - q, k, v = CausalSelfAttention.execute_qkv_transforms(q, k, v, self.qkv_transforms, self.n_head_q) + q, k, v = CausalSelfAttention.execute_qkv_transforms( + q, k, v, self.qkv_transforms, self.n_head_q, position_ids=position_ids + ) if self.q_norm is not None and self.k_norm is not None: q = self.q_norm(q) k = self.k_norm(k) @@ -796,17 +836,19 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None: f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`." ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor: """ Forward pass of the GPT2Block. Args: x (torch.Tensor): Input tensor. + position_ids (torch.Tensor | None): Optional global position indices forwarded to + the attention layer for CP-aware RoPE. Returns: torch.Tensor: Output tensor. """ - x = x + self.attn(self.attention_norm(x)) + x = x + self.attn(self.attention_norm(x), position_ids=position_ids) x = x + self.mlp(self.ffn_norm(x)) return x @@ -971,22 +1013,29 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t Forward pass of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. + inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. When a dict, an optional + ``"position_ids"`` key (shape ``(B, T)`` or ``(1, T)``) may be present to supply + explicit global token positions for CP-aware RoPE. Returns: dict[str, torch.Tensor] | torch.Tensor: Model output. """ if isinstance(inputs, dict): - return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + position_ids = inputs.get("position_ids", None) + return {self.prediction_key: self.forward_impl(inputs[self.sample_key], position_ids=position_ids)} else: return self.forward_impl(inputs) - def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: + def forward_impl(self, inputs: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor: """ Forward pass implementation of the GPT2LLM module. Args: inputs (torch.Tensor): A tensor containing input token ids. + position_ids (torch.Tensor | None): Optional explicit global position indices + of shape ``(B, T)`` or ``(1, T)``. When provided, RoPE uses these positions + instead of a local 0-based arange, enabling correct behaviour for CP ranks + that hold non-contiguous token ranges. Returns: torch.Tensor: A tensor containing output logits. @@ -1010,7 +1059,7 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h for layer_idx in self.transformer.h: - h = self.transformer.h[layer_idx](h) + h = self.transformer.h[layer_idx](h, position_ids=position_ids) h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h return h diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 62933794d..ba468e16d 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -35,15 +35,21 @@ GPT2LLM, AttentionConfig, AttentionImplementation, + CausalSelfAttention, LayerNormWrapperConfig, PositionTypes, SwiGLU, TransformerMLP, ) from modalities.models.model import ActivationType +from modalities.models.parallelism.context_parallel import apply_cp_to_sdpa_attention_forward from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.running_env.env_utils import FSDP2MixedPrecisionSettings, MixedPrecisionSettings -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_mesh_for_parallelism_method, + has_parallelism_method, +) from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory from modalities.training.activation_checkpointing.activation_checkpointing import ( ActivationCheckpointing, @@ -593,6 +599,73 @@ def register_hooks_recursively(module: nn.Module, prefix: str = ""): class GPT2ModelFactory: + @staticmethod + def _get_cp_mesh_if_enabled(device_mesh: DeviceMesh) -> DeviceMesh | None: + if not has_parallelism_method(device_mesh, ParallelismDegrees.CP): + return None + cp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.CP) + return cp_mesh if cp_mesh.size() > 1 else None + + @staticmethod + def _validate_context_parallel_seq_len( + model: GPT2LLM, + cp_degree: int, + tp_degree: int = 1, + load_balancer_type: str | None = "headtail", + ) -> None: + # The "headtail" balancer splits each rank's chunk into a head and tail piece, + # requiring an extra factor of 2. Other load balancers don't impose this constraint. + headtail_factor = 2 if load_balancer_type == "headtail" else 1 + seq_len_divisor = tp_degree * cp_degree * headtail_factor + if model.sequence_length % seq_len_divisor != 0: + headtail_note = " * 2 (headtail)" if load_balancer_type == "headtail" else "" + raise ValueError( + f"For GPT2 CP runs, sequence_length must be divisible by tp_degree * cp_degree{headtail_note}. " + f"Got sequence_length={model.sequence_length}, tp_degree={tp_degree}, cp_degree={cp_degree}, " + f"load_balancer_type={load_balancer_type!r}." + ) + + @staticmethod + def _apply_context_parallel_to_gpt2_attention( + model: GPT2LLM, + cp_mesh: DeviceMesh | None, + context_parallel_load_balancer: str | None, + ) -> None: + if cp_mesh is None: + return + + if context_parallel_load_balancer not in ("headtail", "ptrr", None): + raise ValueError( + "context_parallel_load_balancer must be one of: 'headtail', 'ptrr', or None. " + f"Got {context_parallel_load_balancer}." + ) + + GPT2ModelFactory._validate_context_parallel_seq_len( + model=model, cp_degree=cp_mesh.size(), load_balancer_type=context_parallel_load_balancer + ) + + attention_modules: list[nn.Module] = [] + transformer_layers = getattr(model.transformer, "h", None) + if not isinstance(transformer_layers, nn.ModuleDict): + raise TypeError( + "Context parallelism requires model.transformer.h to be an nn.ModuleDict of GPT2 blocks. " + f"Got type {type(transformer_layers).__name__}." + ) + + for _, transformer_block in transformer_layers.named_children(): + attn_module = getattr(transformer_block, "attn", None) + if not isinstance(attn_module, CausalSelfAttention): + continue + if attn_module.attention_impl != AttentionImplementation.PYTORCH_FLASH: + raise NotImplementedError( + "Context parallelism currently supports only attention_implementation='pytorch_flash' " + "for GPT2 in this codebase." + ) + attention_modules.append(attn_module) + + apply_cp_to_sdpa_attention_forward(attention_modules=attention_modules, cp_mesh=cp_mesh) + setattr(model, "_context_parallel_load_balancer", context_parallel_load_balancer) + @staticmethod def get_gpt2_model( sample_key: str, @@ -653,8 +726,41 @@ def get_gpt2_model( return model @staticmethod - def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) -> nn.Module: + def get_gpt2_context_parallelized_model( + model: GPT2LLM, + device_mesh: DeviceMesh, + context_parallel_load_balancer: str | None = "headtail", + ) -> nn.Module: + cp_mesh = GPT2ModelFactory._get_cp_mesh_if_enabled(device_mesh=device_mesh) + GPT2ModelFactory._apply_context_parallel_to_gpt2_attention( + model=model, + cp_mesh=cp_mesh, + context_parallel_load_balancer=context_parallel_load_balancer, + ) + return model + + @staticmethod + def get_gpt2_tensor_parallelized_model( + model: GPT2LLM, + device_mesh: DeviceMesh, + context_parallel_load_balancer: str | None = "headtail", + ) -> nn.Module: tp_mesh = device_mesh[ParallelismDegrees.TP.value] + cp_mesh = GPT2ModelFactory._get_cp_mesh_if_enabled(device_mesh=device_mesh) + + if cp_mesh is not None: + GPT2ModelFactory._validate_context_parallel_seq_len( + model=model, + cp_degree=cp_mesh.size(), + tp_degree=tp_mesh.size(), + load_balancer_type=context_parallel_load_balancer, + ) + GPT2ModelFactory._apply_context_parallel_to_gpt2_attention( + model=model, + cp_mesh=cp_mesh, + context_parallel_load_balancer=context_parallel_load_balancer, + ) + model_tp_plan = { # Row-wise parallelism might seem counterintuitive here, # but the embedding layer has weight shape (vocab_size, n_embd). diff --git a/src/modalities/models/parallelism/context_parallel.py b/src/modalities/models/parallelism/context_parallel.py new file mode 100644 index 000000000..aa06186ac --- /dev/null +++ b/src/modalities/models/parallelism/context_parallel.py @@ -0,0 +1,121 @@ +from collections.abc import Sequence + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from modalities.models.gpt2.gpt2_model import AttentionImplementation, CausalSelfAttention + + +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. +def apply_cp_to_sdpa_attention_forward(attention_modules: Sequence[nn.Module], cp_mesh: DeviceMesh) -> None: + """Patch CausalSelfAttention.execute_attention to route SDPA through DTensor CP dispatch. + + The patch is class-level (not per-instance). `attention_modules` is used only as a + guard: if the list is empty no patching happens, since the model has no SDPA layers to + wrap. The individual module objects are not modified. + + It must run before tensor-parallel wrappers so CP logic executes inside local + tensor regions. + """ + if len(attention_modules) == 0: + return + + # Detect re-entry. A second call with the same mesh is a no-op; a different mesh + # would silently use the wrong mesh because cp_mesh is captured by closure. + existing_mesh = getattr(CausalSelfAttention, "_cp_mesh", None) + if getattr(CausalSelfAttention, "_cp_execute_attention_wrapped", False): + if existing_mesh is not cp_mesh: + raise RuntimeError( + "apply_cp_to_sdpa_attention_forward has already patched CausalSelfAttention " + "with a different cp_mesh. Re-patching with a new mesh is not supported." + ) + return + + try: + from torch.distributed.tensor import DTensor, Shard + from torch.distributed.tensor.experimental._attention import _enable_context_parallel_dispatcher + except (ImportError, ModuleNotFoundError) as exc: + raise RuntimeError( + "Context parallelism requires PyTorch experimental DTensor attention APIs. " + "Install a PyTorch build that provides " + "torch.distributed.tensor.experimental._attention." + ) from exc + + _enable_context_parallel_dispatcher() + + original_execute_attention = CausalSelfAttention.execute_attention + + def cp_execute_attention(cls, q, k, v, dropout, attention_impl): + if attention_impl != AttentionImplementation.PYTORCH_FLASH: + return original_execute_attention(q, k, v, dropout, attention_impl) + + placement = [Shard(2)] + if not isinstance(q, DTensor): + q = DTensor.from_local(q, cp_mesh, placement, run_check=False) + if not isinstance(k, DTensor): + k = DTensor.from_local(k, cp_mesh, placement, run_check=False) + if not isinstance(v, DTensor): + v = DTensor.from_local(v, cp_mesh, placement, run_check=False) + + output = original_execute_attention(q, k, v, dropout, attention_impl) + return output.to_local() if isinstance(output, DTensor) else output + + CausalSelfAttention.execute_attention = classmethod(cp_execute_attention) + setattr(CausalSelfAttention, "_cp_execute_attention_wrapped", True) + setattr(CausalSelfAttention, "_cp_mesh", cp_mesh) + + +def shard_tensor_buffers_for_context_parallel( + cp_mesh: DeviceMesh, + buffers: tuple[torch.Tensor, ...], + seq_dims: tuple[int, ...], + load_balancer_type: str | None = "headtail", + shard_impl=None, +) -> tuple[torch.Tensor, ...]: + """Shard tensor buffers across CP ranks along sequence dimensions. + + This mirrors TorchTitan's input sharding pattern while keeping the current + codebase focused on plain tensor inputs/targets (no attention mask sharding yet). + """ + if len(buffers) != len(seq_dims): + raise ValueError(f"Expected len(buffers) == len(seq_dims), got {len(buffers)} and {len(seq_dims)}.") + if len(buffers) == 0: + return tuple() + + if shard_impl is None: + try: + from torch.distributed.tensor.experimental._attention import _context_parallel_shard, _HeadTailLoadBalancer + except (ImportError, ModuleNotFoundError) as exc: + raise RuntimeError( + "Context parallel input sharding requires PyTorch experimental DTensor attention APIs." + ) from exc + shard_impl = _context_parallel_shard + + if load_balancer_type == "headtail": + seq_len = buffers[0].size(seq_dims[0]) + cp_world_size = cp_mesh.size() + load_balancer = _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type) + elif load_balancer_type is None: + load_balancer = None + elif load_balancer_type == "ptrr": + raise ValueError( + "PTRR load balancing is not supported for plain tensor input/target sharding without block masks." + ) + else: + raise ValueError( + f"Invalid load_balancer_type '{load_balancer_type}'. Must be one of: 'headtail', 'ptrr', or None" + ) + else: + # Tests can inject shard_impl and bypass private PyTorch imports. + load_balancer = None + + sharded = shard_impl( + mesh=cp_mesh, + buffers=buffers, + seq_dims=seq_dims, + load_balancer=load_balancer, + ) + return tuple(sharded) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..b78c69f2f 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -46,6 +46,7 @@ FSDPWrappedModelConfig, GPT2LLMCollateFnConfig, GPT2MFUCalculatorConfig, + GPT2ModelCPConfig, GPT2ModelTPConfig, LinearLRSchedulerConfig, LinearWarmupCosineAnnealingLRSchedulerConfig, @@ -187,6 +188,9 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2ModelFactory.get_gpt2_model, GPT2LLMConfig), + ComponentEntity( + "model", "gpt2_cp", maybe_model_list(GPT2ModelFactory.get_gpt2_context_parallelized_model), GPT2ModelCPConfig + ), ComponentEntity( "model", "gpt2_tp", maybe_model_list(GPT2ModelFactory.get_gpt2_tensor_parallelized_model), GPT2ModelTPConfig ), diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c715a01fa..20bbb18f6 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,6 +1,6 @@ +import gc from datetime import datetime from enum import Enum -import gc from typing import Callable, Optional import torch @@ -16,8 +16,14 @@ from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss from modalities.models.model import model_predict_batch +from modalities.models.parallelism.context_parallel import shard_tensor_buffers_for_context_parallel from modalities.models.parallelism.pipeline_parallelism import Pipeline -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_degree +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_mesh_for_parallelism_method, + get_parallel_degree, + has_parallelism_method, +) from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.training.training_progress import TrainingProgress @@ -51,6 +57,72 @@ class ThroughputAggregationKeys(Enum): FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" +def apply_context_parallel_sharding_to_batch( + device_mesh: DeviceMesh | None, + batch: DatasetBatch, + sample_key: str | None, + target_key: str, + context_parallel_load_balancer: str | None = "headtail", +) -> None: + """Shard the sequence dimension of a batch in-place for context parallelism. + + When a sample_key is provided, also derives and shards ``position_ids`` (global + token positions) so that RotaryTransform receives the correct frequencies for + each CP rank's non-contiguous token range. + """ + if device_mesh is None or not has_parallelism_method(device_mesh, ParallelismDegrees.CP): + return + + cp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.CP) + if cp_mesh.size() <= 1: + return + + buffer_keys: list[tuple[str, str]] = [] + buffers: list[torch.Tensor] = [] + seq_dims: list[int] = [] + + if sample_key is not None and sample_key in batch.samples: + if batch.samples[sample_key].device.type != "cuda": + batch.samples[sample_key] = batch.samples[sample_key].to(torch.cuda.current_device(), non_blocking=True) + # Build global position_ids before sharding so they carry the full-sequence range. + # After HeadTail sharding they hold the correct global indices for each CP rank's + # local tokens, which RotaryTransform uses instead of a local 0-based arange. + full_seq_len = batch.samples[sample_key].shape[1] + position_ids = torch.arange(full_seq_len, device=batch.samples[sample_key].device, dtype=torch.long).unsqueeze( + 0 + ) # (1, T) + buffer_keys.append(("sample", "position_ids")) + buffers.append(position_ids) + seq_dims.append(1) + + buffer_keys.append(("sample", sample_key)) + buffers.append(batch.samples[sample_key]) + seq_dims.append(1) + + if target_key in batch.targets: + if batch.targets[target_key].device.type != "cuda": + batch.targets[target_key] = batch.targets[target_key].to(torch.cuda.current_device(), non_blocking=True) + buffer_keys.append(("target", target_key)) + buffers.append(batch.targets[target_key]) + seq_dims.append(1) + + if not buffers: + return + + sharded_buffers = shard_tensor_buffers_for_context_parallel( + cp_mesh=cp_mesh, + buffers=tuple(buffers), + seq_dims=tuple(seq_dims), + load_balancer_type=context_parallel_load_balancer, + ) + + for (kind, key), tensor in zip(buffer_keys, sharded_buffers, strict=True): + if kind == "sample": + batch.samples[key] = tensor + else: + batch.targets[key] = tensor + + class Trainer: def __init__( self, @@ -92,6 +164,7 @@ def __init__( """ self.gc = GarbageCollection(gc_freq=10) self.global_rank = global_rank + self.device_mesh = device_mesh if device_mesh is not None: self.dp_degree = get_parallel_degree( device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD] @@ -126,6 +199,21 @@ def _get_num_train_steps_done(micro_batch_id: int, gradient_acc_steps: int) -> i """ return (micro_batch_id + 1) // gradient_acc_steps + def _apply_context_parallel_sharding_to_batch_( + self, + batch: DatasetBatch, + sample_key: str | None, + target_key: str, + context_parallel_load_balancer: str | None = "headtail", + ) -> None: + apply_context_parallel_sharding_to_batch( + device_mesh=self.device_mesh, + batch=batch, + sample_key=sample_key, + target_key=target_key, + context_parallel_load_balancer=context_parallel_load_balancer, + ) + def _train_batch( self, batch: DatasetBatch, @@ -159,6 +247,15 @@ def _train_batch( - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score, if a training step was performed otherwise return None. """ + sample_key = getattr(model_parts[0], "sample_key", None) + context_parallel_load_balancer = getattr(model_parts[0], "_context_parallel_load_balancer", "headtail") + self._apply_context_parallel_sharding_to_batch_( + batch=batch, + sample_key=sample_key, + target_key=loss_fun.target_key, + context_parallel_load_balancer=context_parallel_load_balancer, + ) + if scheduled_pipeline is not None: pp_schedule = scheduled_pipeline.pp_schedule # Pipeline Parallel forward / backward inside step() call @@ -388,7 +485,7 @@ def train( self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) - + profiler_cm.step() @staticmethod diff --git a/tests/end2end_tests/system_tests/test_context_parallel_parity.py b/tests/end2end_tests/system_tests/test_context_parallel_parity.py new file mode 100644 index 000000000..16efc2418 --- /dev/null +++ b/tests/end2end_tests/system_tests/test_context_parallel_parity.py @@ -0,0 +1,197 @@ +import json +import os +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from modalities.__main__ import Main, load_app_config_dict +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.logging_broker.messages import Message +from tests.end2end_tests.custom_components import ( + MultiProcessingCudaEnv, + SaveAllResultSubscriber, + SaveAllResultSubscriberConfig, +) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="This e2e test requires 2 GPUs.", +) +class TestContextParallelParity: + @staticmethod + def _build_config_dict(base_config_path: Path, experiments_root_path: Path, experiment_id: str, cp_enabled: bool): + config_dict = load_app_config_dict( + config_file_path=base_config_path, + experiments_root_path=experiments_root_path, + experiment_id=experiment_id, + ) + + # Keep runs tiny and deterministic for parity checks. + config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 4 + config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 4 + config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1 + config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1 + config_dict["settings"]["step_profile"]["sequence_length"] = 256 + + # Use SDPA backend in both runs to isolate CP impact. + config_dict["model_raw"]["config"]["attention_implementation"] = "pytorch_flash" + + # Remove conversion components and use explicit target settings for stable tiny runs. + config_dict["settings"]["training_target"]["num_target_steps"] = 4 + if cp_enabled: + # dp_degree becomes 1 when cp=2 and world_size=2. + config_dict["settings"]["training_target"]["num_target_tokens"] = 1024 + config_dict["device_mesh"]["config"]["context_parallel_degree"] = 2 + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = -1 + + config_dict["gpt2_cp_model"] = { + "component_key": "model", + "variant_key": "gpt2_cp", + "config": { + "model": {"instance_key": "model_raw", "pass_type": "BY_REFERENCE"}, + "device_mesh": {"instance_key": "device_mesh", "pass_type": "BY_REFERENCE"}, + "context_parallel_load_balancer": "headtail", + }, + } + config_dict["fsdp_model"]["config"]["model"] = { + "instance_key": "gpt2_cp_model", + "pass_type": "BY_REFERENCE", + } + + sampler_cfg = config_dict["train_dataloader"]["config"]["batch_sampler"]["config"]["sampler"] + sampler_cfg["variant_key"] = "resumable_distributed_multi_dim_sampler" + sampler_cfg["config"] = { + "dataset": {"instance_key": "train_dataset", "pass_type": "BY_REFERENCE"}, + "device_mesh": {"instance_key": "device_mesh", "pass_type": "BY_REFERENCE"}, + "data_parallel_key": "dp_shard", + "shuffle": True, + "seed": 42, + "drop_last": True, + "skip_num_global_samples": 0, + } + else: + # dp_degree is world_size=2 in non-CP run. + config_dict["settings"]["training_target"]["num_target_tokens"] = 2048 + config_dict["device_mesh"]["config"]["context_parallel_degree"] = 1 + config_dict["fsdp_model"]["config"]["model"] = { + "instance_key": "model_raw", + "pass_type": "BY_REFERENCE", + } + config_dict.pop("gpt2_cp_model", None) + + return config_dict + + @staticmethod + def _run_training_and_write_losses( + process_id: int, + world_size: int, + rdvz_port: int, + base_config_path: Path, + experiments_root_path: Path, + experiment_id: str, + cp_enabled: bool, + output_path: Path, + ): + torch.manual_seed(20) + torch.cuda.manual_seed(20) + + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=rdvz_port, + ): + config_dict = TestContextParallelParity._build_config_dict( + base_config_path=base_config_path, + experiments_root_path=experiments_root_path, + experiment_id=experiment_id, + cp_enabled=cp_enabled, + ) + + main_obj = Main( + base_config_path, + experiments_root_path=experiments_root_path, + experiment_id=experiment_id, + ) + main_obj.config_dict = config_dict + main_obj.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + + components: TrainingComponentsInstantiationModel = main_obj.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + main_obj.run(components) + + if dist.get_rank() == 0: + messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list + losses = [float(m.payload.losses["train loss avg"].value) for m in messages] + with open(output_path, "w", encoding="utf-8") as f: + json.dump(losses, f) + + @staticmethod + def test_cp_vs_non_cp_loss_parity(tmp_path: Path): + working_dir = Path(os.path.dirname(__file__)) + base_config_path = working_dir / "configs" / "fsdp2_gpt2_train_num_steps_8.yaml" + + non_cp_losses_path = tmp_path / "non_cp_losses.json" + cp_losses_path = tmp_path / "cp_losses.json" + + world_size = 2 + + mp.spawn( + TestContextParallelParity._run_training_and_write_losses, + args=( + world_size, + 24831, + base_config_path, + tmp_path, + "parity_non_cp", + False, + non_cp_losses_path, + ), + nprocs=world_size, + join=True, + ) + + mp.spawn( + TestContextParallelParity._run_training_and_write_losses, + args=( + world_size, + 24832, + base_config_path, + tmp_path, + "parity_cp", + True, + cp_losses_path, + ), + nprocs=world_size, + join=True, + ) + + with open(non_cp_losses_path, "r", encoding="utf-8") as f: + non_cp_losses = json.load(f) + with open(cp_losses_path, "r", encoding="utf-8") as f: + cp_losses = json.load(f) + + assert len(non_cp_losses) >= 4 + assert len(cp_losses) >= 4 + + # Compare aligned prefix for parity signal. + n = min(len(non_cp_losses), len(cp_losses)) + diffs = [abs(non_cp_losses[i] - cp_losses[i]) for i in range(n)] + + # We allow moderate drift; this is a regression guard against severe mismatch. + assert diffs[-1] < 0.5, f"Final loss diverged too much: {diffs[-1]}" + assert (sum(diffs) / len(diffs)) < 0.35, f"Average loss delta too high: {sum(diffs) / len(diffs)}" diff --git a/tests/fsdp2_parallelization/cp_test_configs/cp_config.yaml b/tests/fsdp2_parallelization/cp_test_configs/cp_config.yaml new file mode 100644 index 000000000..def587d1b --- /dev/null +++ b/tests/fsdp2_parallelization/cp_test_configs/cp_config.yaml @@ -0,0 +1,98 @@ +model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_cp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: FP_32 + reduce_dtype: FP_32 + block_names: [GPT2Block] + +gpt2_cp_model: + component_key: model + variant_key: gpt2_cp + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + context_parallel_load_balancer: headtail + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: false + use_weight_tying: false + sample_key: input_ids + poe_type: NOPE + sequence_length: 128 + prediction_key: logits + vocab_size: 50304 + n_layer: 2 + n_head_q: 4 + n_head_kv: 2 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: false + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + context_parallel_degree: 2 + tensor_parallel_degree: 1 + world_size: 2 diff --git a/tests/fsdp2_parallelization/cp_test_configs/cp_tp_config.yaml b/tests/fsdp2_parallelization/cp_test_configs/cp_tp_config.yaml new file mode 100644 index 000000000..3d79aa28f --- /dev/null +++ b/tests/fsdp2_parallelization/cp_test_configs/cp_tp_config.yaml @@ -0,0 +1,98 @@ +model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: FP_32 + reduce_dtype: FP_32 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + context_parallel_load_balancer: headtail + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: false + use_weight_tying: false + sample_key: input_ids + poe_type: NOPE + sequence_length: 128 + prediction_key: logits + vocab_size: 50304 + n_layer: 2 + n_head_q: 4 + n_head_kv: 2 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: false + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + context_parallel_degree: 2 + tensor_parallel_degree: 2 + world_size: 4 diff --git a/tests/fsdp2_parallelization/cp_test_configs/fsdp2_4gpu_config.yaml b/tests/fsdp2_parallelization/cp_test_configs/fsdp2_4gpu_config.yaml new file mode 100644 index 000000000..d698fe136 --- /dev/null +++ b/tests/fsdp2_parallelization/cp_test_configs/fsdp2_4gpu_config.yaml @@ -0,0 +1,86 @@ +model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: FP_32 + reduce_dtype: FP_32 + block_names: [GPT2Block] + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: false + use_weight_tying: false + sample_key: input_ids + poe_type: NOPE + sequence_length: 128 + prediction_key: logits + vocab_size: 50304 + n_layer: 2 + n_head_q: 4 + n_head_kv: 2 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: false + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + context_parallel_degree: 1 + tensor_parallel_degree: 1 + world_size: 4 diff --git a/tests/fsdp2_parallelization/cp_test_configs/fsdp2_config.yaml b/tests/fsdp2_parallelization/cp_test_configs/fsdp2_config.yaml new file mode 100644 index 000000000..4ccadab44 --- /dev/null +++ b/tests/fsdp2_parallelization/cp_test_configs/fsdp2_config.yaml @@ -0,0 +1,86 @@ +model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: FP_32 + reduce_dtype: FP_32 + block_names: [GPT2Block] + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: false + use_weight_tying: false + sample_key: input_ids + poe_type: NOPE + sequence_length: 128 + prediction_key: logits + vocab_size: 50304 + n_layer: 2 + n_head_q: 4 + n_head_kv: 2 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: false + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + context_parallel_degree: 1 + tensor_parallel_degree: 1 + world_size: 2 diff --git a/tests/fsdp2_parallelization/test_context_parallelism.py b/tests/fsdp2_parallelization/test_context_parallelism.py new file mode 100644 index 000000000..98ac84019 --- /dev/null +++ b/tests/fsdp2_parallelization/test_context_parallelism.py @@ -0,0 +1,331 @@ +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType, PydanticPipelineType +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import find_free_port + +_CONFIG_DIR = Path("tests/fsdp2_parallelization/cp_test_configs") + +_PARITY_CONFIGS = { + "cp_2gpu": { + "fsdp2_config": _CONFIG_DIR / "fsdp2_config.yaml", + "cp_config": _CONFIG_DIR / "cp_config.yaml", + "world_size": 2, + }, + "cp_tp_4gpu": { + "fsdp2_config": _CONFIG_DIR / "fsdp2_4gpu_config.yaml", + "cp_config": _CONFIG_DIR / "cp_tp_config.yaml", + "world_size": 4, + }, + "cp_pp_4gpu": { + "fsdp2_config": _CONFIG_DIR / "fsdp2_4gpu_nope_config.yaml", + "cp_config": _CONFIG_DIR / "cp_pp_config.yaml", + "world_size": 4, + }, + "cp_tp_pp_8gpu": { + "fsdp2_config": _CONFIG_DIR / "fsdp2_8gpu_nope_config.yaml", + "cp_config": _CONFIG_DIR / "cp_tp_pp_config.yaml", + "world_size": 8, + }, +} + + +class _Components(BaseModel): + model: PydanticFSDP2ModuleType + device_mesh: PydanticDeviceMeshIFType + + +class _PPComponents(BaseModel): + scheduled_pipeline: PydanticPipelineType + device_mesh: PydanticDeviceMeshIFType + + +def _build(config_path: Path, tmp_path: Path) -> _Components: + return Main(config_path, experiments_root_path=tmp_path).build_components(components_model_type=_Components) + + +def _build_pp(config_path: Path, tmp_path: Path) -> _PPComponents: + return Main(config_path, experiments_root_path=tmp_path).build_components(components_model_type=_PPComponents) + + +def _fixed_input(batch_size: int, seq_len: int, vocab_size: int, device: torch.device) -> torch.Tensor: + torch.manual_seed(0) + return torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + + +def _headtail_input( + inp: torch.Tensor, + cp_rank: int, + cp_degree: int, +) -> tuple[torch.Tensor, int, int]: + """Return (local_input, head_start, tail_start) for HeadTail-sharded CP. + + The HeadTail load balancer interleaves each CP rank's chunk into a head + portion (the first 1/(2·cp) of the full sequence) and a tail portion (the + last 1/(2·cp) of the full sequence, counted from the end), so that every + rank gets an equal mix of early and late tokens. + + Rank r receives: + head = inp[:, r * chunk : (r+1) * chunk] + tail = inp[:, (2*cp-1-r) * chunk : (2*cp-r) * chunk] + + Both portions are concatenated as the rank's local input. + + NOTE: Test configs use RoPE with explicit position_ids passed to the CP model + so that each rank uses the correct global positions instead of a local 0-based + arange. The position_ids are constructed from head_start / tail_start below. + """ + seq_len = inp.shape[1] + chunk = seq_len // (2 * cp_degree) + + head_start = cp_rank * chunk + tail_start = (2 * cp_degree - 1 - cp_rank) * chunk + + local_input = torch.cat( + [ + inp[:, head_start : head_start + chunk], + inp[:, tail_start : tail_start + chunk], + ], + dim=1, + ) + return local_input, head_start, tail_start + + +def _run_cp_logit_match_impl( + process_id: int, + fsdp2_config: Path, + cp_config: Path, + world_size: int, + port: int, + tmp_path: Path, +) -> None: + """Worker run by mp.spawn: verifies that CP (or CP+TP) logits match the FSDP2 baseline. + + The HeadTail dispatcher in PyTorch's experimental ring-attention API works with + HeadTail-interleaved input. Each rank's chunk [head_start, tail_start] maps to + specific global token indices; the ring correctly gates cross-rank K/V using that + interleaving. Feeding sequentially sharded input would apply wrong causal masks + and produce wrong logits. + + Strategy (no all_gather needed): + 1. Build FSDP2 baseline BEFORE the CP class-level patch is applied. + 2. Run FSDP2 forward on the full sequence → reference logits for every token. + 3. Build CP model (applies the class-level patch). + 4. Construct HeadTail-sharded input for this rank; run CP forward. + 5. Each rank independently compares its local CP logits against the matching + rows of the FSDP2 reference. + """ + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=port, + ): + vocab_size = 50304 + seq_len = 128 + batch_size = 2 + device = torch.device(f"cuda:{process_id}") + + # Build FSDP2 baseline BEFORE the CP class-level patch is applied. + torch.manual_seed(42) + fsdp2 = _build(fsdp2_config, tmp_path) + + inp = _fixed_input(batch_size, seq_len, vocab_size, device) + out_fsdp2 = fsdp2.model({"input_ids": inp})["logits"].float() + + # Build CP (or CP+TP) model. This applies the class-level CP patch. + torch.manual_seed(42) + cp = _build(cp_config, tmp_path) + cp_mesh = cp.device_mesh[ParallelismDegrees.CP.value] + cp_rank = dist.get_rank(cp_mesh.get_group()) + cp_degree = cp_mesh.size() + + chunk = seq_len // (2 * cp_degree) + input_cp, head_start, tail_start = _headtail_input(inp, cp_rank, cp_degree) + ref = torch.cat( + [ + out_fsdp2[:, head_start : head_start + chunk, :], + out_fsdp2[:, tail_start : tail_start + chunk, :], + ], + dim=1, + ) + + # Global position indices for this rank's HeadTail-sharded tokens so RoPE uses + # the correct frequencies instead of a local 0-based arange. + position_ids = torch.cat( + [ + torch.arange(head_start, head_start + chunk, device=device), + torch.arange(tail_start, tail_start + chunk, device=device), + ] + ).unsqueeze( + 0 + ) # (1, 2*chunk) + + out_cp_local = cp.model({"input_ids": input_cp, "position_ids": position_ids})["logits"].float() + + assert out_cp_local.shape == ref.shape, f"Shape mismatch: CP={out_cp_local.shape}, ref={ref.shape}" + assert torch.allclose(out_cp_local, ref, atol=1e-5, rtol=1e-4), ( + f"Logit mismatch on CP rank {cp_rank}: " f"max abs diff = {(out_cp_local - ref).abs().max().item():.2e}" + ) + + +def _run_cp_pp_loss_match_impl( + process_id: int, + fsdp2_config: Path, + cp_config: Path, + world_size: int, + port: int, + tmp_path: Path, +) -> None: + """Worker run by mp.spawn: verifies that CP+PP (or CP+TP+PP) losses match the FSDP2 baseline. + + Strategy: + 1. Build FSDP2 (no CP, no PP) baseline BEFORE the CP class-level patch. + Both the FSDP2 and CP+PP configs initialize the full model (staged_pipeline + uses 'initialized_model' as whole_model) so all ranks share identical weights. + 2. Run FSDP2 forward on the full sequence → reference logits for every token. + 3. Build CP+PP model (applies the class-level CP patch). + 4. Construct HeadTail-sharded input/target for this CP rank; run PP schedule forward. + 5. On the last PP stage, compare the CP+PP per-rank loss against the reference loss + computed from slicing the FSDP2 logits to this CP rank's HeadTail token subset. + """ + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=port, + ): + vocab_size = 50304 + seq_len = 128 + batch_size = 2 + device = torch.device(f"cuda:{process_id}") + + # Build FSDP2 baseline BEFORE the CP class-level patch is applied. + torch.manual_seed(42) + fsdp2 = _build(fsdp2_config, tmp_path) + + # Generate a fixed (seq_len+1)-length sequence so we have both input_ids and + # the shifted target_ids required for the CLM cross-entropy loss. + torch.manual_seed(0) + full_seq = torch.randint(0, vocab_size, (batch_size, seq_len + 1), device=device) + inp = full_seq[:, :seq_len] # (batch, seq_len) — input_ids + target = full_seq[:, 1:] # (batch, seq_len) — target_ids + + # Full-sequence forward pass on the FSDP2 baseline. + with torch.no_grad(): + out_fsdp2 = fsdp2.model({"input_ids": inp})["logits"].float() # (batch, seq_len, vocab) + + # Build CP+PP (or CP+TP+PP) model. This applies the class-level CP patch. + torch.manual_seed(42) + cp_pp = _build_pp(cp_config, tmp_path) + + cp_mesh = cp_pp.device_mesh[ParallelismDegrees.CP.value] + cp_rank = dist.get_rank(cp_mesh.get_group()) + cp_degree = cp_mesh.size() + + chunk = seq_len // (2 * cp_degree) + input_cp, head_start, tail_start = _headtail_input(inp, cp_rank, cp_degree) + targets_cp = torch.cat( + [ + target[:, head_start : head_start + chunk], + target[:, tail_start : tail_start + chunk], + ], + dim=1, + ) # (batch, 2*chunk) + + # Reference CE loss over this CP rank's token subset, computed from FSDP2 logits. + ref_logits = torch.cat( + [ + out_fsdp2[:, head_start : head_start + chunk, :], + out_fsdp2[:, tail_start : tail_start + chunk, :], + ], + dim=1, + ) # (batch, 2*chunk, vocab) + ref_loss = F.cross_entropy( + ref_logits.reshape(-1, vocab_size), + targets_cp.reshape(-1).long(), + ) + + # Run CP+PP forward (eval = no-grad forward only). + scheduled_pipeline = cp_pp.scheduled_pipeline + pp_schedule = scheduled_pipeline.pp_schedule + targets_pp, losses = (targets_cp.contiguous(), []) if scheduled_pipeline.has_last_pp_stage else (None, None) + with torch.no_grad(): + if scheduled_pipeline.has_first_pp_stage: + pp_schedule.eval(input_cp.contiguous(), target=targets_pp, losses=losses) + else: + pp_schedule.eval(target=targets_pp, losses=losses) + + if scheduled_pipeline.has_last_pp_stage: + pp_loss = torch.mean(torch.stack(losses)).to(losses[0].device).float() + assert torch.allclose(pp_loss, ref_loss, atol=1e-5, rtol=1e-4), ( + f"Loss mismatch on CP rank {cp_rank}: " + f"PP loss = {pp_loss.item():.6f}, ref = {ref_loss.item():.6f}, " + f"abs diff = {(pp_loss - ref_loss).abs().item():.2e}" + ) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="This test requires at least 2 GPUs", +) +class TestContextParallelism: + def test_cp_output_matches_fsdp2_baseline(self, tmp_path: Path): + cfg = _PARITY_CONFIGS["cp_2gpu"] + mp.spawn( + _run_cp_logit_match_impl, + args=(cfg["fsdp2_config"], cfg["cp_config"], cfg["world_size"], find_free_port(), tmp_path), + nprocs=cfg["world_size"], + join=True, + ) + + @pytest.mark.skipif( + torch.cuda.device_count() < 4, + reason="This test requires at least 4 GPUs", + ) + def test_cp_tp_output_matches_fsdp2_baseline(self, tmp_path: Path): + cfg = _PARITY_CONFIGS["cp_tp_4gpu"] + mp.spawn( + _run_cp_logit_match_impl, + args=(cfg["fsdp2_config"], cfg["cp_config"], cfg["world_size"], find_free_port(), tmp_path), + nprocs=cfg["world_size"], + join=True, + ) + + @pytest.mark.skipif( + torch.cuda.device_count() < 4, + reason="This test requires at least 4 GPUs", + ) + def test_cp_pp_output_matches_fsdp2_baseline(self, tmp_path: Path): + cfg = _PARITY_CONFIGS["cp_pp_4gpu"] + mp.spawn( + _run_cp_pp_loss_match_impl, + args=(cfg["fsdp2_config"], cfg["cp_config"], cfg["world_size"], find_free_port(), tmp_path), + nprocs=cfg["world_size"], + join=True, + ) + + @pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This test requires at least 8 GPUs", + ) + def test_cp_tp_pp_output_matches_fsdp2_baseline(self, tmp_path: Path): + cfg = _PARITY_CONFIGS["cp_tp_pp_8gpu"] + mp.spawn( + _run_cp_pp_loss_match_impl, + args=(cfg["fsdp2_config"], cfg["cp_config"], cfg["world_size"], find_free_port(), tmp_path), + nprocs=cfg["world_size"], + join=True, + ) diff --git a/tests/models/parallelism/test_context_parallel.py b/tests/models/parallelism/test_context_parallel.py new file mode 100644 index 000000000..912df8eff --- /dev/null +++ b/tests/models/parallelism/test_context_parallel.py @@ -0,0 +1,214 @@ +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from modalities.models.gpt2.gpt2_model import CausalSelfAttention +from modalities.models.model_factory import GPT2ModelFactory +from modalities.models.parallelism.context_parallel import ( + apply_cp_to_sdpa_attention_forward, + shard_tensor_buffers_for_context_parallel, +) + +# ── helpers ────────────────────────────────────────────────────────────────── + +_UNSET = object() + + +class _DummyMesh: + device_type = "cuda" + + def size(self) -> int: + return 2 + + +# ── fixture: isolate class-level CP patch across tests ─────────────────────── + + +@pytest.fixture() +def cp_class_patch_isolated(): + """Restore CausalSelfAttention to its exact pre-test state after any test + that touches (or simulates) the class-level CP patch.""" + original_method = CausalSelfAttention.execute_attention + saved_wrapped = getattr(CausalSelfAttention, "_cp_execute_attention_wrapped", _UNSET) + saved_mesh = getattr(CausalSelfAttention, "_cp_mesh", _UNSET) + yield + CausalSelfAttention.execute_attention = original_method + for attr, saved in [("_cp_execute_attention_wrapped", saved_wrapped), ("_cp_mesh", saved_mesh)]: + if saved is _UNSET: + if hasattr(CausalSelfAttention, attr): + delattr(CausalSelfAttention, attr) + else: + setattr(CausalSelfAttention, attr, saved) + + +# ── shard_tensor_buffers_for_context_parallel ──────────────────────────────── + + +def test_shard_tensor_buffers_len_mismatch_raises() -> None: + with pytest.raises(ValueError, match=r"len\(buffers\) == len\(seq_dims\)"): + shard_tensor_buffers_for_context_parallel( + cp_mesh=_DummyMesh(), + buffers=(torch.ones(2, 8),), + seq_dims=(1, 1), + shard_impl=lambda **_: (), + ) + + +def test_shard_tensor_buffers_ptrr_requires_masks_raises() -> None: + with pytest.raises(ValueError, match="PTRR load balancing is not supported"): + shard_tensor_buffers_for_context_parallel( + cp_mesh=_DummyMesh(), + buffers=(torch.ones(2, 8),), + seq_dims=(1,), + load_balancer_type="ptrr", + ) + + +def test_shard_tensor_buffers_uses_injected_impl() -> None: + called = {} + + def fake_shard_impl(*, mesh, buffers, seq_dims, load_balancer): + called["mesh"] = mesh + called["seq_dims"] = seq_dims + called["load_balancer"] = load_balancer + return tuple(t[:, : t.shape[1] // 2].contiguous() for t in buffers) + + sample = torch.arange(16, dtype=torch.int64).view(2, 8) + target = torch.arange(16, dtype=torch.int64).view(2, 8) + + sharded_sample, sharded_target = shard_tensor_buffers_for_context_parallel( + cp_mesh=_DummyMesh(), + buffers=(sample, target), + seq_dims=(1, 1), + load_balancer_type=None, + shard_impl=fake_shard_impl, + ) + + assert called["seq_dims"] == (1, 1) + assert called["load_balancer"] is None + assert sharded_sample.shape == (2, 4) + assert sharded_target.shape == (2, 4) + + +# ── apply_cp_to_sdpa_attention_forward ─────────────────────────────────────── + + +def _get_execute_attention_fn() -> object: + """Return the raw function behind the execute_attention classmethod descriptor. + + Classmethod descriptors return a new bound-method wrapper on every attribute + access, so `is` comparisons on `CausalSelfAttention.execute_attention` always + fail. Comparing the underlying `__func__` (or the descriptor stored in + `__dict__`) gives a stable identity. + """ + descriptor = CausalSelfAttention.__dict__["execute_attention"] + return getattr(descriptor, "__func__", descriptor) + + +def test_apply_cp_empty_modules_does_not_patch(cp_class_patch_isolated) -> None: + """An empty attention_modules list must leave the class untouched.""" + original_fn = _get_execute_attention_fn() + apply_cp_to_sdpa_attention_forward(attention_modules=[], cp_mesh=object()) + assert _get_execute_attention_fn() is original_fn + assert not getattr(CausalSelfAttention, "_cp_execute_attention_wrapped", False) + + +def test_apply_cp_reentry_same_mesh_is_noop(cp_class_patch_isolated) -> None: + """A second call with the same mesh object must return silently.""" + mesh = object() + setattr(CausalSelfAttention, "_cp_execute_attention_wrapped", True) + setattr(CausalSelfAttention, "_cp_mesh", mesh) + original_fn = _get_execute_attention_fn() + + apply_cp_to_sdpa_attention_forward(attention_modules=[MagicMock(spec=nn.Module)], cp_mesh=mesh) + + assert _get_execute_attention_fn() is original_fn + + +def test_apply_cp_reentry_different_mesh_raises(cp_class_patch_isolated) -> None: + """A second call with a different mesh must raise RuntimeError.""" + mesh_a = object() + mesh_b = object() + setattr(CausalSelfAttention, "_cp_execute_attention_wrapped", True) + setattr(CausalSelfAttention, "_cp_mesh", mesh_a) + + with pytest.raises(RuntimeError, match="already patched.*different cp_mesh"): + apply_cp_to_sdpa_attention_forward(attention_modules=[MagicMock(spec=nn.Module)], cp_mesh=mesh_b) + + +def test_apply_cp_patches_classmethod_and_sets_flags(cp_class_patch_isolated) -> None: + """First call must replace execute_attention and set both guard flags.""" + mock_dtensor_module = MagicMock() + mock_attn_module = MagicMock() + + with patch.dict( + sys.modules, + { + "torch.distributed.tensor": mock_dtensor_module, + "torch.distributed.tensor.experimental": MagicMock(), + "torch.distributed.tensor.experimental._attention": mock_attn_module, + }, + ): + original_fn = _get_execute_attention_fn() + mesh = object() + apply_cp_to_sdpa_attention_forward(attention_modules=[MagicMock(spec=nn.Module)], cp_mesh=mesh) + + assert getattr(CausalSelfAttention, "_cp_execute_attention_wrapped", False) is True + assert getattr(CausalSelfAttention, "_cp_mesh", None) is mesh + assert _get_execute_attention_fn() is not original_fn + mock_attn_module._enable_context_parallel_dispatcher.assert_called_once() + + +# ── GPT2ModelFactory._validate_context_parallel_seq_len ────────────────────── + + +def _model(seq_len: int) -> SimpleNamespace: + return SimpleNamespace(sequence_length=seq_len) + + +def test_validate_seq_len_headtail_valid() -> None: + # 256 divisible by cp=2 * tp=1 * 2 = 4 ✓ + GPT2ModelFactory._validate_context_parallel_seq_len(_model(256), cp_degree=2, load_balancer_type="headtail") + + +def test_validate_seq_len_headtail_invalid() -> None: + # 100 % (3 * 2) = 100 % 6 ≠ 0 → should raise + with pytest.raises(ValueError, match="divisible"): + GPT2ModelFactory._validate_context_parallel_seq_len(_model(100), cp_degree=3, load_balancer_type="headtail") + + +def test_validate_seq_len_none_valid_odd_multiple() -> None: + # 6 % (cp=2 * tp=1 * 1) = 0 → passes with None. + # Would fail with headtail (6 % 4 ≠ 0) — this is the bug-fix regression case. + GPT2ModelFactory._validate_context_parallel_seq_len(_model(6), cp_degree=2, load_balancer_type=None) + + +def test_validate_seq_len_none_invalid() -> None: + # 9 % (cp=2) ≠ 0 → raises even without headtail factor + with pytest.raises(ValueError, match="divisible"): + GPT2ModelFactory._validate_context_parallel_seq_len(_model(9), cp_degree=2, load_balancer_type=None) + + +def test_validate_seq_len_headtail_with_tp_valid() -> None: + # 256 % (tp=2 * cp=2 * 2) = 256 % 8 = 0 ✓ + GPT2ModelFactory._validate_context_parallel_seq_len( + _model(256), cp_degree=2, tp_degree=2, load_balancer_type="headtail" + ) + + +def test_validate_seq_len_headtail_with_tp_invalid() -> None: + # 256 % (tp=3 * cp=2 * 2) = 256 % 12 ≠ 0 + with pytest.raises(ValueError, match="divisible"): + GPT2ModelFactory._validate_context_parallel_seq_len( + _model(256), cp_degree=2, tp_degree=3, load_balancer_type="headtail" + ) + + +def test_validate_seq_len_none_with_tp_valid_headtail_would_fail() -> None: + # 9 % (tp=3 * cp=3 * 1) = 0 → passes with None. + # headtail would need 9 % (3*3*2=18) = 9 % 18 ≠ 0 → fails. + GPT2ModelFactory._validate_context_parallel_seq_len(_model(9), cp_degree=3, tp_degree=1, load_balancer_type=None) diff --git a/tests/test_registry_components.py b/tests/test_registry_components.py new file mode 100644 index 000000000..89756f222 --- /dev/null +++ b/tests/test_registry_components.py @@ -0,0 +1,7 @@ +from modalities.registry.components import COMPONENTS + + +def test_gpt2_cp_component_is_registered() -> None: + assert any( + component.component_key == "model" and component.variant_key == "gpt2_cp" for component in COMPONENTS + ), "Expected model variant 'gpt2_cp' to be registered in COMPONENTS." diff --git a/tests/test_trainer_context_parallel.py b/tests/test_trainer_context_parallel.py new file mode 100644 index 000000000..da231b3ce --- /dev/null +++ b/tests/test_trainer_context_parallel.py @@ -0,0 +1,246 @@ +import pytest +import torch + +from modalities.batch import DatasetBatch +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.trainer import Trainer + + +class _DummyPublisher: + def publish_message(self, **_kwargs): + return None + + +class _DummyGradientClipper: + def clip_gradients(self): + return torch.tensor(0.0) + + +class _DummyProfiler: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _MeshEntry: + def __init__(self, degree: int): + self._degree = degree + + def size(self): + return self._degree + + def get_coordinate(self): + return [0] + + +class _DummyDeviceMesh: + mesh_dim_names = [ParallelismDegrees.DP_SHARD.value, ParallelismDegrees.CP.value] + + def __getitem__(self, key: str): + if key in (ParallelismDegrees.DP_SHARD.value, ParallelismDegrees.CP.value): + return _MeshEntry(2) + raise KeyError(key) + + def size(self, index: int): + return 2 + + +class _DummyDeviceMeshCPDegreeOne: + """Device mesh that has a CP dimension, but with degree 1 (effectively disabled).""" + + mesh_dim_names = [ParallelismDegrees.DP_SHARD.value, ParallelismDegrees.CP.value] + + def __getitem__(self, key: str): + if key == ParallelismDegrees.DP_SHARD.value: + return _MeshEntry(2) + if key == ParallelismDegrees.CP.value: + return _MeshEntry(1) + raise KeyError(key) + + def size(self, index: int): + return 2 + + +class _DummyDeviceMeshNoCP: + """Device mesh that has no CP dimension at all.""" + + mesh_dim_names = [ParallelismDegrees.DP_SHARD.value] + + def __getitem__(self, key: str): + if key == ParallelismDegrees.DP_SHARD.value: + return _MeshEntry(2) + raise KeyError(key) + + def size(self, index: int): + return 2 + + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def _make_trainer(device_mesh) -> Trainer: + return Trainer( + global_rank=0, + progress_publisher=_DummyPublisher(), + evaluation_result_publisher=_DummyPublisher(), + gradient_acc_steps=1, + global_num_tokens_per_train_step=1, + device_mesh=device_mesh, + num_seen_train_steps=0, + global_num_seen_tokens=0, + num_target_steps=1, + num_target_tokens=1, + gradient_clipper=_DummyGradientClipper(), + profiler=_DummyProfiler(), + ) + + +def _cuda_batch(seq_len: int = 8) -> DatasetBatch: + dev = torch.cuda.current_device() + return DatasetBatch( + samples={"input_ids": torch.ones(1, seq_len, device=dev, dtype=torch.long)}, + targets={"target_ids": torch.ones(1, seq_len, device=dev, dtype=torch.long)}, + ) + + +# ── CP sharding: load-balancer forwarding ──────────────────────────────────── + + +def test_trainer_passes_model_cp_load_balancer_none(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for context-parallel trainer sharding test") + + captured = {} + + def _fake_shard(*, cp_mesh, buffers, seq_dims, load_balancer_type, shard_impl=None): + captured["load_balancer_type"] = load_balancer_type + return buffers + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMesh()) + batch = _cuda_batch() + trainer._apply_context_parallel_sharding_to_batch_( + batch=batch, sample_key="input_ids", target_key="target_ids", context_parallel_load_balancer=None + ) + + assert captured.get("load_balancer_type") is None + + +def test_trainer_passes_model_cp_load_balancer_headtail(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for context-parallel trainer sharding test") + + captured = {} + + def _fake_shard(*, cp_mesh, buffers, seq_dims, load_balancer_type, shard_impl=None): + captured["load_balancer_type"] = load_balancer_type + return buffers + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMesh()) + batch = _cuda_batch() + trainer._apply_context_parallel_sharding_to_batch_( + batch=batch, sample_key="input_ids", target_key="target_ids", context_parallel_load_balancer="headtail" + ) + + assert captured.get("load_balancer_type") == "headtail" + + +# ── CP sharding: selective buffer inclusion ─────────────────────────────────── + + +def test_cp_sharding_with_sample_key_none_only_shards_target(monkeypatch): + """When sample_key is None (e.g. non-first PP stage), only the target enters the shard call.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + captured = {} + + def _fake_shard(*, cp_mesh, buffers, seq_dims, load_balancer_type, shard_impl=None): + captured["num_buffers"] = len(buffers) + return buffers + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMesh()) + batch = _cuda_batch() + trainer._apply_context_parallel_sharding_to_batch_( + batch=batch, sample_key=None, target_key="target_ids", context_parallel_load_balancer=None + ) + + assert captured.get("num_buffers") == 1 + + +def test_cp_sharding_with_sample_key_includes_position_ids(monkeypatch): + """When sample_key is provided, position_ids and the sample both enter the shard call, + and position_ids is written back to batch.samples.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + captured = {} + + def _fake_shard(*, cp_mesh, buffers, seq_dims, load_balancer_type, shard_impl=None): + captured["num_buffers"] = len(buffers) + return buffers + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMesh()) + batch = _cuda_batch(seq_len=8) + trainer._apply_context_parallel_sharding_to_batch_( + batch=batch, sample_key="input_ids", target_key="target_ids", context_parallel_load_balancer=None + ) + + # position_ids + input_ids + target_ids + assert captured.get("num_buffers") == 3 + assert "position_ids" in batch.samples + assert batch.samples["position_ids"].shape == (1, 8) + + +# ── CP sharding: early-exit conditions ─────────────────────────────────────── + + +def test_cp_sharding_skipped_when_cp_mesh_degree_one(monkeypatch): + """cp_mesh.size() == 1 must result in no sharding call.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + called = {"count": 0} + + def _fake_shard(**_kwargs): + called["count"] += 1 + return _kwargs["buffers"] + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMeshCPDegreeOne()) + trainer._apply_context_parallel_sharding_to_batch_( + batch=_cuda_batch(), sample_key="input_ids", target_key="target_ids", context_parallel_load_balancer=None + ) + + assert called["count"] == 0 + + +def test_cp_sharding_skipped_when_no_cp_in_mesh(monkeypatch): + """A mesh with no CP dimension must skip sharding entirely.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + called = {"count": 0} + + def _fake_shard(**_kwargs): + called["count"] += 1 + return _kwargs["buffers"] + + monkeypatch.setattr("modalities.trainer.shard_tensor_buffers_for_context_parallel", _fake_shard) + + trainer = _make_trainer(_DummyDeviceMeshNoCP()) + trainer._apply_context_parallel_sharding_to_batch_( + batch=_cuda_batch(), sample_key="input_ids", target_key="target_ids", context_parallel_load_balancer=None + ) + + assert called["count"] == 0