Skip to content

Commit 3139e96

Browse files
author
Flax Team
committed
Copybara import of the project:
-- 004b932 by jorisSchaller <71265553+jorisSchaller@users.noreply.github.com>: Refactor pooling operation - Pooling operation moved out of Linen - Pooling documentation duplicated in nnx PiperOrigin-RevId: 842406713
1 parent de84eb1 commit 3139e96

8 files changed

Lines changed: 102 additions & 140 deletions

File tree

docs_nnx/api_reference/flax.nnx/nn/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
1414
linear
1515
lora
1616
normalization
17-
pooling
1817
recurrent
1918
stochastic
2019

docs_nnx/api_reference/flax.nnx/nn/pooling.rst

Lines changed: 0 additions & 10 deletions
This file was deleted.

flax/core/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
swish as swish,
3636
tanh as tanh,
3737
)
38-
from flax.pooling import avg_pool as avg_pool, max_pool as max_pool
38+
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
3939
from .attention import (
4040
dot_product_attention as dot_product_attention,
4141
multi_head_dot_product_attention as multi_head_dot_product_attention,

flax/linen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
SpectralNorm as SpectralNorm,
124124
WeightNorm as WeightNorm,
125125
)
126-
from ..pooling import avg_pool as avg_pool, max_pool as max_pool, pool as pool
126+
from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool)
127127
from .recurrent import (
128128
Bidirectional as Bidirectional,
129129
ConvLSTMCell as ConvLSTMCell,
File renamed without changes.

flax/nnx/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
from flax.core.spmd import logical_axis_rules as logical_axis_rules
16-
from flax.pooling import avg_pool as avg_pool
17-
from flax.pooling import max_pool as max_pool
18-
from flax.pooling import min_pool as min_pool
19-
from flax.pooling import pool as pool
16+
from flax.linen.pooling import avg_pool as avg_pool
17+
from flax.linen.pooling import max_pool as max_pool
18+
from flax.linen.pooling import min_pool as min_pool
19+
from flax.linen.pooling import pool as pool
2020
from flax.typing import Initializer as Initializer
2121

2222
from .bridge import wrappers as wrappers

tests/linen/linen_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,102 @@ def check_eq(xs, ys):
4040
)
4141

4242

43+
class PoolTest(parameterized.TestCase):
44+
def test_pool_custom_reduce(self):
45+
x = jnp.full((1, 3, 3, 1), 2.0)
46+
mul_reduce = lambda x, y: x * y
47+
y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
48+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))
49+
50+
@parameterized.parameters(
51+
{'count_include_pad': True}, {'count_include_pad': False}
52+
)
53+
def test_avg_pool(self, count_include_pad):
54+
x = jnp.full((1, 3, 3, 1), 2.0)
55+
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
56+
y = pool(x)
57+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
58+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
59+
expected_grad = jnp.array(
60+
[
61+
[0.25, 0.5, 0.25],
62+
[0.5, 1.0, 0.5],
63+
[0.25, 0.5, 0.25],
64+
]
65+
).reshape((1, 3, 3, 1))
66+
np.testing.assert_allclose(y_grad, expected_grad)
67+
68+
@parameterized.parameters(
69+
{'count_include_pad': True}, {'count_include_pad': False}
70+
)
71+
def test_avg_pool_no_batch(self, count_include_pad):
72+
x = jnp.full((3, 3, 1), 2.0)
73+
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
74+
y = pool(x)
75+
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
76+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
77+
expected_grad = jnp.array(
78+
[
79+
[0.25, 0.5, 0.25],
80+
[0.5, 1.0, 0.5],
81+
[0.25, 0.5, 0.25],
82+
]
83+
).reshape((3, 3, 1))
84+
np.testing.assert_allclose(y_grad, expected_grad)
85+
86+
def test_max_pool(self):
87+
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
88+
pool = lambda x: nn.max_pool(x, (2, 2))
89+
expected_y = jnp.array(
90+
[
91+
[4.0, 5.0],
92+
[7.0, 8.0],
93+
]
94+
).reshape((1, 2, 2, 1))
95+
y = pool(x)
96+
np.testing.assert_allclose(y, expected_y)
97+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
98+
expected_grad = jnp.array(
99+
[
100+
[0.0, 0.0, 0.0],
101+
[0.0, 1.0, 1.0],
102+
[0.0, 1.0, 1.0],
103+
]
104+
).reshape((1, 3, 3, 1))
105+
np.testing.assert_allclose(y_grad, expected_grad)
106+
107+
@parameterized.parameters(
108+
{'count_include_pad': True}, {'count_include_pad': False}
109+
)
110+
def test_avg_pool_padding_same(self, count_include_pad):
111+
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
112+
pool = lambda x: nn.avg_pool(
113+
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
114+
)
115+
y = pool(x)
116+
if count_include_pad:
117+
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
118+
(1, 2, 2, 1)
119+
)
120+
else:
121+
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
122+
(1, 2, 2, 1)
123+
)
124+
np.testing.assert_allclose(y, expected_y)
125+
126+
def test_pooling_variable_batch_dims(self):
127+
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
128+
y = nn.max_pool(x, (2, 2), (2, 2))
129+
130+
assert y.shape == (1, 8, 16, 16, 3)
131+
132+
def test_pooling_no_batch_dims(self):
133+
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
134+
y = nn.max_pool(x, (2, 2), (2, 2))
135+
136+
assert y.shape == (16, 16, 3)
137+
138+
43139
class NormalizationTest(parameterized.TestCase):
44140
def test_layer_norm_mask(self):
45141
key = random.key(0)

tests/pooling_test.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)