1+ /* *********************************************************************************/
2+ /* MIT License */
3+ /* */
4+ /* Copyright (c) 2020, 2021 JetBrains-Research */
5+ /* */
6+ /* Permission is hereby granted, free of charge, to any person obtaining a copy */
7+ /* of this software and associated documentation files (the "Software"), to deal */
8+ /* in the Software without restriction, including without limitation the rights */
9+ /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
10+ /* copies of the Software, and to permit persons to whom the Software is */
11+ /* furnished to do so, subject to the following conditions: */
12+ /* */
13+ /* The above copyright notice and this permission notice shall be included in all */
14+ /* copies or substantial portions of the Software. */
15+ /* */
16+ /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
17+ /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
18+ /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
19+ /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
20+ /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
21+ /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
22+ /* SOFTWARE. */
23+ /* *********************************************************************************/
24+
25+ #include < cuda/cuda_vector.hpp>
26+ #include < core/error.hpp>
27+ #include < utils/data_utils.hpp>
28+
29+ namespace cubool {
30+
31+ CudaVector::CudaVector (size_t nrows, CudaInstance &instance)
32+ : mVectorImpl (nrows), mInstance (instance) {
33+
34+ }
35+
36+ void CudaVector::setElement (index i) {
37+ RAISE_ERROR (NotImplemented, " This function is not supported for this vector class" );
38+ }
39+
40+ void CudaVector::build (const index *rows, size_t nvals, bool isSorted, bool noDuplicates) {
41+ if (nvals == 0 ) {
42+ // Empty vector, no values (but preserve dim)
43+ mVectorImpl = VectorImplType (getNrows ());
44+ return ;
45+ }
46+
47+ // Validate data, sort, remove duplicates and etc.
48+ std::vector<index> data;
49+ DataUtils::buildVectorFromData (getNrows (), rows, nvals, data, isSorted, noDuplicates);
50+
51+ // Transfer data to GPU
52+ thrust::device_vector<index, DeviceAlloc<index>> deviceData (data.size ());
53+ thrust::copy (data.begin (), data.end (), deviceData.begin ());
54+
55+ // New vec instance
56+ mVectorImpl = VectorImplType (std::move (deviceData), getNrows (), data.size ());
57+ }
58+
59+ void CudaVector::extract (index *rows, size_t &nvals) {
60+ assert (nvals >= getNvals ());
61+
62+ nvals = getNvals ();
63+
64+ if (nvals > 0 ) {
65+ assert (rows);
66+
67+ // Transfer data from GPU
68+ thrust::copy (mVectorImpl .m_rows_index .begin (), mVectorImpl .m_rows_index .end (), rows);
69+ }
70+ }
71+
72+ void CudaVector::extractSubVector (const VectorBase &otherBase, index i, index nrows, bool checkTime) {
73+ RAISE_ERROR (NotImplemented, " This function is not implemented" );
74+
75+ }
76+
77+ void CudaVector::clone (const VectorBase &otherBase) {
78+ auto other = dynamic_cast <const CudaVector*>(&otherBase);
79+
80+ CHECK_RAISE_ERROR (other != nullptr , InvalidArgument, " Passed vector does not belong to vector class" );
81+ CHECK_RAISE_ERROR (other != this , InvalidArgument, " Vectors must differ" );
82+
83+ assert (this ->getNrows () == other->getNrows ());
84+ this ->mVectorImpl = other->mVectorImpl ;
85+ }
86+
87+ void CudaVector::reduce (index &result, bool checkTime) {
88+ result = getNvals ();
89+ }
90+
91+ void CudaVector::reduceMatrix (const struct MatrixBase &matrix, bool transpose, bool checkTime) {
92+ RAISE_ERROR (NotImplemented, " This function is not implemented" );
93+
94+ }
95+
96+ void CudaVector::eWiseAdd (const VectorBase &aBase, const VectorBase &bBase, bool checkTime) {
97+ RAISE_ERROR (NotImplemented, " This function is not implemented" );
98+
99+ }
100+
101+ void CudaVector::multiplyVxM (const VectorBase &vBase, const struct MatrixBase &mBase , bool checkTime) {
102+ RAISE_ERROR (NotImplemented, " This function is not implemented" );
103+
104+ }
105+
106+ void CudaVector::multiplyMxV (const struct MatrixBase &mBase , const VectorBase &vBase, bool checkTime) {
107+ RAISE_ERROR (NotImplemented, " This function is not implemented" );
108+
109+ }
110+
111+ index CudaVector::getNrows () const {
112+ return mVectorImpl .m_rows ;
113+ }
114+
115+ index CudaVector::getNvals () const {
116+ return mVectorImpl .m_vals ;
117+ }
118+
119+ }
0 commit comments