@@ -37,6 +37,14 @@ namespace py=pybind11;
3737
3838using Map_t=torch::Dict<string, int64_t >;
3939
40+ std::unordered_map<string, int64_t > to_map (const Map_t &map) {
41+ std::unordered_map<string, int64_t > result;
42+ for (auto it = map.begin (); it != map.end (); ++it) {
43+ result[it->key ()] = it->value ();
44+ }
45+ return result;
46+ }
47+
4048inline void * data_ptr (const torch::Tensor &tensor) {
4149 if (tensor.dtype () == torch::kFloat )
4250 return reinterpret_cast <void *>(tensor.data_ptr <float >());
@@ -62,21 +70,11 @@ class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::C
6270 dbl_bwd_dict (dbl_bwd_dict_i.copy ()),
6371 kernel_dims (kernel_dims_i.copy ()),
6472 internal (kernel_plaintext,
65- KernelLaunchConfig (
66- fwd_dict. at ( " num_blocks " ),
67- fwd_dict. at ( " num_threads " ),
68- fwd_dict. at ( " smem " )
73+ to_map (fwd_dict_i),
74+ to_map (bwd_dict_i ),
75+ to_map (dbl_bwd_dict_i ),
76+ to_map (kernel_dims_i )
6977 ),
70- KernelLaunchConfig (
71- bwd_dict.at (" num_blocks" ),
72- bwd_dict.at (" num_threads" ),
73- bwd_dict.at (" smem" )
74- ),
75- KernelLaunchConfig (
76- dbl_bwd_dict.at (" num_blocks" ),
77- dbl_bwd_dict.at (" num_threads" ),
78- dbl_bwd_dict.at (" smem" )
79- )),
8078 L3_dim (kernel_dims.at (" L3_dim" )),
8179 shared_weights (kernel_dims.at (" shared_weights" )) { }
8280
@@ -225,17 +223,11 @@ class TorchJITConv : public torch::CustomClassHolder {
225223 fwd_dict (fwd_dict_i.copy()),
226224 bwd_dict (bwd_dict_i.copy()),
227225 kernel_dims (kernel_dims_i.copy()),
228- internal (kernel_plaintext,
229- KernelLaunchConfig (
230- fwd_dict.at(" num_blocks" ),
231- fwd_dict.at(" num_threads" ),
232- fwd_dict.at(" smem" )
226+ internal (kernel_plaintext,
227+ to_map (fwd_dict_i),
228+ to_map(bwd_dict_i),
229+ to_map(kernel_dims_i)
233230 ),
234- KernelLaunchConfig(
235- bwd_dict.at(" num_blocks" ),
236- bwd_dict.at(" num_threads" ),
237- bwd_dict.at(" smem" )
238- )),
239231 L3_dim(kernel_dims.at(" L3_dim" )) { }
240232
241233 tuple<tuple<string, string>, tuple<string, Map_t>, tuple<string, Map_t>, tuple<string, Map_t>> __obj_flatten__ () {
0 commit comments