@@ -40,26 +40,27 @@ namespace cubool {
4040 assert (this ->getNrows () == M);
4141 assert (this ->getNcols () == N);
4242
43+ if (!accumulate) {
44+ // Clear all values
45+ this ->mMatrixImpl .zero_dim ();
46+ }
47+
4348 if (a->isMatrixEmpty () || b->isMatrixEmpty ()) {
44- // A or B has no values
49+ // Return empty matrix
4550 return ;
4651 }
4752
48- CHECK_RAISE_ERROR (accumulate, NotImplemented, " Supported only accumulated multiplication" );
53+ // Ensure csr proper csr format even if empty
54+ a->resizeStorageToDim ();
55+ b->resizeStorageToDim ();
56+ this ->resizeStorageToDim ();
4957
50- if (accumulate) {
51- // Ensure csr proper csr format even if empty
52- a->resizeStorageToDim ();
53- b->resizeStorageToDim ();
54- this ->resizeStorageToDim ();
58+ // Call backend r = c + a * b implementation, as C this is passed
59+ nsparse::spgemm_functor_t <bool , index, DeviceAlloc<index>> spgemmFunctor;
60+ auto result = spgemmFunctor (mMatrixImpl , a->mMatrixImpl , b->mMatrixImpl );
5561
56- // Call backend r = c + a * b implementation, as C this is passed
57- nsparse::spgemm_functor_t <bool , index, DeviceAlloc<index>> spgemmFunctor;
58- auto result = spgemmFunctor (mMatrixImpl , a->mMatrixImpl , b->mMatrixImpl );
59-
60- // Assign result to this
61- this ->mMatrixImpl = std::move (result);
62- }
62+ // Assign result to this
63+ this ->mMatrixImpl = std::move (result);
6364 }
6465
6566}
0 commit comments