Skip to content

Commit 2204c82

Browse files
danielsuoFlax Authors
authored andcommitted
[jax:benchmarks] Add tracing/lowering benchmarks for a few flax examples.
PiperOrigin-RevId: 841957927
1 parent 872d50b commit 2204c82

15 files changed

Lines changed: 1743 additions & 0 deletions

benchmarks/tracing/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Tracing and lowering benchmarks for Flax examples
2+
3+
See Flax
4+
[documentation](https://flax.readthedocs.io/en/latest/examples/index.html) on
5+
their examples.
6+
7+
## Getting started
8+
bash
9+
```
10+
pip install -r benchmarks/tracing/requirements.txt
11+
12+
# Benchmark trace and lower timing for all workloads.
13+
python tracing_benchmark.py
14+
15+
# Profile a single example.
16+
python tracing_benchmark.py --example=wmt
17+
18+
# Profile just tracing for a single example.
19+
python tracing_benchmark.py --example=wmt --mode=trace
20+
```

benchmarks/tracing/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

benchmarks/tracing/gemma.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Gemma helper functions."""
15+
16+
from typing import Any
17+
18+
from flax import nnx
19+
from flax.examples.gemma import train
20+
from flax.examples.gemma import transformer as transformer_lib
21+
from flax.examples.gemma import utils
22+
import jax
23+
import jax.numpy as jnp
24+
import ml_collections
25+
import optax
26+
27+
28+
def get_fake_batch(batch_size: int) -> Any:
29+
"""Returns fake data for the given batch size.
30+
31+
Args:
32+
batch_size: The global batch size to generate.
33+
34+
Returns:
35+
A properly sharded global batch of data.
36+
"""
37+
rng = jax.random.PRNGKey(0)
38+
batch = {}
39+
for k in (
40+
'inputs',
41+
'inputs_position',
42+
'inputs_segmentation',
43+
'targets',
44+
'targets_position',
45+
'targets_segmentation',
46+
):
47+
batch[k] = jax.random.randint(rng, (batch_size, 128), 0, 9999999, jnp.int32)
48+
return batch
49+
50+
51+
def get_apply_fn_and_args(
52+
config: ml_collections.ConfigDict,
53+
vocab_size: int | None = None,
54+
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
55+
"""Returns the apply function and args for the given config.
56+
57+
Args:
58+
config: The training configuration.
59+
vocab_size: The vocabulary size. If None, it will be read from the config.
60+
61+
Returns:
62+
A tuple of the apply function, args and kwargs for the apply function, and
63+
any metadata the training loop needs.
64+
"""
65+
if vocab_size is None:
66+
vocab_size = config.vocab_size
67+
68+
# Build Model and Optimizer
69+
# ---------------------------------------------------------------------------
70+
if config.transformer_name is not None:
71+
model_config = transformer_lib.TransformerConfig.from_version_name(
72+
config.transformer_name,
73+
num_embed=vocab_size,
74+
dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
75+
axis_rules=config.axis_rules,
76+
)
77+
else:
78+
assert config.transformer_params is not None
79+
model_config = transformer_lib.TransformerConfig.from_dict(
80+
**config.transformer_params,
81+
num_embed=vocab_size,
82+
dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
83+
axis_rules=config.axis_rules,
84+
)
85+
86+
# Mesh definition
87+
devices_array = utils.create_device_mesh(config)
88+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
89+
90+
rng = jax.random.PRNGKey(config.seed)
91+
rng, init_rng = jax.random.split(rng)
92+
93+
def constructor(config: transformer_lib.TransformerConfig, key: jax.Array):
94+
return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key))
95+
96+
learning_rate_fn = train.create_learning_rate_schedule(
97+
learning_rate=config.learning_rate, warmup_steps=config.warmup_steps
98+
)
99+
100+
optimizer = optax.adamw(
101+
learning_rate_fn,
102+
b1=0.9,
103+
b2=0.98,
104+
eps=1e-9,
105+
weight_decay=config.weight_decay,
106+
)
107+
108+
state, state_sharding = utils.setup_initial_state(
109+
constructor, optimizer, model_config, init_rng, mesh
110+
)
111+
data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding))
112+
jit_train_step = jax.jit(
113+
train.train_step,
114+
in_shardings=(
115+
state_sharding,
116+
data_sharding,
117+
), # type: ignore
118+
out_shardings=(state_sharding, None), # type: ignore
119+
static_argnames=('learning_rate_fn', 'label_smoothing'),
120+
donate_argnums=0,
121+
)
122+
123+
batch = get_fake_batch(config.per_device_batch_size)
124+
batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch)
125+
126+
return (
127+
jit_train_step,
128+
(state, batch, learning_rate_fn, 0.0),
129+
dict(),
130+
)

benchmarks/tracing/imagenet.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ImageNet helper functions for benchmarking."""
15+
16+
import functools
17+
from typing import Any
18+
19+
from flax.examples.imagenet import models
20+
from flax.examples.imagenet import train
21+
import jax
22+
import jax.numpy as jnp
23+
import ml_collections
24+
25+
26+
def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]:
27+
"""Generate a batch of fake ImageNet data.
28+
29+
Args:
30+
batch_size: Number of images in the batch.
31+
32+
Returns:
33+
A dictionary with 'image' and 'label' keys.
34+
"""
35+
# ImageNet images: (batch_size, 224, 224, 3)
36+
images = jax.random.uniform(
37+
jax.random.key(0), (batch_size, 224, 224, 3), dtype=jnp.float32
38+
)
39+
40+
# Labels: integers [0, 1000)
41+
labels = jax.random.randint(
42+
jax.random.key(1), (batch_size,), minval=0, maxval=1000, dtype=jnp.int32
43+
)
44+
45+
return {'image': images, 'label': labels}
46+
47+
48+
from flax import linen as nn
49+
50+
51+
class BenchmarkResNet(models.ResNet):
52+
"""ResNetV1.5 without axis_name in BatchNorm for single-device benchmarking."""
53+
54+
@nn.compact
55+
def __call__(self, x, train: bool = True):
56+
conv = functools.partial(self.conv, use_bias=False, dtype=self.dtype)
57+
norm = functools.partial(
58+
nn.BatchNorm,
59+
use_running_average=not train,
60+
momentum=0.9,
61+
epsilon=1e-5,
62+
dtype=self.dtype,
63+
axis_name=None, # Changed from 'batch' to None
64+
)
65+
66+
x = conv(
67+
self.num_filters,
68+
(7, 7),
69+
(2, 2),
70+
padding=[(3, 3), (3, 3)],
71+
name='conv_init',
72+
)(x)
73+
x = norm(name='bn_init')(x)
74+
x = nn.relu(x)
75+
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
76+
for i, block_size in enumerate(self.stage_sizes):
77+
for j in range(block_size):
78+
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
79+
x = self.block_cls(
80+
self.num_filters * 2**i,
81+
strides=strides,
82+
conv=conv,
83+
norm=norm,
84+
act=self.act,
85+
)(x)
86+
x = jnp.mean(x, axis=(1, 2))
87+
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
88+
x = jnp.asarray(x, self.dtype)
89+
return x
90+
91+
92+
def get_apply_fn_and_args(
93+
config: ml_collections.ConfigDict,
94+
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
95+
"""Returns the apply function and args for the given config.
96+
97+
Args:
98+
config: The training configuration.
99+
100+
Returns:
101+
A tuple of the apply function, args, and kwargs.
102+
"""
103+
# Create model (ResNet50 by default in config)
104+
# We use BenchmarkResNet to avoid axis_name issues in JIT
105+
if config.model == 'ResNet50':
106+
model_cls = functools.partial(
107+
BenchmarkResNet,
108+
stage_sizes=[3, 4, 6, 3],
109+
block_cls=models.BottleneckResNetBlock,
110+
)
111+
else:
112+
# Fallback to original model if not ResNet50 (might fail if it uses axis_name)
113+
model_cls = getattr(models, config.model)
114+
115+
model = train.create_model(
116+
model_cls=model_cls, half_precision=config.half_precision
117+
)
118+
119+
# Create learning rate function (needed for train_step)
120+
# We use a dummy function for benchmarking
121+
learning_rate_fn = lambda step: 0.1
122+
123+
# Create train state
124+
rng = jax.random.key(0)
125+
image_size = 224
126+
state = train.create_train_state(
127+
rng, config, model, image_size, learning_rate_fn
128+
)
129+
130+
# Generate fake batch
131+
batch = get_fake_batch(config.batch_size)
132+
133+
# Return bench_train_step and its arguments
134+
return (
135+
bench_train_step,
136+
(state, batch, learning_rate_fn),
137+
{},
138+
)
139+
140+
141+
@functools.partial(jax.jit, static_argnums=(2,))
142+
def bench_train_step(state, batch, learning_rate_fn):
143+
"""Perform a single training step (JIT-compiled, no pmean)."""
144+
145+
def compute_metrics(logits, labels):
146+
loss = train.cross_entropy_loss(logits, labels)
147+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
148+
metrics = {
149+
'loss': loss,
150+
'accuracy': accuracy,
151+
}
152+
# metrics = lax.pmean(metrics, axis_name='batch') # Removed pmean
153+
return metrics
154+
155+
def loss_fn(params):
156+
"""loss function used for training."""
157+
logits, new_model_state = state.apply_fn(
158+
{'params': params, 'batch_stats': state.batch_stats},
159+
batch['image'],
160+
mutable=['batch_stats'],
161+
)
162+
loss = train.cross_entropy_loss(logits, batch['label'])
163+
weight_penalty_params = jax.tree_util.tree_leaves(params)
164+
weight_decay = 0.0001
165+
weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1)
166+
weight_penalty = weight_decay * 0.5 * weight_l2
167+
loss = loss + weight_penalty
168+
return loss, (new_model_state, logits)
169+
170+
step = state.step
171+
dynamic_scale = state.dynamic_scale
172+
lr = learning_rate_fn(step)
173+
174+
if dynamic_scale:
175+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
176+
dynamic_scale, is_fin, aux, grads = grad_fn(state.params)
177+
else:
178+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
179+
aux, grads = grad_fn(state.params)
180+
181+
new_model_state, logits = aux[1]
182+
metrics = compute_metrics(logits, batch['label'])
183+
metrics['learning_rate'] = lr
184+
185+
new_state = state.apply_gradients(
186+
grads=grads,
187+
batch_stats=new_model_state['batch_stats'],
188+
)
189+
if dynamic_scale:
190+
new_state = new_state.replace(
191+
opt_state=jax.tree_util.tree_map(
192+
functools.partial(jnp.where, is_fin),
193+
new_state.opt_state,
194+
state.opt_state,
195+
),
196+
params=jax.tree_util.tree_map(
197+
functools.partial(jnp.where, is_fin), new_state.params, state.params
198+
),
199+
dynamic_scale=dynamic_scale,
200+
)
201+
metrics['scale'] = dynamic_scale.scale
202+
203+
return new_state, metrics

0 commit comments

Comments
 (0)