1+ import jax
2+ import jax .numpy as jnp
3+ import numpy as np
4+
5+ def reorder_jax_helper (schedule , weights_in , direction , has_batch_dim ):
6+ assert direction in ["forward" , "backward" ]
7+
8+ specs = schedule .weight_reordering_info (weights_in , has_batch_dim )
9+ weights_out = jnp .zeros_like (weights_in )
10+
11+ for spec in specs :
12+ parent_range = spec ["parent_range" ]
13+ parent_shape = spec ["parent_shape" ]
14+ weights_subrange = spec ["weights_subrange" ]
15+ child_range = spec ["child_range" ]
16+ transpose_perm = spec ["transpose_perm" ]
17+
18+ if direction == "forward" :
19+ reshape_size = spec ["reshape_size" ]
20+
21+ sliced_weights = weights_in [parent_range ].reshape (parent_shape )[
22+ weights_subrange
23+ ]
24+
25+ value_to_assign = sliced_weights .transpose (transpose_perm ).reshape (reshape_size )
26+ weights_out = weights_out .at [child_range ].set (value_to_assign )
27+
28+ elif direction == "backward" :
29+ transpose_child_shape = spec ["transpose_child_shape" ]
30+ child_shape = spec ["child_shape" ]
31+
32+ sliced_weights = (
33+ weights_in [child_range ]
34+ .reshape (transpose_child_shape )
35+ .transpose (transpose_perm )
36+ )
37+
38+ value_to_insert = sliced_weights .flatten ().reshape (child_shape )
39+
40+ slab = weights_out [parent_range ]
41+ slab_reshaped = slab .reshape (parent_shape )
42+ slab_reshaped = slab_reshaped .at [weights_subrange ].set (value_to_insert )
43+ weights_out = weights_out .at [parent_range ].set (slab_reshaped .reshape (slab .shape ))
44+
45+ return weights_out
46+
47+ def reorder_numpy_jax_helper (schedule , weights_in , direction , has_batch_dim ):
48+ weights_in_jax = jnp .array (weights_in )
49+ result = reorder_jax_helper (schedule , weights_in_jax , direction , has_batch_dim )
50+ return np .array (result )
51+
52+ def reorder_jax (schedule , weights_in , direction , has_batch_dim ):
53+ if isinstance (weights_in , (jnp .ndarray , jax .Array )):
54+ return reorder_jax_helper (schedule , weights_in , direction , has_batch_dim )
55+ else :
56+ return reorder_numpy_jax_helper (schedule , weights_in , direction , has_batch_dim )
0 commit comments