TensorContractionBlocking.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
00012 
00013 
00014 namespace Eigen {
00015 namespace internal {
00016 
00017 enum {
00018   ShardByRow = 0,
00019   ShardByCol = 1
00020 };
00021 
00022 
00023 // Default Blocking Strategy
00024 template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
00025 class TensorContractionBlocking {
00026  public:
00027 
00028   typedef typename LhsMapper::Scalar LhsScalar;
00029   typedef typename RhsMapper::Scalar RhsScalar;
00030 
00031   EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
00032       kc_(k), mc_(m), nc_(n)
00033   {
00034     if (ShardingType == ShardByCol) {
00035       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
00036     }
00037     else {
00038       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
00039     }
00040   }
00041 
00042   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
00043   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
00044   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
00045 
00046  private:
00047   Index kc_;
00048   Index mc_;
00049   Index nc_;
00050 };
00051 
00052 
00053 } // end namespace internal
00054 } // end namespace Eigen
00055 
00056 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
 All Classes Functions Variables Typedefs Enumerator