Skip to content

Commit de84eb1

Browse files
author
Flax Authors
committed
Merge pull request google#5057 from jorisSchaller:feat/nnx_pooling
PiperOrigin-RevId: 842388144
2 parents b8aa99d + 004b932 commit de84eb1

8 files changed

Lines changed: 140 additions & 102 deletions

File tree

docs_nnx/api_reference/flax.nnx/nn/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
1414
linear
1515
lora
1616
normalization
17+
pooling
1718
recurrent
1819
stochastic
1920

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Pooling
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autofunction:: avg_pool
8+
.. autofunction:: max_pool
9+
.. autofunction:: min_pool
10+
.. autofunction:: pool

flax/core/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
swish as swish,
3636
tanh as tanh,
3737
)
38-
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
38+
from flax.pooling import avg_pool as avg_pool, max_pool as max_pool
3939
from .attention import (
4040
dot_product_attention as dot_product_attention,
4141
multi_head_dot_product_attention as multi_head_dot_product_attention,

flax/linen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
SpectralNorm as SpectralNorm,
124124
WeightNorm as WeightNorm,
125125
)
126-
from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool)
126+
from ..pooling import avg_pool as avg_pool, max_pool as max_pool, pool as pool
127127
from .recurrent import (
128128
Bidirectional as Bidirectional,
129129
ConvLSTMCell as ConvLSTMCell,

flax/nnx/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
from flax.core.spmd import logical_axis_rules as logical_axis_rules
16-
from flax.linen.pooling import avg_pool as avg_pool
17-
from flax.linen.pooling import max_pool as max_pool
18-
from flax.linen.pooling import min_pool as min_pool
19-
from flax.linen.pooling import pool as pool
16+
from flax.pooling import avg_pool as avg_pool
17+
from flax.pooling import max_pool as max_pool
18+
from flax.pooling import min_pool as min_pool
19+
from flax.pooling import pool as pool
2020
from flax.typing import Initializer as Initializer
2121

2222
from .bridge import wrappers as wrappers
File renamed without changes.

tests/linen/linen_test.py

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -40,102 +40,6 @@ def check_eq(xs, ys):
4040
)
4141

4242

43-
class PoolTest(parameterized.TestCase):
44-
def test_pool_custom_reduce(self):
45-
x = jnp.full((1, 3, 3, 1), 2.0)
46-
mul_reduce = lambda x, y: x * y
47-
y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
48-
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))
49-
50-
@parameterized.parameters(
51-
{'count_include_pad': True}, {'count_include_pad': False}
52-
)
53-
def test_avg_pool(self, count_include_pad):
54-
x = jnp.full((1, 3, 3, 1), 2.0)
55-
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
56-
y = pool(x)
57-
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
58-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
59-
expected_grad = jnp.array(
60-
[
61-
[0.25, 0.5, 0.25],
62-
[0.5, 1.0, 0.5],
63-
[0.25, 0.5, 0.25],
64-
]
65-
).reshape((1, 3, 3, 1))
66-
np.testing.assert_allclose(y_grad, expected_grad)
67-
68-
@parameterized.parameters(
69-
{'count_include_pad': True}, {'count_include_pad': False}
70-
)
71-
def test_avg_pool_no_batch(self, count_include_pad):
72-
x = jnp.full((3, 3, 1), 2.0)
73-
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
74-
y = pool(x)
75-
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
76-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
77-
expected_grad = jnp.array(
78-
[
79-
[0.25, 0.5, 0.25],
80-
[0.5, 1.0, 0.5],
81-
[0.25, 0.5, 0.25],
82-
]
83-
).reshape((3, 3, 1))
84-
np.testing.assert_allclose(y_grad, expected_grad)
85-
86-
def test_max_pool(self):
87-
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
88-
pool = lambda x: nn.max_pool(x, (2, 2))
89-
expected_y = jnp.array(
90-
[
91-
[4.0, 5.0],
92-
[7.0, 8.0],
93-
]
94-
).reshape((1, 2, 2, 1))
95-
y = pool(x)
96-
np.testing.assert_allclose(y, expected_y)
97-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
98-
expected_grad = jnp.array(
99-
[
100-
[0.0, 0.0, 0.0],
101-
[0.0, 1.0, 1.0],
102-
[0.0, 1.0, 1.0],
103-
]
104-
).reshape((1, 3, 3, 1))
105-
np.testing.assert_allclose(y_grad, expected_grad)
106-
107-
@parameterized.parameters(
108-
{'count_include_pad': True}, {'count_include_pad': False}
109-
)
110-
def test_avg_pool_padding_same(self, count_include_pad):
111-
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
112-
pool = lambda x: nn.avg_pool(
113-
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
114-
)
115-
y = pool(x)
116-
if count_include_pad:
117-
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
118-
(1, 2, 2, 1)
119-
)
120-
else:
121-
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
122-
(1, 2, 2, 1)
123-
)
124-
np.testing.assert_allclose(y, expected_y)
125-
126-
def test_pooling_variable_batch_dims(self):
127-
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
128-
y = nn.max_pool(x, (2, 2), (2, 2))
129-
130-
assert y.shape == (1, 8, 16, 16, 3)
131-
132-
def test_pooling_no_batch_dims(self):
133-
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
134-
y = nn.max_pool(x, (2, 2), (2, 2))
135-
136-
assert y.shape == (16, 16, 3)
137-
138-
13943
class NormalizationTest(parameterized.TestCase):
14044
def test_layer_norm_mask(self):
14145
key = random.key(0)

tests/pooling_test.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for flax.pooling."""
16+
17+
from absl.testing import absltest
18+
from flax.pooling import pool, avg_pool, max_pool
19+
import numpy as np
20+
import jax.numpy as jnp
21+
from absl.testing import parameterized
22+
import jax
23+
24+
jax.config.parse_flags_with_absl()
25+
26+
27+
class PoolTest(parameterized.TestCase):
28+
def test_pool_custom_reduce(self):
29+
x = jnp.full((1, 3, 3, 1), 2.0)
30+
mul_reduce = lambda x, y: x * y
31+
y = pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
32+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))
33+
34+
@parameterized.parameters(
35+
{'count_include_pad': True}, {'count_include_pad': False}
36+
)
37+
def test_avg_pool(self, count_include_pad):
38+
x = jnp.full((1, 3, 3, 1), 2.0)
39+
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
40+
y = pool(x)
41+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
42+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
43+
expected_grad = jnp.array(
44+
[
45+
[0.25, 0.5, 0.25],
46+
[0.5, 1.0, 0.5],
47+
[0.25, 0.5, 0.25],
48+
]
49+
).reshape((1, 3, 3, 1))
50+
np.testing.assert_allclose(y_grad, expected_grad)
51+
52+
@parameterized.parameters(
53+
{'count_include_pad': True}, {'count_include_pad': False}
54+
)
55+
def test_avg_pool_no_batch(self, count_include_pad):
56+
x = jnp.full((3, 3, 1), 2.0)
57+
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
58+
y = pool(x)
59+
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
60+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
61+
expected_grad = jnp.array(
62+
[
63+
[0.25, 0.5, 0.25],
64+
[0.5, 1.0, 0.5],
65+
[0.25, 0.5, 0.25],
66+
]
67+
).reshape((3, 3, 1))
68+
np.testing.assert_allclose(y_grad, expected_grad)
69+
70+
def test_max_pool(self):
71+
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
72+
pool = lambda x: max_pool(x, (2, 2))
73+
expected_y = jnp.array(
74+
[
75+
[4.0, 5.0],
76+
[7.0, 8.0],
77+
]
78+
).reshape((1, 2, 2, 1))
79+
y = pool(x)
80+
np.testing.assert_allclose(y, expected_y)
81+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
82+
expected_grad = jnp.array(
83+
[
84+
[0.0, 0.0, 0.0],
85+
[0.0, 1.0, 1.0],
86+
[0.0, 1.0, 1.0],
87+
]
88+
).reshape((1, 3, 3, 1))
89+
np.testing.assert_allclose(y_grad, expected_grad)
90+
91+
@parameterized.parameters(
92+
{'count_include_pad': True}, {'count_include_pad': False}
93+
)
94+
def test_avg_pool_padding_same(self, count_include_pad):
95+
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
96+
pool = lambda x: avg_pool(
97+
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
98+
)
99+
y = pool(x)
100+
if count_include_pad:
101+
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
102+
(1, 2, 2, 1)
103+
)
104+
else:
105+
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
106+
(1, 2, 2, 1)
107+
)
108+
np.testing.assert_allclose(y, expected_y)
109+
110+
def test_pooling_variable_batch_dims(self):
111+
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
112+
y = max_pool(x, (2, 2), (2, 2))
113+
114+
assert y.shape == (1, 8, 16, 16, 3)
115+
116+
def test_pooling_no_batch_dims(self):
117+
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
118+
y = max_pool(x, (2, 2), (2, 2))
119+
120+
assert y.shape == (16, 16, 3)
121+
122+
if __name__ == '__main__':
123+
absltest.main()

0 commit comments

Comments
 (0)