@@ -866,6 +866,85 @@ def test_generator(self):
866866 dpnp .stack (map (lambda x : x , dpnp .ones ((3 , 2 ))))
867867
868868
869+ # numpy.unstack() is available since numpy >= 2.1
870+ @testing .with_requires ("numpy>=2.1" )
871+ class TestUnstack :
872+ def test_non_array_input (self ):
873+ with pytest .raises (TypeError ):
874+ dpnp .unstack (1 )
875+
876+ @pytest .mark .parametrize (
877+ "input" , [([1 , 2 , 3 ],), [dpnp .int32 (1 ), dpnp .int32 (2 ), dpnp .int32 (3 )]]
878+ )
879+ def test_scalar_input (self , input ):
880+ with pytest .raises (TypeError ):
881+ dpnp .unstack (input )
882+
883+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
884+ def test_0d_array_input (self , dtype ):
885+ np_a = numpy .array (1 , dtype = dtype )
886+ dp_a = dpnp .array (np_a , dtype = dtype )
887+
888+ with pytest .raises (ValueError ):
889+ numpy .unstack (np_a )
890+ with pytest .raises (ValueError ):
891+ dpnp .unstack (dp_a )
892+
893+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
894+ def test_1d_array (self , dtype ):
895+ np_a = numpy .array ([1 , 2 , 3 ], dtype = dtype )
896+ dp_a = dpnp .array (np_a , dtype = dtype )
897+
898+ np_res = numpy .unstack (np_a )
899+ dp_res = dpnp .unstack (dp_a )
900+ assert len (dp_res ) == len (np_res )
901+ for dp_arr , np_arr in zip (dp_res , np_res ):
902+ assert_array_equal (dp_arr .asnumpy (), np_arr )
903+
904+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
905+ def test_2d_array (self , dtype ):
906+ np_a = numpy .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype )
907+ dp_a = dpnp .array (np_a , dtype = dtype )
908+
909+ np_res = numpy .unstack (np_a , axis = 0 )
910+ dp_res = dpnp .unstack (dp_a , axis = 0 )
911+ assert len (dp_res ) == len (np_res )
912+ for dp_arr , np_arr in zip (dp_res , np_res ):
913+ assert_array_equal (dp_arr .asnumpy (), np_arr )
914+
915+ @pytest .mark .parametrize ("axis" , [0 , 1 , - 1 ])
916+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
917+ def test_2d_array_axis (self , axis , dtype ):
918+ np_a = numpy .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype )
919+ dp_a = dpnp .array (np_a , dtype = dtype )
920+
921+ np_res = numpy .unstack (np_a , axis = axis )
922+ dp_res = dpnp .unstack (dp_a , axis = axis )
923+ assert len (dp_res ) == len (np_res )
924+ for dp_arr , np_arr in zip (dp_res , np_res ):
925+ assert_array_equal (dp_arr .asnumpy (), np_arr )
926+
927+ @pytest .mark .parametrize ("axis" , [2 , - 3 ])
928+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
929+ def test_invalid_axis (self , axis , dtype ):
930+ np_a = numpy .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype )
931+ dp_a = dpnp .array (np_a , dtype = dtype )
932+
933+ with pytest .raises (AxisError ):
934+ numpy .unstack (np_a , axis = axis )
935+ with pytest .raises (AxisError ):
936+ dpnp .unstack (dp_a , axis = axis )
937+
938+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
939+ def test_empty_array (self , dtype ):
940+ np_a = numpy .array ([], dtype = dtype )
941+ dp_a = dpnp .array (np_a , dtype = dtype )
942+
943+ np_res = numpy .unstack (np_a )
944+ dp_res = dpnp .unstack (dp_a )
945+ assert len (dp_res ) == len (np_res )
946+
947+
869948class TestVstack :
870949 def test_non_iterable (self ):
871950 assert_raises (TypeError , dpnp .vstack , 1 )
0 commit comments