1111
1212#include " hdf5.h"
1313#include " hdf5_hl.h"
14- #include " xtensor/xadapt.hpp"
15- #include " xtensor/xarray.hpp"
14+ #include " openmc/tensor.h"
1615
1716#include " openmc/array.h"
1817#include " openmc/error.h"
@@ -166,24 +165,19 @@ void read_attribute(hid_t obj_id, const char* name, vector<T>& vec)
166165 read_attr (obj_id, name, H5TypeMap<T>::type_id, vec.data ());
167166}
168167
169- // Generic array version
168+ // Tensor version
170169template <typename T>
171- void read_attribute (hid_t obj_id, const char * name, xt::xarray <T>& arr )
170+ void read_attribute (hid_t obj_id, const char * name, tensor::Tensor <T>& tensor )
172171{
173- // Get shape of attribute array
172+ // Get shape of attribute
174173 auto shape = attribute_shape (obj_id, name);
175174
176- // Allocate new array to read data into
177- std::size_t size = 1 ;
178- for (const auto x : shape)
179- size *= x;
180- vector<T> buffer (size);
175+ // Resize tensor and read data directly
176+ vector<size_t > tshape (shape.begin (), shape.end ());
177+ tensor.resize (tshape);
181178
182179 // Read data from attribute
183- read_attr (obj_id, name, H5TypeMap<T>::type_id, buffer.data ());
184-
185- // Adapt array into xarray
186- arr = xt::adapt (buffer, shape);
180+ read_attr (obj_id, name, H5TypeMap<T>::type_id, tensor.data ());
187181}
188182
189183// overload for std::string
@@ -290,61 +284,32 @@ void read_dataset(
290284}
291285
292286template <typename T>
293- void read_dataset (hid_t dset, xt::xarray <T>& arr , bool indep = false )
287+ void read_dataset (hid_t dset, tensor::Tensor <T>& tensor , bool indep = false )
294288{
295289 // Get shape of dataset
296290 vector<hsize_t > shape = object_shape (dset);
297291
298- // Allocate space in the array to read data into
299- std::size_t size = 1 ;
300- for (const auto x : shape)
301- size *= x;
302- arr.resize (shape);
292+ // Resize tensor and read data directly
293+ vector<size_t > tshape (shape.begin (), shape.end ());
294+ tensor.resize (tshape);
303295
304- // Read data from attribute
296+ // Read data from dataset
305297 read_dataset_lowlevel (
306- dset, nullptr , H5TypeMap<T>::type_id, H5S_ALL, indep, arr .data ());
298+ dset, nullptr , H5TypeMap<T>::type_id, H5S_ALL, indep, tensor .data ());
307299}
308300
309301template <>
310302void read_dataset (
311- hid_t dset, xt::xarray <std::complex <double >>& arr , bool indep);
303+ hid_t dset, tensor::Tensor <std::complex <double >>& tensor , bool indep);
312304
313305template <typename T>
314306void read_dataset (
315- hid_t obj_id, const char * name, xt::xarray<T>& arr, bool indep = false )
316- {
317- // Open dataset and read array
318- hid_t dset = open_dataset (obj_id, name);
319- read_dataset (dset, arr, indep);
320- close_dataset (dset);
321- }
322-
323- template <typename T, std::size_t N>
324- void read_dataset (
325- hid_t obj_id, const char * name, xt::xtensor<T, N>& arr, bool indep = false )
307+ hid_t obj_id, const char * name, tensor::Tensor<T>& tensor, bool indep = false )
326308{
327- // Open dataset and read array
309+ // Open dataset and read tensor
328310 hid_t dset = open_dataset (obj_id, name);
329-
330- // Get shape of dataset
331- vector<hsize_t > hsize_t_shape = object_shape (dset);
311+ read_dataset (dset, tensor, indep);
332312 close_dataset (dset);
333-
334- // cast from hsize_t to size_t
335- vector<size_t > shape (hsize_t_shape.size ());
336- for (int i = 0 ; i < shape.size (); i++) {
337- shape[i] = static_cast <size_t >(hsize_t_shape[i]);
338- }
339-
340- // Allocate new xarray to read data into
341- xt::xarray<T> xarr (shape);
342-
343- // Read data from the dataset
344- read_dataset (obj_id, name, xarr);
345-
346- // Copy into xtensor
347- arr = xarr;
348313}
349314
350315// overload for Position
@@ -358,31 +323,22 @@ inline void read_dataset(
358323 r.z = x[2 ];
359324}
360325
361- template <typename T, std:: size_t N >
326+ template <typename T>
362327inline void read_dataset_as_shape (
363- hid_t obj_id, const char * name, xt::xtensor<T, N >& arr , bool indep = false )
328+ hid_t obj_id, const char * name, tensor::Tensor<T >& tensor , bool indep = false )
364329{
365330 hid_t dset = open_dataset (obj_id, name);
366331
367- // Allocate new array to read data into
368- std::size_t size = 1 ;
369- for (const auto x : arr.shape ())
370- size *= x;
371- vector<T> buffer (size);
372-
373- // Read data from attribute
332+ // Read data directly into pre-shaped tensor
374333 read_dataset_lowlevel (
375- dset, nullptr , H5TypeMap<T>::type_id, H5S_ALL, indep, buffer.data ());
376-
377- // Adapt into xarray
378- arr = xt::adapt (buffer, arr.shape ());
334+ dset, nullptr , H5TypeMap<T>::type_id, H5S_ALL, indep, tensor.data ());
379335
380336 close_dataset (dset);
381337}
382338
383- template <typename T, std:: size_t N >
384- inline void read_nd_vector (hid_t obj_id, const char * name,
385- xt::xtensor<T, N >& result, bool must_have = false )
339+ template <typename T>
340+ inline void read_nd_tensor (hid_t obj_id, const char * name,
341+ tensor::Tensor<T >& result, bool must_have = false )
386342{
387343 if (object_exists (obj_id, name)) {
388344 read_dataset_as_shape (obj_id, name, result, true );
@@ -496,12 +452,16 @@ inline void write_dataset(
496452 false , buffer.data ());
497453}
498454
499- // Template for xarray, xtensor, etc.
500- template <typename D>
501- inline void write_dataset (
502- hid_t obj_id, const char * name, const xt::xcontainer<D>& arr)
455+ // Template for Tensor and StaticTensor2D. A SFINAE guard is used here to
456+ // prevent this template from matching vector/string types that have their own
457+ // overloads above. A generic Container parameter avoids duplicating the body
458+ // for both Tensor<T> and StaticTensor2D<T,R,C>.
459+ template <typename Container,
460+ typename =
461+ std::enable_if_t <tensor::is_tensor<std::decay_t <Container>>::value>>
462+ inline void write_dataset (hid_t obj_id, const char * name, const Container& arr)
503463{
504- using T = typename D ::value_type;
464+ using T = typename std:: decay_t <Container> ::value_type;
505465 auto s = arr.shape ();
506466 vector<hsize_t > dims {s.cbegin (), s.cend ()};
507467 write_dataset_lowlevel (obj_id, dims.size (), dims.data (), name,
0 commit comments