TensorContractionCuda.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
00006 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
00007 //
00008 // This Source Code Form is subject to the terms of the Mozilla
00009 // Public License v. 2.0. If a copy of the MPL was not distributed
00010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00011 
00012 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
00013 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
00014 
00015 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
00016 
00017 namespace Eigen {
00018 
00019 template<typename Scalar, typename Index, typename LhsMapper,
00020          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
00021 __device__ EIGEN_STRONG_INLINE void
00022 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
00023                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
00024                        const Index m_size, const Index n_size, const Index k_size) {
00025 
00026   const Index m_block_idx = blockIdx.x;
00027   const Index n_block_idx = blockIdx.y;
00028 
00029   const Index base_m = 64 * m_block_idx;
00030   const Index base_n = 64 * n_block_idx;
00031 
00032   // declare and initialize 64 registers for output 8x8 block
00033 
00034   // prefetch registers
00035   Scalar lhs_pf0;
00036   Scalar lhs_pf1;
00037   Scalar lhs_pf2;
00038   Scalar lhs_pf3;
00039   Scalar lhs_pf4;
00040   Scalar lhs_pf5;
00041   Scalar lhs_pf6;
00042   Scalar lhs_pf7;
00043 
00044   Scalar rhs_pf0;
00045   Scalar rhs_pf1;
00046   Scalar rhs_pf2;
00047   Scalar rhs_pf3;
00048   Scalar rhs_pf4;
00049   Scalar rhs_pf5;
00050   Scalar rhs_pf6;
00051   Scalar rhs_pf7;
00052 
00053   // shared memory is formatted
00054   // (contract idx in block, nocontract idx in block, block idx)
00055   // where block idx is column major. This transposition limits the number of
00056   // bank conflicts when reading the LHS. The core idea is that since the contracting
00057   // index is shared by both sides, then the contracting index should be in threadIdx.x.
00058 
00059   // On the LHS, we pad each row inside of each block with an extra element. This makes
00060   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
00061   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
00062 
00063   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
00064   // conflicts on writes and also none on reads.
00065 
00066   // storage indices
00067   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
00068   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
00069 
00070   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
00071   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
00072   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
00073   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
00074   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
00075   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
00076   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
00077   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
00078 
00079   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
00080   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
00081   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
00082   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
00083   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
00084   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
00085   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
00086   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
00087 
00088   // in the loading code, the following variables are important:
00089   // threadIdx.x: the vertical position in an 8x8 block
00090   // threadIdx.y: the vertical index of the 8x8 block in the grid
00091   // threadIdx.z: the horizontal position in an 8x8 block
00092   // k: the horizontal index of the 8x8 block in the grid
00093   //
00094   // The k parameter is implicit (it was the loop counter for a loop that went
00095   // from 0 to <8, but now that loop is unrolled in the below code.
00096 
00097   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
00098   const Index lhs_vert = base_m + load_idx_vert;
00099 
00100 #define prefetchIntoRegisters(base_k)                           \
00101   {                                                             \
00102     lhs_pf0 = conv(0);                                          \
00103     lhs_pf1 = conv(0);                                          \
00104     lhs_pf2 = conv(0);                                          \
00105     lhs_pf3 = conv(0);                                          \
00106     lhs_pf4 = conv(0);                                          \
00107     lhs_pf5 = conv(0);                                          \
00108     lhs_pf6 = conv(0);                                          \
00109     lhs_pf7 = conv(0);                                          \
00110                                                                 \
00111     rhs_pf0 = conv(0);                                          \
00112     rhs_pf1 = conv(0);                                          \
00113     rhs_pf2 = conv(0);                                          \
00114     rhs_pf3 = conv(0);                                          \
00115     rhs_pf4 = conv(0);                                          \
00116     rhs_pf5 = conv(0);                                          \
00117     rhs_pf6 = conv(0);                                          \
00118     rhs_pf7 = conv(0);                                          \
00119                                                                 \
00120     if (!needs_edge_check || lhs_vert < m_size) {               \
00121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
00122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
00123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
00124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
00125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
00126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
00127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
00128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
00129                                                                 \
00130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
00131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
00135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
00136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
00137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
00138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
00139       } else if (lhs_horiz_6 < k_size) {                        \
00140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
00144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
00145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
00146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
00147       } else if (lhs_horiz_5 < k_size) {                        \
00148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
00152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
00153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
00154       } else if (lhs_horiz_4 < k_size) {                        \
00155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
00159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
00160       } else if (lhs_horiz_3 < k_size) {                        \
00161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
00165       } else if (lhs_horiz_2 < k_size) {                        \
00166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
00169       } else if (lhs_horiz_1 < k_size) {                        \
00170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
00172       } else if (lhs_horiz_0 < k_size) {                        \
00173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
00174       }                                                         \
00175     }                                                           \
00176                                                                 \
00177     const Index rhs_vert = base_k + load_idx_vert;              \
00178     if (!needs_edge_check || rhs_vert < k_size) {               \
00179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
00180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
00181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
00182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
00183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
00184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
00185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
00186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
00187                                                                 \
00188       if (rhs_horiz_7 < n_size) {                               \
00189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
00193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
00194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
00195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
00196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
00197       } else if (rhs_horiz_6 < n_size) {                        \
00198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
00202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
00203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
00204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
00205       } else if (rhs_horiz_5 < n_size) {                        \
00206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
00210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
00211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
00212       } else if (rhs_horiz_4 < n_size) {                        \
00213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
00217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
00218       } else if (rhs_horiz_3 < n_size) {                        \
00219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
00223       } else if (rhs_horiz_2 < n_size) {                        \
00224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
00227       } else if (rhs_horiz_1 < n_size) {                        \
00228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
00230       } else if (rhs_horiz_0 < n_size) {                        \
00231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
00232       }                                                         \
00233     }                                                           \
00234   }                                                             \
00235 
00236 #define writeRegToShmem(_)                      \
00237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
00238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
00239                                                 \
00240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
00241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
00242                                                 \
00243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
00244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
00245                                                 \
00246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
00247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
00248                                                 \
00249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
00250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
00251                                                 \
00252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
00253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
00254                                                 \
00255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
00256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
00257                                                 \
00258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
00259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
00260 
00261   // declare and initialize result array
00262 #define res(i, j) _res_##i##j
00263 #define initResultRow(i)                        \
00264   Scalar res(i, 0) = conv(0);                   \
00265   Scalar res(i, 1) = conv(0);                   \
00266   Scalar res(i, 2) = conv(0);                   \
00267   Scalar res(i, 3) = conv(0);                   \
00268   Scalar res(i, 4) = conv(0);                   \
00269   Scalar res(i, 5) = conv(0);                   \
00270   Scalar res(i, 6) = conv(0);                   \
00271   Scalar res(i, 7) = conv(0);                   \
00272 
00273   internal::scalar_cast_op<int, Scalar> conv;
00274   initResultRow(0);
00275   initResultRow(1);
00276   initResultRow(2);
00277   initResultRow(3);
00278   initResultRow(4);
00279   initResultRow(5);
00280   initResultRow(6);
00281   initResultRow(7);
00282 #undef initResultRow
00283 
00284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
00285     // wait for previous iteration to finish with shmem. Despite common sense,
00286     // the code is a bit faster with this here then at bottom of loop
00287     __syncthreads();
00288 
00289     prefetchIntoRegisters(base_k);
00290     writeRegToShmem();
00291 
00292     #undef prefetchIntoRegisters
00293     #undef writeRegToShmem
00294 
00295     // wait for shared mem packing to be done before starting computation
00296     __syncthreads();
00297 
00298     // compute 8x8 matrix product by outer product. This involves packing one column
00299     // of LHS and one row of RHS into registers (takes 16 registers).
00300 
00301 #define lcol(i) _lcol##i
00302     Scalar lcol(0);
00303     Scalar lcol(1);
00304     Scalar lcol(2);
00305     Scalar lcol(3);
00306     Scalar lcol(4);
00307     Scalar lcol(5);
00308     Scalar lcol(6);
00309     Scalar lcol(7);
00310 
00311 #define rrow(j) _rrow##j
00312     Scalar rrow(0);
00313     Scalar rrow(1);
00314     Scalar rrow(2);
00315     Scalar rrow(3);
00316     Scalar rrow(4);
00317     Scalar rrow(5);
00318     Scalar rrow(6);
00319     Scalar rrow(7);
00320 
00321     // Now x corresponds to k, y to m, and z to n
00322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
00323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
00324 
00325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
00326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
00327 
00328 #define loadData(i, j)                          \
00329     lcol(0) = lhs_element(0, j);               \
00330     rrow(0) = rhs_element(i, 0);               \
00331     lcol(1) = lhs_element(1, j);               \
00332     rrow(1) = rhs_element(i, 1);               \
00333     lcol(2) = lhs_element(2, j);               \
00334     rrow(2) = rhs_element(i, 2);               \
00335     lcol(3) = lhs_element(3, j);               \
00336     rrow(3) = rhs_element(i, 3);               \
00337     lcol(4) = lhs_element(4, j);               \
00338     rrow(4) = rhs_element(i, 4);               \
00339     lcol(5) = lhs_element(5, j);               \
00340     rrow(5) = rhs_element(i, 5);               \
00341     lcol(6) = lhs_element(6, j);               \
00342     rrow(6) = rhs_element(i, 6);               \
00343     lcol(7) = lhs_element(7, j);               \
00344     rrow(7) = rhs_element(i, 7);               \
00345 
00346 #define computeCol(j)                           \
00347     res(0, j) += lcol(0) * rrow(j);             \
00348     res(1, j) += lcol(1) * rrow(j);             \
00349     res(2, j) += lcol(2) * rrow(j);             \
00350     res(3, j) += lcol(3) * rrow(j);             \
00351     res(4, j) += lcol(4) * rrow(j);             \
00352     res(5, j) += lcol(5) * rrow(j);             \
00353     res(6, j) += lcol(6) * rrow(j);             \
00354     res(7, j) += lcol(7) * rrow(j);             \
00355 
00356 #define computePass(i)                          \
00357     loadData(i, i);                             \
00358                                                 \
00359     computeCol(0);                              \
00360     computeCol(1);                              \
00361     computeCol(2);                              \
00362     computeCol(3);                              \
00363     computeCol(4);                              \
00364     computeCol(5);                              \
00365     computeCol(6);                              \
00366     computeCol(7);                              \
00367 
00368     computePass(0);
00369     computePass(1);
00370     computePass(2);
00371     computePass(3);
00372     computePass(4);
00373     computePass(5);
00374     computePass(6);
00375     computePass(7);
00376 
00377 #undef lcol
00378 #undef rrow
00379 #undef lhs_element
00380 #undef rhs_element
00381 #undef loadData
00382 #undef computeCol
00383 #undef computePass
00384   } // end loop over k
00385 
00386   // we've now iterated over all of the large (ie width 64) k blocks and
00387   // accumulated results in registers. At this point thread (x, y, z) contains
00388   // the sum across all big k blocks of the product of little k block of index (x, y)
00389   // with block of index (y, z). To compute the final output, we need to reduce
00390   // the 8 threads over y by summation.
00391 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
00392 
00393 #define reduceRow(i, mask)                      \
00394   shuffleInc(i, 0, mask);                       \
00395   shuffleInc(i, 1, mask);                       \
00396   shuffleInc(i, 2, mask);                       \
00397   shuffleInc(i, 3, mask);                       \
00398   shuffleInc(i, 4, mask);                       \
00399   shuffleInc(i, 5, mask);                       \
00400   shuffleInc(i, 6, mask);                       \
00401   shuffleInc(i, 7, mask);                       \
00402 
00403 #define reduceMatrix(mask)                      \
00404   reduceRow(0, mask);                           \
00405   reduceRow(1, mask);                           \
00406   reduceRow(2, mask);                           \
00407   reduceRow(3, mask);                           \
00408   reduceRow(4, mask);                           \
00409   reduceRow(5, mask);                           \
00410   reduceRow(6, mask);                           \
00411   reduceRow(7, mask);                           \
00412 
00413   // actually perform the reduction, now each thread of index (_, y, z)
00414   // contains the correct values in its registers that belong in the output
00415   // block
00416   reduceMatrix(1);
00417   reduceMatrix(2);
00418   reduceMatrix(4);
00419 
00420 #undef shuffleInc
00421 #undef reduceRow
00422 #undef reduceMatrix
00423 
00424   // now we need to copy the 64 values into main memory. We can't split work
00425   // among threads because all variables are in registers. There's 2 ways
00426   // to do this:
00427   // (1) have 1 thread do 64 writes from registers into global memory
00428   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
00429   //     each do 8 writes into global memory. We can just overwrite the shared
00430   //     memory from the problem we just solved.
00431   // (2) is slightly faster than (1) due to less branching and more ILP
00432 
00433   // TODO: won't yield much gain, but could just use currently unused shared mem
00434   //       and then we won't have to sync
00435   // wait for shared mem to be out of use
00436   __syncthreads();
00437 
00438 #define writeResultShmem(i, j)                                          \
00439   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
00440 
00441 #define writeRow(i)                             \
00442   writeResultShmem(i, 0);                       \
00443   writeResultShmem(i, 1);                       \
00444   writeResultShmem(i, 2);                       \
00445   writeResultShmem(i, 3);                       \
00446   writeResultShmem(i, 4);                       \
00447   writeResultShmem(i, 5);                       \
00448   writeResultShmem(i, 6);                       \
00449   writeResultShmem(i, 7);                       \
00450 
00451   if (threadIdx.x == 0) {
00452     writeRow(0);
00453     writeRow(1);
00454     writeRow(2);
00455     writeRow(3);
00456     writeRow(4);
00457     writeRow(5);
00458     writeRow(6);
00459     writeRow(7);
00460   }
00461 #undef writeResultShmem
00462 #undef writeRow
00463 
00464   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
00465   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
00466 
00467   if (threadIdx.x < max_i_write) {
00468     if (max_j_write == 8) {
00469       // TODO: can i trade bank conflicts for coalesced writes?
00470       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
00471       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
00472       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
00473       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
00474       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
00475       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
00476       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
00477       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
00478 
00479       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
00480       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
00481       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
00482       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
00483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
00484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
00485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
00486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
00487     } else {
00488 #pragma unroll 7
00489       for (int j = 0; j < max_j_write; j++) {
00490         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
00491         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
00492       }
00493     }
00494   }
00495 #undef res
00496 }
00497 
00498 
00499 template<typename Scalar, typename Index, typename LhsMapper,
00500          typename RhsMapper, typename OutputMapper>
00501 __global__ void
00502 __launch_bounds__(512)
00503 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
00504                        const OutputMapper output,
00505                        const Index m_size, const Index n_size, const Index k_size) {
00506   __shared__ Scalar lhs_shmem[72 * 64];
00507   __shared__ Scalar rhs_shmem[72 * 64];
00508 
00509   const Index m_block_idx = blockIdx.x;
00510   const Index n_block_idx = blockIdx.y;
00511 
00512   const Index base_m = 64 * m_block_idx;
00513   const Index base_n = 64 * n_block_idx;
00514 
00515   if (base_m + 63 < m_size && base_n + 63 < n_size) {
00516     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
00517   } else {
00518     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
00519   }
00520 }
00521 
00522 
00523 template<typename Index, typename LhsMapper,
00524          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
00525          bool CHECK_RHS_BOUNDARY>
00526 __device__ EIGEN_STRONG_INLINE void
00527 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
00528                        const OutputMapper output, float2 lhs_shmem2[][16],
00529                        float2 rhs_shmem2[][8], const Index m_size,
00530                        const Index n_size, const Index k_size,
00531                        const Index base_m, const Index base_n) {
00532   typedef float Scalar;
00533 
00534   // prefetch registers
00535   float4 lhs_pf0, rhs_pf0;
00536 
00537   float4 results[4];
00538   for (int i=0; i < 4; i++) {
00539     results[i].x = results[i].y = results[i].z = results[i].w = 0;
00540   }
00541 
00542 
00543 #define prefetch_lhs(reg, row, col)                   \
00544     if (!CHECK_LHS_BOUNDARY) {                        \
00545       if (col < k_size) {                             \
00546         reg =lhs.loadPacket<Unaligned>(row, col);     \
00547       }                                               \
00548     } else {                                          \
00549       if (col < k_size) {                             \
00550         if (row + 3 < m_size) {                       \
00551           reg =lhs.loadPacket<Unaligned>(row, col);   \
00552         } else if (row + 2 < m_size) {                \
00553           reg.x =lhs(row + 0, col);                   \
00554           reg.y =lhs(row + 1, col);                   \
00555           reg.z =lhs(row + 2, col);                   \
00556         } else if (row + 1 < m_size) {                \
00557           reg.x =lhs(row + 0, col);                   \
00558           reg.y =lhs(row + 1, col);                   \
00559         } else if (row  < m_size) {                   \
00560           reg.x =lhs(row + 0, col);                   \
00561         }                                             \
00562       }                                               \
00563     }                                                 \
00564 
00565 
00566   Index lhs_vert = base_m+threadIdx.x*4;
00567 
00568   for (Index k = 0; k < k_size; k += 16) {
00569     lhs_pf0 = internal::pset1<float4>(0);
00570     rhs_pf0 = internal::pset1<float4>(0);
00571 
00572     Index lhs_horiz = threadIdx.y+k;
00573     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
00574 
00575     Index rhs_vert = k+(threadIdx.x%4)*4;
00576     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
00577 
00578     if (!CHECK_RHS_BOUNDARY) {
00579       if ((rhs_vert + 3) < k_size) {
00580         // just CHECK_RHS_BOUNDARY
00581         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
00582       } else if (rhs_vert + 2 < k_size) {
00583         // just CHECK_RHS_BOUNDARY
00584         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00585         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00586         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
00587       } else if (rhs_vert + 1 < k_size) {
00588         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00589         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00590       } else if (rhs_vert  < k_size) {
00591         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00592       }
00593     } else {
00594       if (rhs_horiz0 < n_size) {
00595         if ((rhs_vert + 3) < k_size) {
00596           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
00597         } else if ((rhs_vert + 2) < k_size) {
00598           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00599           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00600           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
00601         } else if ((rhs_vert + 1) < k_size) {
00602           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00603           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00604         } else if (rhs_vert  < k_size) {
00605           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00606         }
00607       }
00608     }
00609     float x1, x2 ;
00610     // the following can be a bitwise operation..... some day.
00611     if((threadIdx.x%8) < 4) {
00612       x1 = rhs_pf0.y;
00613       x2 = rhs_pf0.w;
00614     } else {
00615       x1 = rhs_pf0.x;
00616       x2 = rhs_pf0.z;
00617     }
00618     x1 = __shfl_xor(x1, 4);
00619     x2 = __shfl_xor(x2, 4);
00620     if((threadIdx.x%8) < 4) {
00621       rhs_pf0.y = x1;
00622       rhs_pf0.w = x2;
00623     } else {
00624       rhs_pf0.x = x1;
00625       rhs_pf0.z = x2;
00626     }
00627 
00628     // We have 64 features.
00629     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
00630     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
00631     // ...
00632     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
00633     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
00634     // ...
00635     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
00636     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
00637 
00638     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
00639     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
00640     // ...
00641     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
00642     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
00643     // ...
00644 
00645     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
00646     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
00647 
00648 
00649 #define add_vals(fl1, fl2, fr1, fr2)\
00650     results[0].x += fl1.x * fr1.x;\
00651     results[0].y += fl1.y * fr1.x;\
00652     results[0].z += fl2.x * fr1.x;\
00653     results[0].w += fl2.y * fr1.x;\
00654 \
00655     results[1].x += fl1.x * fr1.y;\
00656     results[1].y += fl1.y * fr1.y;\
00657     results[1].z += fl2.x * fr1.y;\
00658     results[1].w += fl2.y * fr1.y;\
00659 \
00660     results[2].x += fl1.x * fr2.x;\
00661     results[2].y += fl1.y * fr2.x;\
00662     results[2].z += fl2.x * fr2.x;\
00663     results[2].w += fl2.y * fr2.x;\
00664 \
00665     results[3].x += fl1.x * fr2.y;\
00666     results[3].y += fl1.y * fr2.y;\
00667     results[3].z += fl2.x * fr2.y;\
00668     results[3].w += fl2.y * fr2.y;\
00669 
00670     __syncthreads();
00671 
00672     // Do the multiplies.
00673     #pragma unroll
00674     for (int koff = 0; koff < 16; koff ++) {
00675       // 32 x threads.
00676       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
00677       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
00678 
00679       int start_feature = threadIdx.y * 4;
00680       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
00681       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
00682 
00683       add_vals(fl1, fl2, fr1, fr2)
00684     }
00685     __syncthreads();
00686   }
00687 
00688 #undef prefetch_lhs
00689 #undef add_vals
00690 
00691   Index horiz_base = threadIdx.y*4+base_n;
00692   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
00693     for (int i = 0; i < 4; i++) {
00694       output(lhs_vert, horiz_base + i) = results[i].x;
00695       output(lhs_vert + 1, horiz_base + i) = results[i].y;
00696       output(lhs_vert + 2, horiz_base + i) = results[i].z;
00697       output(lhs_vert + 3, horiz_base + i) = results[i].w;
00698     }
00699   } else if (!CHECK_RHS_BOUNDARY) {
00700     // CHECK LHS
00701     if (lhs_vert + 3 < m_size) {
00702       for (int i = 0; i < 4; i++) {
00703         output(lhs_vert, horiz_base + i) = results[i].x;
00704         output(lhs_vert + 1, horiz_base + i) = results[i].y;
00705         output(lhs_vert + 2, horiz_base + i) = results[i].z;
00706         output(lhs_vert + 3, horiz_base + i) = results[i].w;
00707       }
00708     } else if (lhs_vert + 2 < m_size) {
00709       for (int i = 0; i < 4; i++) {
00710         output(lhs_vert, horiz_base + i) = results[i].x;
00711         output(lhs_vert + 1, horiz_base + i) = results[i].y;
00712         output(lhs_vert + 2, horiz_base + i) = results[i].z;
00713       }
00714     } else if (lhs_vert + 1 < m_size) {
00715       for (int i = 0; i < 4; i++) {
00716         output(lhs_vert, horiz_base + i) = results[i].x;
00717         output(lhs_vert + 1, horiz_base + i) = results[i].y;
00718       }
00719     } else if (lhs_vert  < m_size) {
00720       for (int i = 0; i < 4; i++) {
00721         output(lhs_vert, horiz_base + i) = results[i].x;
00722       }
00723     }
00724   } else if (!CHECK_LHS_BOUNDARY) {
00725     // CHECK RHS
00726     /*
00727     int ncols_rem = fminf(n_size- horiz_base, 4);
00728     for (int i = 0; i < ncols_rem; i++) {
00729       output(lhs_vert, horiz_base + i) = results[i].x;
00730       output(lhs_vert + 1, horiz_base + i) = results[i].y;
00731       output(lhs_vert + 2, horiz_base + i) = results[i].z;
00732       output(lhs_vert + 3, horiz_base + i) = results[i].w;
00733     }*/
00734     for (int i = 0; i < 4; i++) {
00735       if (horiz_base+i < n_size) {
00736         output(lhs_vert, horiz_base + i) = results[i].x;
00737         output(lhs_vert + 1, horiz_base + i) = results[i].y;
00738         output(lhs_vert + 2, horiz_base + i) = results[i].z;
00739         output(lhs_vert + 3, horiz_base + i) = results[i].w;
00740        }
00741     }
00742   } else {
00743     // CHECK both boundaries.
00744     for (int i = 0; i < 4; i++) {
00745       if (horiz_base+i < n_size) {
00746         if (lhs_vert < m_size)
00747           output(lhs_vert, horiz_base + i) = results[i].x;
00748         if (lhs_vert + 1 < m_size)
00749           output(lhs_vert + 1, horiz_base + i) = results[i].y;
00750         if (lhs_vert + 2 < m_size)
00751           output(lhs_vert + 2, horiz_base + i) = results[i].z;
00752         if (lhs_vert + 3 < m_size)
00753           output(lhs_vert + 3, horiz_base + i) = results[i].w;
00754       }
00755     }
00756   }
00757 }
00758 
00759 
00760 template<typename Index, typename LhsMapper,
00761          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
00762          bool CHECK_RHS_BOUNDARY>
00763 __device__ EIGEN_STRONG_INLINE void
00764 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
00765                        const OutputMapper output, float2 lhs_shmem2[][32],
00766                        float2 rhs_shmem2[][8], const Index m_size,
00767                        const Index n_size, const Index k_size,
00768                        const Index base_m, const Index base_n) {
00769   typedef float Scalar;
00770 
00771   // prefetch registers
00772   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
00773   float4 rhs_pf0, rhs_pf1;
00774 
00775   float4 results[8];
00776   for (int i=0; i < 8; i++) {
00777     results[i].x = results[i].y = results[i].z = results[i].w = 0;
00778   }
00779 
00780 
00781   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
00782   for (Index k = 0; k < k_size; k += 32) {
00783     lhs_pf0 = internal::pset1<float4>(0);
00784     lhs_pf1 = internal::pset1<float4>(0);
00785     lhs_pf2 = internal::pset1<float4>(0);
00786     lhs_pf3 = internal::pset1<float4>(0);
00787 
00788     rhs_pf0 = internal::pset1<float4>(0);
00789     rhs_pf1 = internal::pset1<float4>(0);
00790 
00791      if (!CHECK_LHS_BOUNDARY) {
00792       if ((threadIdx.y/4+k+24) < k_size) {
00793         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00794         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00795         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
00796         lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
00797       } else if ((threadIdx.y/4+k+16) < k_size) {
00798         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00799         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00800         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
00801       } else if ((threadIdx.y/4+k+8) < k_size) {
00802         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00803         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00804       } else if ((threadIdx.y/4+k) < k_size) {
00805         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00806       }
00807     } else {
00808       // just CHECK_LHS_BOUNDARY
00809       if (lhs_vert + 3 < m_size) {
00810         if ((threadIdx.y/4+k+24) < k_size) {
00811           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00812           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00813           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
00814           lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
00815         } else if ((threadIdx.y/4+k+16) < k_size) {
00816           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00817           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00818           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
00819         } else if ((threadIdx.y/4+k+8) < k_size) {
00820           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00821           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
00822         } else if ((threadIdx.y/4+k) < k_size) {
00823           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
00824         }
00825       } else if (lhs_vert + 2 < m_size) {
00826         if ((threadIdx.y/4+k+24) < k_size) {
00827           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00828           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00829           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
00830           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00831           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00832           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
00833           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00834           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
00835           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
00836           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
00837           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
00838           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
00839         } else if ((threadIdx.y/4+k+16) < k_size) {
00840           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00841           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00842           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
00843           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00844           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00845           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
00846           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00847           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
00848           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
00849         } else if ((threadIdx.y/4+k+8) < k_size) {
00850           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00851           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00852           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
00853           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00854           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00855           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
00856         } else if ((threadIdx.y/4+k) < k_size) {
00857           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00858           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00859           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
00860         }
00861       } else if (lhs_vert + 1 < m_size) {
00862         if ((threadIdx.y/4+k+24) < k_size) {
00863           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00864           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00865           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00866           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00867           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00868           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
00869           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
00870           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
00871         } else if ((threadIdx.y/4+k+16) < k_size) {
00872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
00878         } else if ((threadIdx.y/4+k+8) < k_size) {
00879           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00880           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00881           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00882           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
00883         } else if ((threadIdx.y/4+k) < k_size) {
00884           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00885           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
00886         }
00887       } else if (lhs_vert < m_size) {
00888         if ((threadIdx.y/4+k+24) < k_size) {
00889           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00891           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00892           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
00893         } else if ((threadIdx.y/4+k+16) < k_size) {
00894           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00895           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00896           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
00897         } else if ((threadIdx.y/4+k+8) < k_size) {
00898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
00900         } else if ((threadIdx.y/4+k) < k_size) {
00901           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
00902         }
00903       }
00904     }
00905     __syncthreads();
00906     Index rhs_vert = k+threadIdx.x*4;
00907     Index rhs_horiz0 = threadIdx.y*2+base_n;
00908     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
00909     if (!CHECK_RHS_BOUNDARY) {
00910       if ((rhs_vert + 3) < k_size) {
00911         // just CHECK_RHS_BOUNDARY
00912         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
00913         rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
00914       } else if (rhs_vert + 2 < k_size) {
00915         // just CHECK_RHS_BOUNDARY
00916         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00917         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00918         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
00919         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00920         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
00921         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
00922       } else if (rhs_vert + 1 < k_size) {
00923         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00924         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00925         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00926         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
00927       } else if (rhs_vert  < k_size) {
00928         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00929         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00930       }
00931     } else {
00932       if (rhs_horiz1 < n_size) {
00933         if ((rhs_vert + 3) < k_size) {
00934           // just CHECK_RHS_BOUNDARY
00935           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
00936           rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
00937         } else if (rhs_vert + 2 < k_size) {
00938           // just CHECK_RHS_BOUNDARY
00939           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00940           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00941           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
00942           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00943           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
00944           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
00945         } else if (k+threadIdx.x*4 + 1 < k_size) {
00946           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00947           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00948           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00949           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
00950         } else if (k+threadIdx.x*4  < k_size) {
00951           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00952           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
00953         }
00954       } else if (rhs_horiz0 < n_size) {
00955         if ((rhs_vert + 3) < k_size) {
00956           // just CHECK_RHS_BOUNDARY
00957           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
00958         } else if ((rhs_vert + 2) < k_size) {
00959           // just CHECK_RHS_BOUNDARY
00960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00961           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00962           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
00963         } else if ((rhs_vert + 1) < k_size) {
00964           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00965           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
00966         } else if (rhs_vert  < k_size) {
00967           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
00968         }
00969       }
00970     }
00971     __syncthreads();
00972     // Loaded. Do computation
00973     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
00974     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
00975     // ..
00976     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
00977     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
00978     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
00979     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
00980     // ..
00981     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
00982     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
00983     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
00984     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
00985     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
00986     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
00987     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
00988 
00989     // LHS.
00990     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
00991     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
00992     // ...
00993     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
00994     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
00995 
00996 
00997 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
00998       results[0].x += a_feat1.x * f1.x;\
00999       results[1].x += a_feat1.x * f1.y;\
01000       results[2].x += a_feat1.x * f2.x;\
01001       results[3].x += a_feat1.x * f2.y;\
01002       results[4].x += a_feat1.x * f3.x;\
01003       results[5].x += a_feat1.x * f3.y;\
01004       results[6].x += a_feat1.x * f4.x;\
01005       results[7].x += a_feat1.x * f4.y;\
01006 \
01007       results[0].y += a_feat1.y * f1.x;\
01008       results[1].y += a_feat1.y * f1.y;\
01009       results[2].y += a_feat1.y * f2.x;\
01010       results[3].y += a_feat1.y * f2.y;\
01011       results[4].y += a_feat1.y * f3.x;\
01012       results[5].y += a_feat1.y * f3.y;\
01013       results[6].y += a_feat1.y * f4.x;\
01014       results[7].y += a_feat1.y * f4.y;\
01015 \
01016       results[0].z += a_feat2.x * f1.x;\
01017       results[1].z += a_feat2.x * f1.y;\
01018       results[2].z += a_feat2.x * f2.x;\
01019       results[3].z += a_feat2.x * f2.y;\
01020       results[4].z += a_feat2.x * f3.x;\
01021       results[5].z += a_feat2.x * f3.y;\
01022       results[6].z += a_feat2.x * f4.x;\
01023       results[7].z += a_feat2.x * f4.y;\
01024 \
01025       results[0].w += a_feat2.y * f1.x;\
01026       results[1].w += a_feat2.y * f1.y;\
01027       results[2].w += a_feat2.y * f2.x;\
01028       results[3].w += a_feat2.y * f2.y;\
01029       results[4].w += a_feat2.y * f3.x;\
01030       results[5].w += a_feat2.y * f3.y;\
01031       results[6].w += a_feat2.y * f4.x;\
01032       results[7].w += a_feat2.y * f4.y;\
01033 
01034     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
01035     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
01036     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
01037     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
01038 
01039     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
01040     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
01041     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
01042     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
01043 
01044     __syncthreads();
01045 
01046     // Do the multiplies.
01047     #pragma unroll
01048     for (int koff = 0; koff < 32; koff ++) {
01049       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
01050       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
01051 
01052       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
01053       int start_feature = (threadIdx.y / 4) * 8;
01054 
01055       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
01056       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
01057       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
01058       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
01059 
01060       add_vals(a3, a4, br1, br2, br3, br4)
01061     }
01062     __syncthreads();
01063   } // end loop over k
01064 
01065 
01066   __syncthreads();
01067   Index horiz_base = (threadIdx.y/4)*8+base_n;
01068   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
01069     for (int i = 0; i < 8; i++) {
01070       output(lhs_vert, horiz_base + i) = results[i].x;
01071       output(lhs_vert + 1, horiz_base + i) = results[i].y;
01072       output(lhs_vert + 2, horiz_base + i) = results[i].z;
01073       output(lhs_vert + 3, horiz_base + i) = results[i].w;
01074     }
01075   } else if (!CHECK_RHS_BOUNDARY) {
01076     if (lhs_vert + 3 < m_size) {
01077       for (int i = 0; i < 8; i++) {
01078         output(lhs_vert, horiz_base + i) = results[i].x;
01079         output(lhs_vert + 1, horiz_base + i) = results[i].y;
01080         output(lhs_vert + 2, horiz_base + i) = results[i].z;
01081         output(lhs_vert + 3, horiz_base + i) = results[i].w;
01082       }
01083     } else if (lhs_vert + 2 < m_size) {
01084       for (int i = 0; i < 8; i++) {
01085         output(lhs_vert, horiz_base + i) = results[i].x;
01086         output(lhs_vert + 1, horiz_base + i) = results[i].y;
01087         output(lhs_vert + 2, horiz_base + i) = results[i].z;
01088       }
01089     } else if (lhs_vert + 1 < m_size) {
01090       for (int i = 0; i < 8; i++) {
01091         output(lhs_vert, horiz_base + i) = results[i].x;
01092         output(lhs_vert + 1, horiz_base + i) = results[i].y;
01093       }
01094     } else if (lhs_vert  < m_size) {
01095       for (int i = 0; i < 8; i++) {
01096         output(lhs_vert, horiz_base + i) = results[i].x;
01097       }
01098     }
01099   } else if (!CHECK_LHS_BOUNDARY) {
01100     // CHECK BOUNDARY_B
01101     for (int i = 0; i < 8; i++) {
01102       if (horiz_base + i < n_size) {
01103         output(lhs_vert, horiz_base + i) = results[i].x;
01104         output(lhs_vert + 1, horiz_base + i) = results[i].y;
01105         output(lhs_vert + 2, horiz_base + i) = results[i].z;
01106         output(lhs_vert + 3, horiz_base + i) = results[i].w;
01107       }
01108     }
01109   } else {
01110     // CHECK both boundaries.
01111     for (int i = 0; i < 8; i++) {
01112       if (horiz_base + i < n_size) {
01113         if (lhs_vert < m_size)
01114           output(lhs_vert, horiz_base + i) = results[i].x;
01115         if (lhs_vert + 1 < m_size)
01116           output(lhs_vert + 1, horiz_base + i) = results[i].y;
01117         if (lhs_vert + 2 < m_size)
01118           output(lhs_vert + 2, horiz_base + i) = results[i].z;
01119         if (lhs_vert + 3 < m_size)
01120           output(lhs_vert + 3, horiz_base + i) = results[i].w;
01121       }
01122     }
01123   }
01124 }
01125 
01126 
01127 template<typename Index, typename LhsMapper,
01128          typename RhsMapper, typename OutputMapper>
01129 __global__ void
01130 __launch_bounds__(256)
01131 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
01132                        const OutputMapper output,
01133                        const Index m_size, const Index n_size, const Index k_size) {
01134   __shared__ float2 lhs_shmem[64*32];
01135   __shared__ float2 rhs_shmem[128*8];
01136 
01137   typedef float2 LHS_MEM[64][32];
01138   typedef float2 RHS_MEM[128][8];
01139 
01140   typedef float2 LHS_MEM16x16[32][16];
01141   typedef float2 RHS_MEM16x16[64][8];
01142 
01143   const Index m_block_idx = blockIdx.x;
01144   const Index n_block_idx = blockIdx.y;
01145 
01146   const Index base_m = 128 * m_block_idx;
01147   const Index base_n = 64 * n_block_idx;
01148 
01149   bool check_rhs = (base_n + 63) >= n_size;
01150   bool check_lhs128 = (base_m + 127) >= m_size;
01151 
01152   if (!check_rhs) {
01153     if (!check_lhs128) {
01154       // >= 128 rows left
01155       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
01156                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
01157     } else {
01158       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
01159                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
01160     }
01161   } else {
01162     if (!check_lhs128) {
01163       // >= 128 rows left
01164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
01165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
01166     } else {
01167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
01168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
01169     }
01170   }
01171 }
01172 
01173 template<typename Index, typename LhsMapper,
01174          typename RhsMapper, typename OutputMapper>
01175 __global__ void
01176 __launch_bounds__(256)
01177 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
01178                        const OutputMapper output,
01179                        const Index m_size, const Index n_size, const Index k_size) {
01180   __shared__ float2 lhs_shmem[32][16];
01181   __shared__ float2 rhs_shmem[64][8];
01182 
01183   const Index m_block_idx = blockIdx.x;
01184   const Index n_block_idx = blockIdx.y;
01185 
01186   const Index base_m = 64 * m_block_idx;
01187   const Index base_n = 64 * n_block_idx;
01188 
01189   if (base_m + 63 < m_size) {
01190     if (base_n + 63 < n_size) {
01191       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
01192     } else {
01193       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
01194     }
01195   } else {
01196     if (base_n + 63 < n_size) {
01197       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
01198     } else {
01199       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
01200     }
01201   }
01202 }
01203 
01204 
01205 template<typename Indices, typename LeftArgType, typename RightArgType>
01206 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> :
01207     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
01208 
01209   typedef GpuDevice Device;
01210 
01211   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
01212   typedef TensorContractionEvaluatorBase<Self> Base;
01213 
01214   typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
01215   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
01216   typedef typename XprType::Index Index;
01217   typedef typename XprType::CoeffReturnType CoeffReturnType;
01218   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
01219 
01220   enum {
01221     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
01222   };
01223 
01224   // Most of the code is assuming that both input tensors are ColMajor. If the
01225   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
01226   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
01227   // will pretend B is LHS and A is RHS.
01228   typedef typename internal::conditional<
01229     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
01230   typedef typename internal::conditional<
01231     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
01232 
01233   static const int LDims =
01234       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
01235   static const int RDims =
01236       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
01237   static const int ContractDims = internal::array_size<Indices>::value;
01238 
01239   typedef array<Index, LDims> left_dim_mapper_t;
01240   typedef array<Index, RDims> right_dim_mapper_t;
01241 
01242   typedef array<Index, ContractDims> contract_t;
01243   typedef array<Index, LDims - ContractDims> left_nocontract_t;
01244   typedef array<Index, RDims - ContractDims> right_nocontract_t;
01245 
01246   static const int NumDims = LDims + RDims - 2 * ContractDims;
01247 
01248   typedef DSizes<Index, NumDims> Dimensions;
01249 
01250   // typedefs needed in evalTo
01251   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
01252   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
01253 
01254   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
01255   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
01256 
01257   typedef typename LeftEvaluator::Dimensions LeftDimensions;
01258   typedef typename RightEvaluator::Dimensions RightDimensions;
01259 
01260   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
01261       Base(op, device) {}
01262 
01263   // We need to redefine this method to make nvcc happy
01264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
01265     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
01266     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
01267     if (data) {
01268       evalTo(data);
01269       return false;
01270     } else {
01271       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
01272       evalTo(this->m_result);
01273       return true;
01274     }
01275   }
01276 
01277   void evalTo(Scalar* buffer) const {
01278     if (this->m_lhs_inner_dim_contiguous) {
01279       if (this->m_rhs_inner_dim_contiguous) {
01280         if (this->m_rhs_inner_dim_reordered) {
01281           evalTyped<true, true, true, Unaligned>(buffer);
01282         }
01283         else {
01284           evalTyped<true, true, false, Unaligned>(buffer);
01285         }
01286       }
01287       else {
01288        if (this->m_rhs_inner_dim_reordered) {
01289           evalTyped<true, false, true, Unaligned>(buffer);
01290         }
01291         else {
01292           evalTyped<true, false, false, Unaligned>(buffer);
01293         }
01294       }
01295     }
01296     else {
01297       if (this->m_rhs_inner_dim_contiguous) {
01298         if (this->m_rhs_inner_dim_reordered) {
01299           evalTyped<false, true, true, Unaligned>(buffer);
01300         }
01301         else {
01302           evalTyped<false, true, false, Unaligned>(buffer);
01303         }
01304       }
01305       else {
01306        if (this->m_rhs_inner_dim_reordered) {
01307           evalTyped<false, false, true, Unaligned>(buffer);
01308         }
01309         else {
01310           evalTyped<false, false, false, Unaligned>(buffer);
01311         }
01312       }
01313     }
01314   }
01315 
01316   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
01317     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
01318     const Index m_blocks = (m + 63) / 64;
01319     const Index n_blocks = (n + 63) / 64;
01320     const dim3 num_blocks(m_blocks, n_blocks, 1);
01321     const dim3 block_size(8, 8, 8);
01322     LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
01323     }
01324   };
01325 
01326   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
01327     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
01328       if (m < 768 || n < 768) {
01329         const Index m_blocks = (m + 63) / 64;
01330         const Index n_blocks = (n + 63) / 64;
01331         const dim3 num_blocks(m_blocks, n_blocks, 1);
01332         const dim3 block_size(16, 16, 1);
01333         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
01334       } else {
01335         const Index m_blocks = (m + 127) / 128;
01336         const Index n_blocks = (n + 63) / 64;
01337         const dim3 num_blocks(m_blocks, n_blocks, 1);
01338         const dim3 block_size(8, 32, 1);
01339         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
01340       }
01341     }
01342   };
01343 
01344   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
01345   void evalTyped(Scalar* buffer) const {
01346     // columns in left side, rows in right side
01347     const Index k = this->m_k_size;
01348     EIGEN_UNUSED_VARIABLE(k)
01349 
01350     // rows in left side
01351     const Index m = this->m_i_size;
01352 
01353     // columns in right side
01354     const Index n = this->m_j_size;
01355 
01356     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
01357     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
01358 
01359     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
01360                                                    LeftEvaluator, left_nocontract_t,
01361                                                    contract_t, 4,
01362                                                    lhs_inner_dim_contiguous,
01363                                                    false, Unaligned> LhsMapper;
01364 
01365     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
01366                                                    RightEvaluator, right_nocontract_t,
01367                                                    contract_t, 4,
01368                                                    rhs_inner_dim_contiguous,
01369                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
01370 
01371     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
01372 
01373 
01374     // initialize data mappers
01375     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
01376                   this->m_left_contracting_strides, this->m_k_strides);
01377 
01378     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
01379                   this->m_right_contracting_strides, this->m_k_strides);
01380 
01381     OutputMapper output(buffer, m);
01382 
01383     setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
01384     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
01385   }
01386 };
01387 
01388 } // end namespace Eigen
01389 
01390 #endif // EIGEN_USE_GPU and __CUDACC__
01391 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
 All Classes Functions Variables Typedefs Enumerator