Skip to content

Commit 9068cd0

Browse files
Add i32-type
1 parent 6165347 commit 9068cd0

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

gpu.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ struct TensorPool {
190190

191191
enum NumType {
192192
kf16, // (experimental)
193-
kf32
193+
kf32,
194+
ki32
194195
};
195196

196197
/**
@@ -202,6 +203,8 @@ inline size_t sizeBytes(const NumType &type) {
202203
return sizeof(uint16_t);
203204
case kf32:
204205
return sizeof(float);
206+
case ki32:
207+
return sizeof(int32_t);
205208
default:
206209
LOG(kDefLog, kError, "Invalid NumType in size calculation.");
207210
return 0;
@@ -217,6 +220,8 @@ inline std::string toString(NumType type) {
217220
return "f16";
218221
case kf32:
219222
return "f32";
223+
case ki32:
224+
return "i32";
220225
default:
221226
LOG(kDefLog, kError, "Invalid NumType in string conversion.");
222227
return "unknown";
@@ -635,6 +640,18 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
635640
return tensor;
636641
}
637642

643+
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
644+
const int32_t *data) {
645+
assert(dtype == ki32);
646+
Tensor tensor =
647+
createTensor(ctx.pool, ctx.device, shape, dtype,
648+
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
649+
WGPUBufferUsage_CopySrc);
650+
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
651+
tensor.data.size);
652+
return tensor;
653+
}
654+
638655
/**
639656
* @brief Overload of the tensor factory function to instantiate a tensor on
640657
* the GPU with a given shape, data type. This overload also takes initial

0 commit comments

Comments
 (0)