![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H 00011 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 template <typename Dimensions, typename Scalar> 00018 class TensorLazyBaseEvaluator { 00019 public: 00020 TensorLazyBaseEvaluator() : m_refcount(0) { } 00021 virtual ~TensorLazyBaseEvaluator() { } 00022 00023 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0; 00024 EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0; 00025 00026 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0; 00027 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0; 00028 00029 void incrRefCount() { ++m_refcount; } 00030 void decrRefCount() { --m_refcount; } 00031 int refCount() const { return m_refcount; } 00032 00033 private: 00034 // No copy, no assigment; 00035 TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other); 00036 TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other); 00037 00038 int m_refcount; 00039 }; 00040 00041 00042 template <typename Dimensions, typename Expr, typename Device> 00043 class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> { 00044 public: 00045 // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions; 00046 typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar; 00047 00048 TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) { 00049 m_dims = m_impl.dimensions(); 00050 m_impl.evalSubExprsIfNeeded(NULL); 00051 } 00052 virtual ~TensorLazyEvaluatorReadOnly() { 00053 m_impl.cleanup(); 00054 } 00055 00056 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const { 00057 return m_dims; 00058 } 00059 EIGEN_DEVICE_FUNC virtual const Scalar* data() const { 00060 return m_impl.data(); 00061 } 00062 00063 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const { 00064 return m_impl.coeff(index); 00065 } 00066 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) { 00067 eigen_assert(false && "can't reference the coefficient of a rvalue"); 00068 return m_dummy; 00069 }; 00070 00071 protected: 00072 TensorEvaluator<Expr, Device> m_impl; 00073 Dimensions m_dims; 00074 Scalar m_dummy; 00075 }; 00076 00077 template <typename Dimensions, typename Expr, typename Device> 00078 class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> { 00079 public: 00080 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base; 00081 typedef typename Base::Scalar Scalar; 00082 00083 TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) { 00084 } 00085 virtual ~TensorLazyEvaluatorWritable() { 00086 } 00087 00088 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { 00089 return this->m_impl.coeffRef(index); 00090 } 00091 }; 00092 00093 template <typename Dimensions, typename Expr, typename Device> 00094 class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value), 00095 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, 00096 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type { 00097 public: 00098 typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value), 00099 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, 00100 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base; 00101 typedef typename Base::Scalar Scalar; 00102 00103 TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) { 00104 } 00105 virtual ~TensorLazyEvaluator() { 00106 } 00107 }; 00108 00109 } // namespace internal 00110 00111 00119 template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> > 00120 { 00121 public: 00122 typedef TensorRef<PlainObjectType> Self; 00123 typedef typename PlainObjectType::Base Base; 00124 typedef typename Eigen::internal::nested<Self>::type Nested; 00125 typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind; 00126 typedef typename internal::traits<PlainObjectType>::Index Index; 00127 typedef typename internal::traits<PlainObjectType>::Scalar Scalar; 00128 typedef typename NumTraits<Scalar>::Real RealScalar; 00129 typedef typename Base::CoeffReturnType CoeffReturnType; 00130 typedef Scalar* PointerType; 00131 typedef PointerType PointerArgType; 00132 00133 static const Index NumIndices = PlainObjectType::NumIndices; 00134 typedef typename PlainObjectType::Dimensions Dimensions; 00135 00136 enum { 00137 IsAligned = false, 00138 PacketAccess = false, 00139 Layout = PlainObjectType::Layout, 00140 CoordAccess = false, // to be implemented 00141 RawAccess = false 00142 }; 00143 00144 EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) { 00145 } 00146 00147 template <typename Expression> 00148 EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) { 00149 m_evaluator->incrRefCount(); 00150 } 00151 00152 template <typename Expression> 00153 EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) { 00154 unrefEvaluator(); 00155 m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice()); 00156 m_evaluator->incrRefCount(); 00157 return *this; 00158 } 00159 00160 ~TensorRef() { 00161 unrefEvaluator(); 00162 } 00163 00164 TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) { 00165 eigen_assert(m_evaluator->refCount() > 0); 00166 m_evaluator->incrRefCount(); 00167 } 00168 00169 TensorRef& operator = (const TensorRef& other) { 00170 if (this != &other) { 00171 unrefEvaluator(); 00172 m_evaluator = other.m_evaluator; 00173 eigen_assert(m_evaluator->refCount() > 0); 00174 m_evaluator->incrRefCount(); 00175 } 00176 return *this; 00177 } 00178 00179 EIGEN_DEVICE_FUNC 00180 EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); } 00181 EIGEN_DEVICE_FUNC 00182 EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } 00183 EIGEN_DEVICE_FUNC 00184 EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } 00185 EIGEN_DEVICE_FUNC 00186 EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); } 00187 EIGEN_DEVICE_FUNC 00188 EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); } 00189 00190 EIGEN_DEVICE_FUNC 00191 EIGEN_STRONG_INLINE const Scalar operator()(Index index) const 00192 { 00193 return m_evaluator->coeff(index); 00194 } 00195 00196 #if EIGEN_HAS_VARIADIC_TEMPLATES 00197 template<typename... IndexTypes> EIGEN_DEVICE_FUNC 00198 EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const 00199 { 00200 const std::size_t num_indices = (sizeof...(otherIndices) + 1); 00201 const array<Index, num_indices> indices{{firstIndex, otherIndices...}}; 00202 return coeff(indices); 00203 } 00204 template<typename... IndexTypes> EIGEN_DEVICE_FUNC 00205 EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) 00206 { 00207 const std::size_t num_indices = (sizeof...(otherIndices) + 1); 00208 const array<Index, num_indices> indices{{firstIndex, otherIndices...}}; 00209 return coeffRef(indices); 00210 } 00211 #else 00212 00213 EIGEN_DEVICE_FUNC 00214 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const 00215 { 00216 array<Index, 2> indices; 00217 indices[0] = i0; 00218 indices[1] = i1; 00219 return coeff(indices); 00220 } 00221 EIGEN_DEVICE_FUNC 00222 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const 00223 { 00224 array<Index, 3> indices; 00225 indices[0] = i0; 00226 indices[1] = i1; 00227 indices[2] = i2; 00228 return coeff(indices); 00229 } 00230 EIGEN_DEVICE_FUNC 00231 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const 00232 { 00233 array<Index, 4> indices; 00234 indices[0] = i0; 00235 indices[1] = i1; 00236 indices[2] = i2; 00237 indices[3] = i3; 00238 return coeff(indices); 00239 } 00240 EIGEN_DEVICE_FUNC 00241 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const 00242 { 00243 array<Index, 5> indices; 00244 indices[0] = i0; 00245 indices[1] = i1; 00246 indices[2] = i2; 00247 indices[3] = i3; 00248 indices[4] = i4; 00249 return coeff(indices); 00250 } 00251 EIGEN_DEVICE_FUNC 00252 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1) 00253 { 00254 array<Index, 2> indices; 00255 indices[0] = i0; 00256 indices[1] = i1; 00257 return coeffRef(indices); 00258 } 00259 EIGEN_DEVICE_FUNC 00260 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2) 00261 { 00262 array<Index, 3> indices; 00263 indices[0] = i0; 00264 indices[1] = i1; 00265 indices[2] = i2; 00266 return coeffRef(indices); 00267 } 00268 EIGEN_DEVICE_FUNC 00269 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) 00270 { 00271 array<Index, 4> indices; 00272 indices[0] = i0; 00273 indices[1] = i1; 00274 indices[2] = i2; 00275 indices[3] = i3; 00276 return coeffRef(indices); 00277 } 00278 EIGEN_DEVICE_FUNC 00279 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4) 00280 { 00281 array<Index, 5> indices; 00282 indices[0] = i0; 00283 indices[1] = i1; 00284 indices[2] = i2; 00285 indices[3] = i3; 00286 indices[4] = i4; 00287 return coeffRef(indices); 00288 } 00289 #endif 00290 00291 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC 00292 EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const 00293 { 00294 const Dimensions& dims = this->dimensions(); 00295 Index index = 0; 00296 if (PlainObjectType::Options & RowMajor) { 00297 index += indices[0]; 00298 for (size_t i = 1; i < NumIndices; ++i) { 00299 index = index * dims[i] + indices[i]; 00300 } 00301 } else { 00302 index += indices[NumIndices-1]; 00303 for (int i = NumIndices-2; i >= 0; --i) { 00304 index = index * dims[i] + indices[i]; 00305 } 00306 } 00307 return m_evaluator->coeff(index); 00308 } 00309 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC 00310 EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) 00311 { 00312 const Dimensions& dims = this->dimensions(); 00313 Index index = 0; 00314 if (PlainObjectType::Options & RowMajor) { 00315 index += indices[0]; 00316 for (size_t i = 1; i < NumIndices; ++i) { 00317 index = index * dims[i] + indices[i]; 00318 } 00319 } else { 00320 index += indices[NumIndices-1]; 00321 for (int i = NumIndices-2; i >= 0; --i) { 00322 index = index * dims[i] + indices[i]; 00323 } 00324 } 00325 return m_evaluator->coeffRef(index); 00326 } 00327 00328 EIGEN_DEVICE_FUNC 00329 EIGEN_STRONG_INLINE const Scalar coeff(Index index) const 00330 { 00331 return m_evaluator->coeff(index); 00332 } 00333 00334 EIGEN_DEVICE_FUNC 00335 EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) 00336 { 00337 return m_evaluator->coeffRef(index); 00338 } 00339 00340 private: 00341 EIGEN_STRONG_INLINE void unrefEvaluator() { 00342 if (m_evaluator) { 00343 m_evaluator->decrRefCount(); 00344 if (m_evaluator->refCount() == 0) { 00345 delete m_evaluator; 00346 } 00347 } 00348 } 00349 00350 internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator; 00351 }; 00352 00353 00354 // evaluator for rvalues 00355 template<typename Derived, typename Device> 00356 struct TensorEvaluator<const TensorRef<Derived>, Device> 00357 { 00358 typedef typename Derived::Index Index; 00359 typedef typename Derived::Scalar Scalar; 00360 typedef typename Derived::Scalar CoeffReturnType; 00361 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00362 typedef typename Derived::Dimensions Dimensions; 00363 00364 enum { 00365 IsAligned = false, 00366 PacketAccess = false, 00367 Layout = TensorRef<Derived>::Layout, 00368 CoordAccess = false, // to be implemented 00369 RawAccess = false 00370 }; 00371 00372 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) 00373 : m_ref(m) 00374 { } 00375 00376 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); } 00377 00378 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { 00379 return true; 00380 } 00381 00382 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } 00383 00384 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 00385 return m_ref.coeff(index); 00386 } 00387 00388 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { 00389 return m_ref.coeffRef(index); 00390 } 00391 00392 EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); } 00393 00394 protected: 00395 TensorRef<Derived> m_ref; 00396 }; 00397 00398 00399 // evaluator for lvalues 00400 template<typename Derived, typename Device> 00401 struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device> 00402 { 00403 typedef typename Derived::Index Index; 00404 typedef typename Derived::Scalar Scalar; 00405 typedef typename Derived::Scalar CoeffReturnType; 00406 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00407 typedef typename Derived::Dimensions Dimensions; 00408 00409 typedef TensorEvaluator<const TensorRef<Derived>, Device> Base; 00410 00411 enum { 00412 IsAligned = false, 00413 PacketAccess = false, 00414 RawAccess = false 00415 }; 00416 00417 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d) 00418 { } 00419 00420 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { 00421 return this->m_ref.coeffRef(index); 00422 } 00423 }; 00424 00425 00426 00427 } // end namespace Eigen 00428 00429 #endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H