|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""ImageNet helper functions for benchmarking.""" |
| 15 | + |
| 16 | +import functools |
| 17 | +from typing import Any |
| 18 | + |
| 19 | +from flax.examples.imagenet import models |
| 20 | +from flax.examples.imagenet import train |
| 21 | +import jax |
| 22 | +import jax.numpy as jnp |
| 23 | +import ml_collections |
| 24 | + |
| 25 | + |
| 26 | +def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]: |
| 27 | + """Generate a batch of fake ImageNet data. |
| 28 | +
|
| 29 | + Args: |
| 30 | + batch_size: Number of images in the batch. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + A dictionary with 'image' and 'label' keys. |
| 34 | + """ |
| 35 | + # ImageNet images: (batch_size, 224, 224, 3) |
| 36 | + images = jax.random.uniform( |
| 37 | + jax.random.key(0), (batch_size, 224, 224, 3), dtype=jnp.float32 |
| 38 | + ) |
| 39 | + |
| 40 | + # Labels: integers [0, 1000) |
| 41 | + labels = jax.random.randint( |
| 42 | + jax.random.key(1), (batch_size,), minval=0, maxval=1000, dtype=jnp.int32 |
| 43 | + ) |
| 44 | + |
| 45 | + return {'image': images, 'label': labels} |
| 46 | + |
| 47 | + |
| 48 | +from flax import linen as nn |
| 49 | + |
| 50 | + |
| 51 | +class BenchmarkResNet(models.ResNet): |
| 52 | + """ResNetV1.5 without axis_name in BatchNorm for single-device benchmarking.""" |
| 53 | + |
| 54 | + @nn.compact |
| 55 | + def __call__(self, x, train: bool = True): |
| 56 | + conv = functools.partial(self.conv, use_bias=False, dtype=self.dtype) |
| 57 | + norm = functools.partial( |
| 58 | + nn.BatchNorm, |
| 59 | + use_running_average=not train, |
| 60 | + momentum=0.9, |
| 61 | + epsilon=1e-5, |
| 62 | + dtype=self.dtype, |
| 63 | + axis_name=None, # Changed from 'batch' to None |
| 64 | + ) |
| 65 | + |
| 66 | + x = conv( |
| 67 | + self.num_filters, |
| 68 | + (7, 7), |
| 69 | + (2, 2), |
| 70 | + padding=[(3, 3), (3, 3)], |
| 71 | + name='conv_init', |
| 72 | + )(x) |
| 73 | + x = norm(name='bn_init')(x) |
| 74 | + x = nn.relu(x) |
| 75 | + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') |
| 76 | + for i, block_size in enumerate(self.stage_sizes): |
| 77 | + for j in range(block_size): |
| 78 | + strides = (2, 2) if i > 0 and j == 0 else (1, 1) |
| 79 | + x = self.block_cls( |
| 80 | + self.num_filters * 2**i, |
| 81 | + strides=strides, |
| 82 | + conv=conv, |
| 83 | + norm=norm, |
| 84 | + act=self.act, |
| 85 | + )(x) |
| 86 | + x = jnp.mean(x, axis=(1, 2)) |
| 87 | + x = nn.Dense(self.num_classes, dtype=self.dtype)(x) |
| 88 | + x = jnp.asarray(x, self.dtype) |
| 89 | + return x |
| 90 | + |
| 91 | + |
| 92 | +def get_apply_fn_and_args( |
| 93 | + config: ml_collections.ConfigDict, |
| 94 | +) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: |
| 95 | + """Returns the apply function and args for the given config. |
| 96 | +
|
| 97 | + Args: |
| 98 | + config: The training configuration. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + A tuple of the apply function, args, and kwargs. |
| 102 | + """ |
| 103 | + # Create model (ResNet50 by default in config) |
| 104 | + # We use BenchmarkResNet to avoid axis_name issues in JIT |
| 105 | + if config.model == 'ResNet50': |
| 106 | + model_cls = functools.partial( |
| 107 | + BenchmarkResNet, |
| 108 | + stage_sizes=[3, 4, 6, 3], |
| 109 | + block_cls=models.BottleneckResNetBlock, |
| 110 | + ) |
| 111 | + else: |
| 112 | + # Fallback to original model if not ResNet50 (might fail if it uses axis_name) |
| 113 | + model_cls = getattr(models, config.model) |
| 114 | + |
| 115 | + model = train.create_model( |
| 116 | + model_cls=model_cls, half_precision=config.half_precision |
| 117 | + ) |
| 118 | + |
| 119 | + # Create learning rate function (needed for train_step) |
| 120 | + # We use a dummy function for benchmarking |
| 121 | + learning_rate_fn = lambda step: 0.1 |
| 122 | + |
| 123 | + # Create train state |
| 124 | + rng = jax.random.key(0) |
| 125 | + image_size = 224 |
| 126 | + state = train.create_train_state( |
| 127 | + rng, config, model, image_size, learning_rate_fn |
| 128 | + ) |
| 129 | + |
| 130 | + # Generate fake batch |
| 131 | + batch = get_fake_batch(config.batch_size) |
| 132 | + |
| 133 | + # Return bench_train_step and its arguments |
| 134 | + return ( |
| 135 | + bench_train_step, |
| 136 | + (state, batch, learning_rate_fn), |
| 137 | + {}, |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +@functools.partial(jax.jit, static_argnums=(2,)) |
| 142 | +def bench_train_step(state, batch, learning_rate_fn): |
| 143 | + """Perform a single training step (JIT-compiled, no pmean).""" |
| 144 | + |
| 145 | + def compute_metrics(logits, labels): |
| 146 | + loss = train.cross_entropy_loss(logits, labels) |
| 147 | + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
| 148 | + metrics = { |
| 149 | + 'loss': loss, |
| 150 | + 'accuracy': accuracy, |
| 151 | + } |
| 152 | + # metrics = lax.pmean(metrics, axis_name='batch') # Removed pmean |
| 153 | + return metrics |
| 154 | + |
| 155 | + def loss_fn(params): |
| 156 | + """loss function used for training.""" |
| 157 | + logits, new_model_state = state.apply_fn( |
| 158 | + {'params': params, 'batch_stats': state.batch_stats}, |
| 159 | + batch['image'], |
| 160 | + mutable=['batch_stats'], |
| 161 | + ) |
| 162 | + loss = train.cross_entropy_loss(logits, batch['label']) |
| 163 | + weight_penalty_params = jax.tree_util.tree_leaves(params) |
| 164 | + weight_decay = 0.0001 |
| 165 | + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) |
| 166 | + weight_penalty = weight_decay * 0.5 * weight_l2 |
| 167 | + loss = loss + weight_penalty |
| 168 | + return loss, (new_model_state, logits) |
| 169 | + |
| 170 | + step = state.step |
| 171 | + dynamic_scale = state.dynamic_scale |
| 172 | + lr = learning_rate_fn(step) |
| 173 | + |
| 174 | + if dynamic_scale: |
| 175 | + grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True) |
| 176 | + dynamic_scale, is_fin, aux, grads = grad_fn(state.params) |
| 177 | + else: |
| 178 | + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
| 179 | + aux, grads = grad_fn(state.params) |
| 180 | + |
| 181 | + new_model_state, logits = aux[1] |
| 182 | + metrics = compute_metrics(logits, batch['label']) |
| 183 | + metrics['learning_rate'] = lr |
| 184 | + |
| 185 | + new_state = state.apply_gradients( |
| 186 | + grads=grads, |
| 187 | + batch_stats=new_model_state['batch_stats'], |
| 188 | + ) |
| 189 | + if dynamic_scale: |
| 190 | + new_state = new_state.replace( |
| 191 | + opt_state=jax.tree_util.tree_map( |
| 192 | + functools.partial(jnp.where, is_fin), |
| 193 | + new_state.opt_state, |
| 194 | + state.opt_state, |
| 195 | + ), |
| 196 | + params=jax.tree_util.tree_map( |
| 197 | + functools.partial(jnp.where, is_fin), new_state.params, state.params |
| 198 | + ), |
| 199 | + dynamic_scale=dynamic_scale, |
| 200 | + ) |
| 201 | + metrics['scale'] = dynamic_scale.scale |
| 202 | + |
| 203 | + return new_state, metrics |
0 commit comments