![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Mehdi Goli Codeplay Software Ltd. 00005 // Ralph Potter Codeplay Software Ltd. 00006 // Luke Iwanski Codeplay Software Ltd. 00007 // Contact: <eigen@codeplay.com> 00008 // 00009 // This Source Code Form is subject to the terms of the Mozilla 00010 // Public License v. 2.0. If a copy of the MPL was not distributed 00011 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00012 00013 /***************************************************************** 00014 * TensorSyclExprConstructor.h 00015 * 00016 * \brief: 00017 * This file re-create an expression on the SYCL device in order 00018 * to use the original tensor evaluator. 00019 * 00020 *****************************************************************/ 00021 00022 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP 00023 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP 00024 00025 namespace Eigen { 00026 namespace TensorSycl { 00027 namespace internal { 00030 template <typename PtrType, size_t N, typename... Params> 00031 struct EvalToLHSConstructor { 00032 PtrType expr; 00033 EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t): expr((&(*(utility::tuple::get<N>(t).get_pointer())))) {} 00034 }; 00035 00042 template <typename OrigExpr, typename IndexExpr, typename... Params> 00043 struct ExprConstructor; 00044 00047 #define TENSORMAP(CVQual)\ 00048 template <typename Scalar_, int Options_, int Options2_, int Options3_, int NumIndices_, typename IndexType_,\ 00049 template <class> class MakePointer_, size_t N, typename... Params>\ 00050 struct ExprConstructor< CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer>,\ 00051 CVQual PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options3_, MakePointer_>, N>, Params...>{\ 00052 typedef CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer> Type;\ 00053 Type expr;\ 00054 template <typename FuncDetector>\ 00055 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 00056 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 00057 }; 00058 00059 TENSORMAP(const) 00060 TENSORMAP() 00061 #undef TENSORMAP 00062 00063 #define UNARYCATEGORY(CVQual)\ 00064 template <template<class, class> class UnaryCategory, typename OP, typename OrigRHSExpr, typename RHSExpr, typename... Params>\ 00065 struct ExprConstructor<CVQual UnaryCategory<OP, OrigRHSExpr>, CVQual UnaryCategory<OP, RHSExpr>, Params...> {\ 00066 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_type;\ 00067 my_type rhsExpr;\ 00068 typedef CVQual UnaryCategory<OP, typename my_type::Type> Type;\ 00069 Type expr;\ 00070 template <typename FuncDetector>\ 00071 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 00072 : rhsExpr(funcD.rhsExpr, t), expr(rhsExpr.expr, funcD.func) {}\ 00073 }; 00074 00075 UNARYCATEGORY(const) 00076 UNARYCATEGORY() 00077 #undef UNARYCATEGORY 00078 00081 #define BINARYCATEGORY(CVQual)\ 00082 template <template<class, class, class> class BinaryCategory, typename OP, typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr,\ 00083 typename RHSExpr, typename... Params>\ 00084 struct ExprConstructor<CVQual BinaryCategory<OP, OrigLHSExpr, OrigRHSExpr>, CVQual BinaryCategory<OP, LHSExpr, RHSExpr>, Params...> {\ 00085 typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ 00086 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ 00087 typedef CVQual BinaryCategory<OP, typename my_left_type::Type, typename my_right_type::Type> Type;\ 00088 my_left_type lhsExpr;\ 00089 my_right_type rhsExpr;\ 00090 Type expr;\ 00091 template <typename FuncDetector>\ 00092 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 00093 : lhsExpr(funcD.lhsExpr, t),rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr, funcD.func) {}\ 00094 }; 00095 00096 BINARYCATEGORY(const) 00097 BINARYCATEGORY() 00098 #undef BINARYCATEGORY 00099 00102 #define TERNARYCATEGORY(CVQual)\ 00103 template <template <class, class, class, class> class TernaryCategory, typename OP, typename OrigArg1Expr, typename OrigArg2Expr,typename OrigArg3Expr,\ 00104 typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename... Params>\ 00105 struct ExprConstructor<CVQual TernaryCategory<OP, OrigArg1Expr, OrigArg2Expr, OrigArg3Expr>, CVQual TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Params...> {\ 00106 typedef ExprConstructor<OrigArg1Expr, Arg1Expr, Params...> my_arg1_type;\ 00107 typedef ExprConstructor<OrigArg2Expr, Arg2Expr, Params...> my_arg2_type;\ 00108 typedef ExprConstructor<OrigArg3Expr, Arg3Expr, Params...> my_arg3_type;\ 00109 typedef CVQual TernaryCategory<OP, typename my_arg1_type::Type, typename my_arg2_type::Type, typename my_arg3_type::Type> Type;\ 00110 my_arg1_type arg1Expr;\ 00111 my_arg2_type arg2Expr;\ 00112 my_arg3_type arg3Expr;\ 00113 Type expr;\ 00114 template <typename FuncDetector>\ 00115 ExprConstructor(FuncDetector &funcD,const utility::tuple::Tuple<Params...> &t)\ 00116 : arg1Expr(funcD.arg1Expr, t), arg2Expr(funcD.arg2Expr, t), arg3Expr(funcD.arg3Expr, t), expr(arg1Expr.expr, arg2Expr.expr, arg3Expr.expr, funcD.func) {}\ 00117 }; 00118 00119 TERNARYCATEGORY(const) 00120 TERNARYCATEGORY() 00121 #undef TERNARYCATEGORY 00122 00125 #define SELECTOP(CVQual)\ 00126 template <typename OrigIfExpr, typename OrigThenExpr, typename OrigElseExpr, typename IfExpr, typename ThenExpr, typename ElseExpr, typename... Params>\ 00127 struct ExprConstructor< CVQual TensorSelectOp<OrigIfExpr, OrigThenExpr, OrigElseExpr>, CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Params...> {\ 00128 typedef ExprConstructor<OrigIfExpr, IfExpr, Params...> my_if_type;\ 00129 typedef ExprConstructor<OrigThenExpr, ThenExpr, Params...> my_then_type;\ 00130 typedef ExprConstructor<OrigElseExpr, ElseExpr, Params...> my_else_type;\ 00131 typedef CVQual TensorSelectOp<typename my_if_type::Type, typename my_then_type::Type, typename my_else_type::Type> Type;\ 00132 my_if_type ifExpr;\ 00133 my_then_type thenExpr;\ 00134 my_else_type elseExpr;\ 00135 Type expr;\ 00136 template <typename FuncDetector>\ 00137 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 00138 : ifExpr(funcD.ifExpr, t), thenExpr(funcD.thenExpr, t), elseExpr(funcD.elseExpr, t), expr(ifExpr.expr, thenExpr.expr, elseExpr.expr) {}\ 00139 }; 00140 00141 SELECTOP(const) 00142 SELECTOP() 00143 #undef SELECTOP 00144 00147 #define ASSIGN(CVQual)\ 00148 template <typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr, typename RHSExpr, typename... Params>\ 00149 struct ExprConstructor<CVQual TensorAssignOp<OrigLHSExpr, OrigRHSExpr>, CVQual TensorAssignOp<LHSExpr, RHSExpr>, Params...> {\ 00150 typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ 00151 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ 00152 typedef CVQual TensorAssignOp<typename my_left_type::Type, typename my_right_type::Type> Type;\ 00153 my_left_type lhsExpr;\ 00154 my_right_type rhsExpr;\ 00155 Type expr;\ 00156 template <typename FuncDetector>\ 00157 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 00158 : lhsExpr(funcD.lhsExpr, t), rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr) {}\ 00159 }; 00160 00161 ASSIGN(const) 00162 ASSIGN() 00163 #undef ASSIGN 00164 00165 00166 #define EVALTO(CVQual)\ 00167 template <typename OrigExpr, typename Expr, typename... Params>\ 00168 struct ExprConstructor<CVQual TensorEvalToOp<OrigExpr, MakeGlobalPointer>, CVQual TensorEvalToOp<Expr>, Params...> {\ 00169 typedef ExprConstructor<OrigExpr, Expr, Params...> my_expr_type;\ 00170 typedef typename TensorEvalToOp<OrigExpr, MakeGlobalPointer>::PointerType my_buffer_type;\ 00171 typedef CVQual TensorEvalToOp<typename my_expr_type::Type, MakeGlobalPointer> Type;\ 00172 my_expr_type nestedExpression;\ 00173 EvalToLHSConstructor<my_buffer_type, 0, Params...> buffer;\ 00174 Type expr;\ 00175 template <typename FuncDetector>\ 00176 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 00177 : nestedExpression(funcD.rhsExpr, t), buffer(t), expr(buffer.expr, nestedExpression.expr) {}\ 00178 }; 00179 00180 EVALTO(const) 00181 EVALTO() 00182 #undef EVALTO 00183 00186 #define FORCEDEVAL(CVQual)\ 00187 template <typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ 00188 struct ExprConstructor<CVQual TensorForcedEvalOp<OrigExpr, MakeGlobalPointer>,\ 00189 CVQual PlaceHolder<CVQual TensorForcedEvalOp<DevExpr>, N>, Params...> {\ 00190 typedef CVQual TensorMap<Tensor<typename TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::Scalar,\ 00191 TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::NumDimensions, 0, typename TensorForcedEvalOp<DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ 00192 Type expr;\ 00193 template <typename FuncDetector>\ 00194 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 00195 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 00196 }; 00197 00198 FORCEDEVAL(const) 00199 FORCEDEVAL() 00200 #undef FORCEDEVAL 00201 00202 template <bool Conds, size_t X , size_t Y > struct ValueCondition { 00203 static const size_t Res =X; 00204 }; 00205 template<size_t X, size_t Y> struct ValueCondition<false, X , Y> { 00206 static const size_t Res =Y; 00207 }; 00208 00210 #define SYCLREDUCTIONEXPR(CVQual)\ 00211 template <typename OP, typename Dim, typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ 00212 struct ExprConstructor<CVQual TensorReductionOp<OP, Dim, OrigExpr, MakeGlobalPointer>,\ 00213 CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dim, DevExpr>, N>, Params...> {\ 00214 static const size_t NumIndices= ValueCondition< TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions==0, 1, TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions >::Res;\ 00215 typedef CVQual TensorMap<Tensor<typename TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::Scalar,\ 00216 NumIndices, 0, typename TensorReductionOp<OP, Dim, DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ 00217 Type expr;\ 00218 template <typename FuncDetector>\ 00219 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 00220 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 00221 }; 00222 00223 SYCLREDUCTIONEXPR(const) 00224 SYCLREDUCTIONEXPR() 00225 #undef SYCLREDUCTIONEXPR 00226 00228 template <typename OrigExpr, typename IndexExpr, typename FuncD, typename... Params> 00229 auto createDeviceExpression(FuncD &funcD, const utility::tuple::Tuple<Params...> &t) 00230 -> decltype(ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t)) { 00231 return ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t); 00232 } 00233 00234 } 00235 } 00236 } 00237 00238 00239 #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP