![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 00011 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 00012 00013 namespace Eigen { 00014 00022 namespace internal { 00023 template<typename Axis, typename LhsXprType, typename RhsXprType> 00024 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > 00025 { 00026 // Type promotion to handle the case where the types of the lhs and the rhs are different. 00027 typedef typename promote_storage_type<typename LhsXprType::Scalar, 00028 typename RhsXprType::Scalar>::ret Scalar; 00029 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 00030 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 00031 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 00032 typename traits<RhsXprType>::Index>::type Index; 00033 typedef typename LhsXprType::Nested LhsNested; 00034 typedef typename RhsXprType::Nested RhsNested; 00035 typedef typename remove_reference<LhsNested>::type _LhsNested; 00036 typedef typename remove_reference<RhsNested>::type _RhsNested; 00037 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 00038 static const int Layout = traits<LhsXprType>::Layout; 00039 enum { Flags = 0 }; 00040 }; 00041 00042 template<typename Axis, typename LhsXprType, typename RhsXprType> 00043 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> 00044 { 00045 typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type; 00046 }; 00047 00048 template<typename Axis, typename LhsXprType, typename RhsXprType> 00049 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1, typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> 00050 { 00051 typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type; 00052 }; 00053 00054 } // end namespace internal 00055 00056 00057 template<typename Axis, typename LhsXprType, typename RhsXprType> 00058 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> 00059 { 00060 public: 00061 typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar; 00062 typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind; 00063 typedef typename internal::traits<TensorConcatenationOp>::Index Index; 00064 typedef typename internal::nested<TensorConcatenationOp>::type Nested; 00065 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 00066 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 00067 typedef typename NumTraits<Scalar>::Real RealScalar; 00068 00069 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis) 00070 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {} 00071 00072 EIGEN_DEVICE_FUNC 00073 const typename internal::remove_all<typename LhsXprType::Nested>::type& 00074 lhsExpression() const { return m_lhs_xpr; } 00075 00076 EIGEN_DEVICE_FUNC 00077 const typename internal::remove_all<typename RhsXprType::Nested>::type& 00078 rhsExpression() const { return m_rhs_xpr; } 00079 00080 EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; } 00081 00082 EIGEN_DEVICE_FUNC 00083 EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other) 00084 { 00085 typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign; 00086 Assign assign(*this, other); 00087 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 00088 return *this; 00089 } 00090 00091 template<typename OtherDerived> 00092 EIGEN_DEVICE_FUNC 00093 EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other) 00094 { 00095 typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign; 00096 Assign assign(*this, other); 00097 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 00098 return *this; 00099 } 00100 00101 protected: 00102 typename LhsXprType::Nested m_lhs_xpr; 00103 typename RhsXprType::Nested m_rhs_xpr; 00104 const Axis m_axis; 00105 }; 00106 00107 00108 // Eval as rvalue 00109 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 00110 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 00111 { 00112 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 00113 typedef typename XprType::Index Index; 00114 static const int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value; 00115 static const int RightNumDims = internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value; 00116 typedef DSizes<Index, NumDims> Dimensions; 00117 typedef typename XprType::Scalar Scalar; 00118 typedef typename XprType::CoeffReturnType CoeffReturnType; 00119 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00120 enum { 00121 IsAligned = false, 00122 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, 00123 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 00124 RawAccess = false 00125 }; 00126 00127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 00128 : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) 00129 { 00130 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); 00131 EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE); 00132 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); 00133 00134 eigen_assert(0 <= m_axis && m_axis < NumDims); 00135 const Dimensions& lhs_dims = m_leftImpl.dimensions(); 00136 const Dimensions& rhs_dims = m_rightImpl.dimensions(); 00137 { 00138 int i = 0; 00139 for (; i < m_axis; ++i) { 00140 eigen_assert(lhs_dims[i] > 0); 00141 eigen_assert(lhs_dims[i] == rhs_dims[i]); 00142 m_dimensions[i] = lhs_dims[i]; 00143 } 00144 eigen_assert(lhs_dims[i] > 0); // Now i == m_axis. 00145 eigen_assert(rhs_dims[i] > 0); 00146 m_dimensions[i] = lhs_dims[i] + rhs_dims[i]; 00147 for (++i; i < NumDims; ++i) { 00148 eigen_assert(lhs_dims[i] > 0); 00149 eigen_assert(lhs_dims[i] == rhs_dims[i]); 00150 m_dimensions[i] = lhs_dims[i]; 00151 } 00152 } 00153 00154 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 00155 m_leftStrides[0] = 1; 00156 m_rightStrides[0] = 1; 00157 m_outputStrides[0] = 1; 00158 00159 for (int j = 1; j < NumDims; ++j) { 00160 m_leftStrides[j] = m_leftStrides[j-1] * lhs_dims[j-1]; 00161 m_rightStrides[j] = m_rightStrides[j-1] * rhs_dims[j-1]; 00162 m_outputStrides[j] = m_outputStrides[j-1] * m_dimensions[j-1]; 00163 } 00164 } else { 00165 m_leftStrides[NumDims - 1] = 1; 00166 m_rightStrides[NumDims - 1] = 1; 00167 m_outputStrides[NumDims - 1] = 1; 00168 00169 for (int j = NumDims - 2; j >= 0; --j) { 00170 m_leftStrides[j] = m_leftStrides[j+1] * lhs_dims[j+1]; 00171 m_rightStrides[j] = m_rightStrides[j+1] * rhs_dims[j+1]; 00172 m_outputStrides[j] = m_outputStrides[j+1] * m_dimensions[j+1]; 00173 } 00174 } 00175 } 00176 00177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 00178 00179 // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear? 00180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) 00181 { 00182 m_leftImpl.evalSubExprsIfNeeded(NULL); 00183 m_rightImpl.evalSubExprsIfNeeded(NULL); 00184 return true; 00185 } 00186 00187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() 00188 { 00189 m_leftImpl.cleanup(); 00190 m_rightImpl.cleanup(); 00191 } 00192 00193 // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow. 00194 // See CL/76180724 comments for more ideas. 00195 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 00196 { 00197 // Collect dimension-wise indices (subs). 00198 array<Index, NumDims> subs; 00199 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 00200 for (int i = NumDims - 1; i > 0; --i) { 00201 subs[i] = index / m_outputStrides[i]; 00202 index -= subs[i] * m_outputStrides[i]; 00203 } 00204 subs[0] = index; 00205 } else { 00206 for (int i = 0; i < NumDims - 1; ++i) { 00207 subs[i] = index / m_outputStrides[i]; 00208 index -= subs[i] * m_outputStrides[i]; 00209 } 00210 subs[NumDims - 1] = index; 00211 } 00212 00213 const Dimensions& left_dims = m_leftImpl.dimensions(); 00214 if (subs[m_axis] < left_dims[m_axis]) { 00215 Index left_index; 00216 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 00217 left_index = subs[0]; 00218 for (int i = 1; i < NumDims; ++i) { 00219 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 00220 } 00221 } else { 00222 left_index = subs[NumDims - 1]; 00223 for (int i = NumDims - 2; i >= 0; --i) { 00224 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 00225 } 00226 } 00227 return m_leftImpl.coeff(left_index); 00228 } else { 00229 subs[m_axis] -= left_dims[m_axis]; 00230 const Dimensions& right_dims = m_rightImpl.dimensions(); 00231 Index right_index; 00232 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 00233 right_index = subs[0]; 00234 for (int i = 1; i < NumDims; ++i) { 00235 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 00236 } 00237 } else { 00238 right_index = subs[NumDims - 1]; 00239 for (int i = NumDims - 2; i >= 0; --i) { 00240 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 00241 } 00242 } 00243 return m_rightImpl.coeff(right_index); 00244 } 00245 } 00246 00247 // TODO(phli): Add a real vectorization. 00248 template<int LoadMode> 00249 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 00250 { 00251 const int packetSize = internal::unpacket_traits<PacketReturnType>::size; 00252 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 00253 eigen_assert(index + packetSize - 1 < dimensions().TotalSize()); 00254 00255 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 00256 for (int i = 0; i < packetSize; ++i) { 00257 values[i] = coeff(index+i); 00258 } 00259 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 00260 return rslt; 00261 } 00262 00263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 00264 costPerCoeff(bool vectorized) const { 00265 const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 00266 2 * TensorOpCost::MulCost<Index>() + 00267 TensorOpCost::DivCost<Index>() + 00268 TensorOpCost::ModCost<Index>()); 00269 const double lhs_size = m_leftImpl.dimensions().TotalSize(); 00270 const double rhs_size = m_rightImpl.dimensions().TotalSize(); 00271 return (lhs_size / (lhs_size + rhs_size)) * 00272 m_leftImpl.costPerCoeff(vectorized) + 00273 (rhs_size / (lhs_size + rhs_size)) * 00274 m_rightImpl.costPerCoeff(vectorized) + 00275 TensorOpCost(0, 0, compute_cost); 00276 } 00277 00278 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 00279 00280 protected: 00281 Dimensions m_dimensions; 00282 array<Index, NumDims> m_outputStrides; 00283 array<Index, NumDims> m_leftStrides; 00284 array<Index, NumDims> m_rightStrides; 00285 TensorEvaluator<LeftArgType, Device> m_leftImpl; 00286 TensorEvaluator<RightArgType, Device> m_rightImpl; 00287 const Axis m_axis; 00288 }; 00289 00290 // Eval as lvalue 00291 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 00292 struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 00293 : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 00294 { 00295 typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base; 00296 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 00297 typedef typename Base::Dimensions Dimensions; 00298 enum { 00299 IsAligned = false, 00300 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, 00301 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 00302 RawAccess = false 00303 }; 00304 00305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) 00306 : Base(op, device) 00307 { 00308 EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE); 00309 } 00310 00311 typedef typename XprType::Index Index; 00312 typedef typename XprType::Scalar Scalar; 00313 typedef typename XprType::CoeffReturnType CoeffReturnType; 00314 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00315 00316 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) 00317 { 00318 // Collect dimension-wise indices (subs). 00319 array<Index, Base::NumDims> subs; 00320 for (int i = Base::NumDims - 1; i > 0; --i) { 00321 subs[i] = index / this->m_outputStrides[i]; 00322 index -= subs[i] * this->m_outputStrides[i]; 00323 } 00324 subs[0] = index; 00325 00326 const Dimensions& left_dims = this->m_leftImpl.dimensions(); 00327 if (subs[this->m_axis] < left_dims[this->m_axis]) { 00328 Index left_index = subs[0]; 00329 for (int i = 1; i < Base::NumDims; ++i) { 00330 left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i]; 00331 } 00332 return this->m_leftImpl.coeffRef(left_index); 00333 } else { 00334 subs[this->m_axis] -= left_dims[this->m_axis]; 00335 const Dimensions& right_dims = this->m_rightImpl.dimensions(); 00336 Index right_index = subs[0]; 00337 for (int i = 1; i < Base::NumDims; ++i) { 00338 right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i]; 00339 } 00340 return this->m_rightImpl.coeffRef(right_index); 00341 } 00342 } 00343 00344 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 00345 void writePacket(Index index, const PacketReturnType& x) 00346 { 00347 const int packetSize = internal::unpacket_traits<PacketReturnType>::size; 00348 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 00349 eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize()); 00350 00351 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 00352 internal::pstore<CoeffReturnType, PacketReturnType>(values, x); 00353 for (int i = 0; i < packetSize; ++i) { 00354 coeffRef(index+i) = values[i]; 00355 } 00356 } 00357 }; 00358 00359 } // end namespace Eigen 00360 00361 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H