@@ -83,16 +83,29 @@ def __call__(
8383 ) -> jax .numpy .ndarray :
8484 return self .forward (X , Y , W )
8585
86- def forward_cpu (
87- self ,
88- L1_in : np .ndarray ,
89- L2_in : np .ndarray ,
90- L3_out : np .ndarray ,
91- weights : np .ndarray ,
92- ) -> None :
86+ def forward_cpu (self , L1_in , L2_in , L3_out , weights ) -> None :
9387 result = self .forward (
9488 jax .numpy .asarray (L1_in ),
9589 jax .numpy .asarray (L2_in ),
9690 jax .numpy .asarray (weights ),
9791 )
98- L3_out [:] = np .asarray (result )
92+ L3_out [:] = np .asarray (result )
93+
94+ def backward_cpu (
95+ self , L1_in , L1_grad , L2_in , L2_grad , L3_grad , weights , weights_grad
96+ ) -> None :
97+ backward_fn = jax .vjp (
98+ lambda X , Y , W : self .forward (X , Y , W ),
99+ jax .numpy .asarray (L1_in ),
100+ jax .numpy .asarray (L2_in ),
101+ jax .numpy .asarray (weights ),
102+ )[1 ]
103+ L1_grad_jax , L2_grad_jax , weights_grad_jax = backward_fn (
104+ jax .numpy .asarray (L3_grad )
105+ )
106+ L1_grad [:] = np .asarray (L1_grad_jax )
107+ L2_grad [:] = np .asarray (L2_grad_jax )
108+ weights_grad [:] = np .asarray (weights_grad_jax )
109+
110+
111+
0 commit comments