Browse Source

more templates

Konstantin Ladutenko 6 years ago
parent
commit
89532e6fde
1 changed files with 11 additions and 10 deletions
  1. 11 10
      src/nmie-pybind11.cc

+ 11 - 10
src/nmie-pybind11.cc

@@ -72,20 +72,21 @@ std::vector<T> flatten(const std::vector<std::vector<T>>& v) {
 }
 
 
-py::array VectorVectorComplex2Py(const std::vector<std::vector<std::complex<double> > > &E) {
-  size_t ncoord = E.size();
-  size_t ncomp = E[0].size();
-  auto result = flatten(E);
+template <typename T>
+py::array VectorVector2Py(const std::vector<std::vector<T > > &x) {
+  size_t dim1 = x.size();
+  size_t dim2 = x[0].size();
+  auto result = flatten(x);
   // https://github.com/tdegeus/pybind11_examples/blob/master/04_numpy-2D_cpp-vector/example.cpp 
   size_t              ndim    = 2;
-  std::vector<size_t> shape   = { ncoord , ncomp };
-  std::vector<size_t> strides = { sizeof(std::complex<double>)*ncomp , sizeof(std::complex<double>) };
+  std::vector<size_t> shape   = { dim1 , dim2 };
+  std::vector<size_t> strides = { sizeof(T)*dim2 , sizeof(T) };
 
   // return 2-D NumPy array
   return py::array(py::buffer_info(
     result.data(),                           /* data as contiguous array  */
-    sizeof(std::complex<double>),            /* size of one scalar        */
-    py::format_descriptor<std::complex<double>>::format(), /* data type                 */
+    sizeof(T),                          /* size of one scalar        */
+    py::format_descriptor<T>::format(), /* data type                 */
     ndim,                                    /* number of dimensions      */
     shape,                                   /* shape of the matrix       */
     strides                                  /* strides for each axis     */
@@ -156,8 +157,8 @@ py::tuple py_fieldnlay(const py::array_t<double, py::array::c_style | py::array:
   for (auto& f : H) f.resize(3);
   int L = py_x.size(), terms;
   terms = nmie::nField(L, pl, c_x, c_m, nmax, ncoord, c_Xp, c_Yp, c_Zp, E, H);
-  auto py_E = VectorVectorComplex2Py(E);
-  auto py_H = VectorVectorComplex2Py(H);
+  auto py_E = VectorVector2Py<std::complex<double> >(E);
+  auto py_H = VectorVector2Py<std::complex<double> >(H);
   return py::make_tuple(terms, py_E, py_H);
 }