![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com> 00006 // Copyright (C) 2014 Eric Martin <eric@ericmart.in> 00007 // 00008 // This Source Code Form is subject to the terms of the Mozilla 00009 // Public License v. 2.0. If a copy of the MPL was not distributed 00010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00011 00012 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 00013 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 00014 00015 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__) 00016 00017 namespace Eigen { 00018 00019 template<typename Scalar, typename Index, typename LhsMapper, 00020 typename RhsMapper, typename OutputMapper, bool needs_edge_check> 00021 __device__ EIGEN_STRONG_INLINE void 00022 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 00023 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem, 00024 const Index m_size, const Index n_size, const Index k_size) { 00025 00026 const Index m_block_idx = blockIdx.x; 00027 const Index n_block_idx = blockIdx.y; 00028 00029 const Index base_m = 64 * m_block_idx; 00030 const Index base_n = 64 * n_block_idx; 00031 00032 // declare and initialize 64 registers for output 8x8 block 00033 00034 // prefetch registers 00035 Scalar lhs_pf0; 00036 Scalar lhs_pf1; 00037 Scalar lhs_pf2; 00038 Scalar lhs_pf3; 00039 Scalar lhs_pf4; 00040 Scalar lhs_pf5; 00041 Scalar lhs_pf6; 00042 Scalar lhs_pf7; 00043 00044 Scalar rhs_pf0; 00045 Scalar rhs_pf1; 00046 Scalar rhs_pf2; 00047 Scalar rhs_pf3; 00048 Scalar rhs_pf4; 00049 Scalar rhs_pf5; 00050 Scalar rhs_pf6; 00051 Scalar rhs_pf7; 00052 00053 // shared memory is formatted 00054 // (contract idx in block, nocontract idx in block, block idx) 00055 // where block idx is column major. This transposition limits the number of 00056 // bank conflicts when reading the LHS. The core idea is that since the contracting 00057 // index is shared by both sides, then the contracting index should be in threadIdx.x. 00058 00059 // On the LHS, we pad each row inside of each block with an extra element. This makes 00060 // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts 00061 // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks. 00062 00063 // On the RHS we just add 8 padding elements to the end of each block. This gives no bank 00064 // conflicts on writes and also none on reads. 00065 00066 // storage indices 00067 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z; 00068 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x; 00069 00070 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0; 00071 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1; 00072 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2; 00073 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3; 00074 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4; 00075 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5; 00076 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6; 00077 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7; 00078 00079 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0; 00080 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1; 00081 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2; 00082 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3; 00083 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4; 00084 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5; 00085 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6; 00086 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7; 00087 00088 // in the loading code, the following variables are important: 00089 // threadIdx.x: the vertical position in an 8x8 block 00090 // threadIdx.y: the vertical index of the 8x8 block in the grid 00091 // threadIdx.z: the horizontal position in an 8x8 block 00092 // k: the horizontal index of the 8x8 block in the grid 00093 // 00094 // The k parameter is implicit (it was the loop counter for a loop that went 00095 // from 0 to <8, but now that loop is unrolled in the below code. 00096 00097 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y; 00098 const Index lhs_vert = base_m + load_idx_vert; 00099 00100 #define prefetchIntoRegisters(base_k) \ 00101 { \ 00102 lhs_pf0 = conv(0); \ 00103 lhs_pf1 = conv(0); \ 00104 lhs_pf2 = conv(0); \ 00105 lhs_pf3 = conv(0); \ 00106 lhs_pf4 = conv(0); \ 00107 lhs_pf5 = conv(0); \ 00108 lhs_pf6 = conv(0); \ 00109 lhs_pf7 = conv(0); \ 00110 \ 00111 rhs_pf0 = conv(0); \ 00112 rhs_pf1 = conv(0); \ 00113 rhs_pf2 = conv(0); \ 00114 rhs_pf3 = conv(0); \ 00115 rhs_pf4 = conv(0); \ 00116 rhs_pf5 = conv(0); \ 00117 rhs_pf6 = conv(0); \ 00118 rhs_pf7 = conv(0); \ 00119 \ 00120 if (!needs_edge_check || lhs_vert < m_size) { \ 00121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ 00122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ 00123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ 00124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ 00125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ 00126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ 00127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ 00128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ 00129 \ 00130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \ 00131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 00135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 00136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 00137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 00138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ 00139 } else if (lhs_horiz_6 < k_size) { \ 00140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 00144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 00145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 00146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 00147 } else if (lhs_horiz_5 < k_size) { \ 00148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 00152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 00153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 00154 } else if (lhs_horiz_4 < k_size) { \ 00155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 00159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 00160 } else if (lhs_horiz_3 < k_size) { \ 00161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 00165 } else if (lhs_horiz_2 < k_size) { \ 00166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 00169 } else if (lhs_horiz_1 < k_size) { \ 00170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 00172 } else if (lhs_horiz_0 < k_size) { \ 00173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 00174 } \ 00175 } \ 00176 \ 00177 const Index rhs_vert = base_k + load_idx_vert; \ 00178 if (!needs_edge_check || rhs_vert < k_size) { \ 00179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ 00180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ 00181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ 00182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ 00183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ 00184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ 00185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ 00186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ 00187 \ 00188 if (rhs_horiz_7 < n_size) { \ 00189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 00193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 00194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 00195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 00196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ 00197 } else if (rhs_horiz_6 < n_size) { \ 00198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 00202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 00203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 00204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 00205 } else if (rhs_horiz_5 < n_size) { \ 00206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 00210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 00211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 00212 } else if (rhs_horiz_4 < n_size) { \ 00213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 00217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 00218 } else if (rhs_horiz_3 < n_size) { \ 00219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 00223 } else if (rhs_horiz_2 < n_size) { \ 00224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 00227 } else if (rhs_horiz_1 < n_size) { \ 00228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 00230 } else if (rhs_horiz_0 < n_size) { \ 00231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 00232 } \ 00233 } \ 00234 } \ 00235 00236 #define writeRegToShmem(_) \ 00237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ 00238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ 00239 \ 00240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ 00241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ 00242 \ 00243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ 00244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ 00245 \ 00246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ 00247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ 00248 \ 00249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ 00250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ 00251 \ 00252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ 00253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ 00254 \ 00255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ 00256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ 00257 \ 00258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ 00259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ 00260 00261 // declare and initialize result array 00262 #define res(i, j) _res_##i##j 00263 #define initResultRow(i) \ 00264 Scalar res(i, 0) = conv(0); \ 00265 Scalar res(i, 1) = conv(0); \ 00266 Scalar res(i, 2) = conv(0); \ 00267 Scalar res(i, 3) = conv(0); \ 00268 Scalar res(i, 4) = conv(0); \ 00269 Scalar res(i, 5) = conv(0); \ 00270 Scalar res(i, 6) = conv(0); \ 00271 Scalar res(i, 7) = conv(0); \ 00272 00273 internal::scalar_cast_op<int, Scalar> conv; 00274 initResultRow(0); 00275 initResultRow(1); 00276 initResultRow(2); 00277 initResultRow(3); 00278 initResultRow(4); 00279 initResultRow(5); 00280 initResultRow(6); 00281 initResultRow(7); 00282 #undef initResultRow 00283 00284 for (Index base_k = 0; base_k < k_size; base_k += 64) { 00285 // wait for previous iteration to finish with shmem. Despite common sense, 00286 // the code is a bit faster with this here then at bottom of loop 00287 __syncthreads(); 00288 00289 prefetchIntoRegisters(base_k); 00290 writeRegToShmem(); 00291 00292 #undef prefetchIntoRegisters 00293 #undef writeRegToShmem 00294 00295 // wait for shared mem packing to be done before starting computation 00296 __syncthreads(); 00297 00298 // compute 8x8 matrix product by outer product. This involves packing one column 00299 // of LHS and one row of RHS into registers (takes 16 registers). 00300 00301 #define lcol(i) _lcol##i 00302 Scalar lcol(0); 00303 Scalar lcol(1); 00304 Scalar lcol(2); 00305 Scalar lcol(3); 00306 Scalar lcol(4); 00307 Scalar lcol(5); 00308 Scalar lcol(6); 00309 Scalar lcol(7); 00310 00311 #define rrow(j) _rrow##j 00312 Scalar rrow(0); 00313 Scalar rrow(1); 00314 Scalar rrow(2); 00315 Scalar rrow(3); 00316 Scalar rrow(4); 00317 Scalar rrow(5); 00318 Scalar rrow(6); 00319 Scalar rrow(7); 00320 00321 // Now x corresponds to k, y to m, and z to n 00322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; 00323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; 00324 00325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] 00326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] 00327 00328 #define loadData(i, j) \ 00329 lcol(0) = lhs_element(0, j); \ 00330 rrow(0) = rhs_element(i, 0); \ 00331 lcol(1) = lhs_element(1, j); \ 00332 rrow(1) = rhs_element(i, 1); \ 00333 lcol(2) = lhs_element(2, j); \ 00334 rrow(2) = rhs_element(i, 2); \ 00335 lcol(3) = lhs_element(3, j); \ 00336 rrow(3) = rhs_element(i, 3); \ 00337 lcol(4) = lhs_element(4, j); \ 00338 rrow(4) = rhs_element(i, 4); \ 00339 lcol(5) = lhs_element(5, j); \ 00340 rrow(5) = rhs_element(i, 5); \ 00341 lcol(6) = lhs_element(6, j); \ 00342 rrow(6) = rhs_element(i, 6); \ 00343 lcol(7) = lhs_element(7, j); \ 00344 rrow(7) = rhs_element(i, 7); \ 00345 00346 #define computeCol(j) \ 00347 res(0, j) += lcol(0) * rrow(j); \ 00348 res(1, j) += lcol(1) * rrow(j); \ 00349 res(2, j) += lcol(2) * rrow(j); \ 00350 res(3, j) += lcol(3) * rrow(j); \ 00351 res(4, j) += lcol(4) * rrow(j); \ 00352 res(5, j) += lcol(5) * rrow(j); \ 00353 res(6, j) += lcol(6) * rrow(j); \ 00354 res(7, j) += lcol(7) * rrow(j); \ 00355 00356 #define computePass(i) \ 00357 loadData(i, i); \ 00358 \ 00359 computeCol(0); \ 00360 computeCol(1); \ 00361 computeCol(2); \ 00362 computeCol(3); \ 00363 computeCol(4); \ 00364 computeCol(5); \ 00365 computeCol(6); \ 00366 computeCol(7); \ 00367 00368 computePass(0); 00369 computePass(1); 00370 computePass(2); 00371 computePass(3); 00372 computePass(4); 00373 computePass(5); 00374 computePass(6); 00375 computePass(7); 00376 00377 #undef lcol 00378 #undef rrow 00379 #undef lhs_element 00380 #undef rhs_element 00381 #undef loadData 00382 #undef computeCol 00383 #undef computePass 00384 } // end loop over k 00385 00386 // we've now iterated over all of the large (ie width 64) k blocks and 00387 // accumulated results in registers. At this point thread (x, y, z) contains 00388 // the sum across all big k blocks of the product of little k block of index (x, y) 00389 // with block of index (y, z). To compute the final output, we need to reduce 00390 // the 8 threads over y by summation. 00391 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) 00392 00393 #define reduceRow(i, mask) \ 00394 shuffleInc(i, 0, mask); \ 00395 shuffleInc(i, 1, mask); \ 00396 shuffleInc(i, 2, mask); \ 00397 shuffleInc(i, 3, mask); \ 00398 shuffleInc(i, 4, mask); \ 00399 shuffleInc(i, 5, mask); \ 00400 shuffleInc(i, 6, mask); \ 00401 shuffleInc(i, 7, mask); \ 00402 00403 #define reduceMatrix(mask) \ 00404 reduceRow(0, mask); \ 00405 reduceRow(1, mask); \ 00406 reduceRow(2, mask); \ 00407 reduceRow(3, mask); \ 00408 reduceRow(4, mask); \ 00409 reduceRow(5, mask); \ 00410 reduceRow(6, mask); \ 00411 reduceRow(7, mask); \ 00412 00413 // actually perform the reduction, now each thread of index (_, y, z) 00414 // contains the correct values in its registers that belong in the output 00415 // block 00416 reduceMatrix(1); 00417 reduceMatrix(2); 00418 reduceMatrix(4); 00419 00420 #undef shuffleInc 00421 #undef reduceRow 00422 #undef reduceMatrix 00423 00424 // now we need to copy the 64 values into main memory. We can't split work 00425 // among threads because all variables are in registers. There's 2 ways 00426 // to do this: 00427 // (1) have 1 thread do 64 writes from registers into global memory 00428 // (2) have 1 thread do 64 writes into shared memory, and then 8 threads 00429 // each do 8 writes into global memory. We can just overwrite the shared 00430 // memory from the problem we just solved. 00431 // (2) is slightly faster than (1) due to less branching and more ILP 00432 00433 // TODO: won't yield much gain, but could just use currently unused shared mem 00434 // and then we won't have to sync 00435 // wait for shared mem to be out of use 00436 __syncthreads(); 00437 00438 #define writeResultShmem(i, j) \ 00439 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ 00440 00441 #define writeRow(i) \ 00442 writeResultShmem(i, 0); \ 00443 writeResultShmem(i, 1); \ 00444 writeResultShmem(i, 2); \ 00445 writeResultShmem(i, 3); \ 00446 writeResultShmem(i, 4); \ 00447 writeResultShmem(i, 5); \ 00448 writeResultShmem(i, 6); \ 00449 writeResultShmem(i, 7); \ 00450 00451 if (threadIdx.x == 0) { 00452 writeRow(0); 00453 writeRow(1); 00454 writeRow(2); 00455 writeRow(3); 00456 writeRow(4); 00457 writeRow(5); 00458 writeRow(6); 00459 writeRow(7); 00460 } 00461 #undef writeResultShmem 00462 #undef writeRow 00463 00464 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8); 00465 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8); 00466 00467 if (threadIdx.x < max_i_write) { 00468 if (max_j_write == 8) { 00469 // TODO: can i trade bank conflicts for coalesced writes? 00470 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0]; 00471 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1]; 00472 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2]; 00473 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3]; 00474 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4]; 00475 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5]; 00476 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6]; 00477 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7]; 00478 00479 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0; 00480 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1; 00481 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2; 00482 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3; 00483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4; 00484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5; 00485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6; 00486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7; 00487 } else { 00488 #pragma unroll 7 00489 for (int j = 0; j < max_j_write; j++) { 00490 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j]; 00491 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val; 00492 } 00493 } 00494 } 00495 #undef res 00496 } 00497 00498 00499 template<typename Scalar, typename Index, typename LhsMapper, 00500 typename RhsMapper, typename OutputMapper> 00501 __global__ void 00502 __launch_bounds__(512) 00503 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 00504 const OutputMapper output, 00505 const Index m_size, const Index n_size, const Index k_size) { 00506 __shared__ Scalar lhs_shmem[72 * 64]; 00507 __shared__ Scalar rhs_shmem[72 * 64]; 00508 00509 const Index m_block_idx = blockIdx.x; 00510 const Index n_block_idx = blockIdx.y; 00511 00512 const Index base_m = 64 * m_block_idx; 00513 const Index base_n = 64 * n_block_idx; 00514 00515 if (base_m + 63 < m_size && base_n + 63 < n_size) { 00516 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 00517 } else { 00518 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 00519 } 00520 } 00521 00522 00523 template<typename Index, typename LhsMapper, 00524 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 00525 bool CHECK_RHS_BOUNDARY> 00526 __device__ EIGEN_STRONG_INLINE void 00527 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs, 00528 const OutputMapper output, float2 lhs_shmem2[][16], 00529 float2 rhs_shmem2[][8], const Index m_size, 00530 const Index n_size, const Index k_size, 00531 const Index base_m, const Index base_n) { 00532 typedef float Scalar; 00533 00534 // prefetch registers 00535 float4 lhs_pf0, rhs_pf0; 00536 00537 float4 results[4]; 00538 for (int i=0; i < 4; i++) { 00539 results[i].x = results[i].y = results[i].z = results[i].w = 0; 00540 } 00541 00542 00543 #define prefetch_lhs(reg, row, col) \ 00544 if (!CHECK_LHS_BOUNDARY) { \ 00545 if (col < k_size) { \ 00546 reg =lhs.loadPacket<Unaligned>(row, col); \ 00547 } \ 00548 } else { \ 00549 if (col < k_size) { \ 00550 if (row + 3 < m_size) { \ 00551 reg =lhs.loadPacket<Unaligned>(row, col); \ 00552 } else if (row + 2 < m_size) { \ 00553 reg.x =lhs(row + 0, col); \ 00554 reg.y =lhs(row + 1, col); \ 00555 reg.z =lhs(row + 2, col); \ 00556 } else if (row + 1 < m_size) { \ 00557 reg.x =lhs(row + 0, col); \ 00558 reg.y =lhs(row + 1, col); \ 00559 } else if (row < m_size) { \ 00560 reg.x =lhs(row + 0, col); \ 00561 } \ 00562 } \ 00563 } \ 00564 00565 00566 Index lhs_vert = base_m+threadIdx.x*4; 00567 00568 for (Index k = 0; k < k_size; k += 16) { 00569 lhs_pf0 = internal::pset1<float4>(0); 00570 rhs_pf0 = internal::pset1<float4>(0); 00571 00572 Index lhs_horiz = threadIdx.y+k; 00573 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) 00574 00575 Index rhs_vert = k+(threadIdx.x%4)*4; 00576 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n; 00577 00578 if (!CHECK_RHS_BOUNDARY) { 00579 if ((rhs_vert + 3) < k_size) { 00580 // just CHECK_RHS_BOUNDARY 00581 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 00582 } else if (rhs_vert + 2 < k_size) { 00583 // just CHECK_RHS_BOUNDARY 00584 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00585 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00586 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 00587 } else if (rhs_vert + 1 < k_size) { 00588 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00589 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00590 } else if (rhs_vert < k_size) { 00591 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00592 } 00593 } else { 00594 if (rhs_horiz0 < n_size) { 00595 if ((rhs_vert + 3) < k_size) { 00596 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 00597 } else if ((rhs_vert + 2) < k_size) { 00598 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00599 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00600 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 00601 } else if ((rhs_vert + 1) < k_size) { 00602 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00603 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00604 } else if (rhs_vert < k_size) { 00605 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00606 } 00607 } 00608 } 00609 float x1, x2 ; 00610 // the following can be a bitwise operation..... some day. 00611 if((threadIdx.x%8) < 4) { 00612 x1 = rhs_pf0.y; 00613 x2 = rhs_pf0.w; 00614 } else { 00615 x1 = rhs_pf0.x; 00616 x2 = rhs_pf0.z; 00617 } 00618 x1 = __shfl_xor(x1, 4); 00619 x2 = __shfl_xor(x2, 4); 00620 if((threadIdx.x%8) < 4) { 00621 rhs_pf0.y = x1; 00622 rhs_pf0.w = x2; 00623 } else { 00624 rhs_pf0.x = x1; 00625 rhs_pf0.z = x2; 00626 } 00627 00628 // We have 64 features. 00629 // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1. 00630 // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3. 00631 // ... 00632 // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63 00633 // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1 00634 // ... 00635 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y); 00636 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w); 00637 00638 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 00639 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 00640 // ... 00641 // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 00642 // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) 00643 // ... 00644 00645 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y); 00646 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w); 00647 00648 00649 #define add_vals(fl1, fl2, fr1, fr2)\ 00650 results[0].x += fl1.x * fr1.x;\ 00651 results[0].y += fl1.y * fr1.x;\ 00652 results[0].z += fl2.x * fr1.x;\ 00653 results[0].w += fl2.y * fr1.x;\ 00654 \ 00655 results[1].x += fl1.x * fr1.y;\ 00656 results[1].y += fl1.y * fr1.y;\ 00657 results[1].z += fl2.x * fr1.y;\ 00658 results[1].w += fl2.y * fr1.y;\ 00659 \ 00660 results[2].x += fl1.x * fr2.x;\ 00661 results[2].y += fl1.y * fr2.x;\ 00662 results[2].z += fl2.x * fr2.x;\ 00663 results[2].w += fl2.y * fr2.x;\ 00664 \ 00665 results[3].x += fl1.x * fr2.y;\ 00666 results[3].y += fl1.y * fr2.y;\ 00667 results[3].z += fl2.x * fr2.y;\ 00668 results[3].w += fl2.y * fr2.y;\ 00669 00670 __syncthreads(); 00671 00672 // Do the multiplies. 00673 #pragma unroll 00674 for (int koff = 0; koff < 16; koff ++) { 00675 // 32 x threads. 00676 float2 fl1 = lhs_shmem2[koff][threadIdx.x]; 00677 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x]; 00678 00679 int start_feature = threadIdx.y * 4; 00680 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 00681 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 00682 00683 add_vals(fl1, fl2, fr1, fr2) 00684 } 00685 __syncthreads(); 00686 } 00687 00688 #undef prefetch_lhs 00689 #undef add_vals 00690 00691 Index horiz_base = threadIdx.y*4+base_n; 00692 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 00693 for (int i = 0; i < 4; i++) { 00694 output(lhs_vert, horiz_base + i) = results[i].x; 00695 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00696 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00697 output(lhs_vert + 3, horiz_base + i) = results[i].w; 00698 } 00699 } else if (!CHECK_RHS_BOUNDARY) { 00700 // CHECK LHS 00701 if (lhs_vert + 3 < m_size) { 00702 for (int i = 0; i < 4; i++) { 00703 output(lhs_vert, horiz_base + i) = results[i].x; 00704 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00705 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00706 output(lhs_vert + 3, horiz_base + i) = results[i].w; 00707 } 00708 } else if (lhs_vert + 2 < m_size) { 00709 for (int i = 0; i < 4; i++) { 00710 output(lhs_vert, horiz_base + i) = results[i].x; 00711 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00712 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00713 } 00714 } else if (lhs_vert + 1 < m_size) { 00715 for (int i = 0; i < 4; i++) { 00716 output(lhs_vert, horiz_base + i) = results[i].x; 00717 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00718 } 00719 } else if (lhs_vert < m_size) { 00720 for (int i = 0; i < 4; i++) { 00721 output(lhs_vert, horiz_base + i) = results[i].x; 00722 } 00723 } 00724 } else if (!CHECK_LHS_BOUNDARY) { 00725 // CHECK RHS 00726 /* 00727 int ncols_rem = fminf(n_size- horiz_base, 4); 00728 for (int i = 0; i < ncols_rem; i++) { 00729 output(lhs_vert, horiz_base + i) = results[i].x; 00730 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00731 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00732 output(lhs_vert + 3, horiz_base + i) = results[i].w; 00733 }*/ 00734 for (int i = 0; i < 4; i++) { 00735 if (horiz_base+i < n_size) { 00736 output(lhs_vert, horiz_base + i) = results[i].x; 00737 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00738 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00739 output(lhs_vert + 3, horiz_base + i) = results[i].w; 00740 } 00741 } 00742 } else { 00743 // CHECK both boundaries. 00744 for (int i = 0; i < 4; i++) { 00745 if (horiz_base+i < n_size) { 00746 if (lhs_vert < m_size) 00747 output(lhs_vert, horiz_base + i) = results[i].x; 00748 if (lhs_vert + 1 < m_size) 00749 output(lhs_vert + 1, horiz_base + i) = results[i].y; 00750 if (lhs_vert + 2 < m_size) 00751 output(lhs_vert + 2, horiz_base + i) = results[i].z; 00752 if (lhs_vert + 3 < m_size) 00753 output(lhs_vert + 3, horiz_base + i) = results[i].w; 00754 } 00755 } 00756 } 00757 } 00758 00759 00760 template<typename Index, typename LhsMapper, 00761 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 00762 bool CHECK_RHS_BOUNDARY> 00763 __device__ EIGEN_STRONG_INLINE void 00764 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 00765 const OutputMapper output, float2 lhs_shmem2[][32], 00766 float2 rhs_shmem2[][8], const Index m_size, 00767 const Index n_size, const Index k_size, 00768 const Index base_m, const Index base_n) { 00769 typedef float Scalar; 00770 00771 // prefetch registers 00772 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3; 00773 float4 rhs_pf0, rhs_pf1; 00774 00775 float4 results[8]; 00776 for (int i=0; i < 8; i++) { 00777 results[i].x = results[i].y = results[i].z = results[i].w = 0; 00778 } 00779 00780 00781 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; 00782 for (Index k = 0; k < k_size; k += 32) { 00783 lhs_pf0 = internal::pset1<float4>(0); 00784 lhs_pf1 = internal::pset1<float4>(0); 00785 lhs_pf2 = internal::pset1<float4>(0); 00786 lhs_pf3 = internal::pset1<float4>(0); 00787 00788 rhs_pf0 = internal::pset1<float4>(0); 00789 rhs_pf1 = internal::pset1<float4>(0); 00790 00791 if (!CHECK_LHS_BOUNDARY) { 00792 if ((threadIdx.y/4+k+24) < k_size) { 00793 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00794 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00795 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 00796 lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 00797 } else if ((threadIdx.y/4+k+16) < k_size) { 00798 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00799 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00800 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 00801 } else if ((threadIdx.y/4+k+8) < k_size) { 00802 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00803 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00804 } else if ((threadIdx.y/4+k) < k_size) { 00805 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00806 } 00807 } else { 00808 // just CHECK_LHS_BOUNDARY 00809 if (lhs_vert + 3 < m_size) { 00810 if ((threadIdx.y/4+k+24) < k_size) { 00811 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00812 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00813 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 00814 lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 00815 } else if ((threadIdx.y/4+k+16) < k_size) { 00816 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00817 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00818 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 00819 } else if ((threadIdx.y/4+k+8) < k_size) { 00820 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00821 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 00822 } else if ((threadIdx.y/4+k) < k_size) { 00823 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 00824 } 00825 } else if (lhs_vert + 2 < m_size) { 00826 if ((threadIdx.y/4+k+24) < k_size) { 00827 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00828 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00829 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 00830 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00831 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00832 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 00833 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00834 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 00835 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 00836 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 00837 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 00838 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24)); 00839 } else if ((threadIdx.y/4+k+16) < k_size) { 00840 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00841 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00842 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 00843 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00844 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00845 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 00846 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00847 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 00848 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 00849 } else if ((threadIdx.y/4+k+8) < k_size) { 00850 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00851 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00852 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 00853 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00854 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00855 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 00856 } else if ((threadIdx.y/4+k) < k_size) { 00857 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00858 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00859 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 00860 } 00861 } else if (lhs_vert + 1 < m_size) { 00862 if ((threadIdx.y/4+k+24) < k_size) { 00863 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00864 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00865 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00866 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00867 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00868 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 00869 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 00870 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 00871 } else if ((threadIdx.y/4+k+16) < k_size) { 00872 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00873 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00874 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00875 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00876 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00877 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 00878 } else if ((threadIdx.y/4+k+8) < k_size) { 00879 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00880 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00881 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00882 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 00883 } else if ((threadIdx.y/4+k) < k_size) { 00884 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00885 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 00886 } 00887 } else if (lhs_vert < m_size) { 00888 if ((threadIdx.y/4+k+24) < k_size) { 00889 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00890 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00891 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00892 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 00893 } else if ((threadIdx.y/4+k+16) < k_size) { 00894 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00895 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00896 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 00897 } else if ((threadIdx.y/4+k+8) < k_size) { 00898 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00899 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 00900 } else if ((threadIdx.y/4+k) < k_size) { 00901 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 00902 } 00903 } 00904 } 00905 __syncthreads(); 00906 Index rhs_vert = k+threadIdx.x*4; 00907 Index rhs_horiz0 = threadIdx.y*2+base_n; 00908 Index rhs_horiz1 = threadIdx.y*2+1+base_n; 00909 if (!CHECK_RHS_BOUNDARY) { 00910 if ((rhs_vert + 3) < k_size) { 00911 // just CHECK_RHS_BOUNDARY 00912 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 00913 rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1); 00914 } else if (rhs_vert + 2 < k_size) { 00915 // just CHECK_RHS_BOUNDARY 00916 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00917 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00918 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 00919 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00920 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 00921 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 00922 } else if (rhs_vert + 1 < k_size) { 00923 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00924 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00925 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00926 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 00927 } else if (rhs_vert < k_size) { 00928 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00929 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00930 } 00931 } else { 00932 if (rhs_horiz1 < n_size) { 00933 if ((rhs_vert + 3) < k_size) { 00934 // just CHECK_RHS_BOUNDARY 00935 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 00936 rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1); 00937 } else if (rhs_vert + 2 < k_size) { 00938 // just CHECK_RHS_BOUNDARY 00939 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00940 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00941 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 00942 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00943 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 00944 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 00945 } else if (k+threadIdx.x*4 + 1 < k_size) { 00946 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00947 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00948 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00949 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 00950 } else if (k+threadIdx.x*4 < k_size) { 00951 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00952 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 00953 } 00954 } else if (rhs_horiz0 < n_size) { 00955 if ((rhs_vert + 3) < k_size) { 00956 // just CHECK_RHS_BOUNDARY 00957 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 00958 } else if ((rhs_vert + 2) < k_size) { 00959 // just CHECK_RHS_BOUNDARY 00960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00961 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00962 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 00963 } else if ((rhs_vert + 1) < k_size) { 00964 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00965 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 00966 } else if (rhs_vert < k_size) { 00967 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 00968 } 00969 } 00970 } 00971 __syncthreads(); 00972 // Loaded. Do computation 00973 // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1. 00974 // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3. 00975 // .. 00976 // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63 00977 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x); 00978 // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1. 00979 // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3. 00980 // .. 00981 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y); 00982 // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1. 00983 // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3. 00984 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z); 00985 // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1. 00986 // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3. 00987 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w); 00988 00989 // LHS. 00990 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 00991 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 00992 // ... 00993 // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 00994 // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 00995 00996 00997 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ 00998 results[0].x += a_feat1.x * f1.x;\ 00999 results[1].x += a_feat1.x * f1.y;\ 01000 results[2].x += a_feat1.x * f2.x;\ 01001 results[3].x += a_feat1.x * f2.y;\ 01002 results[4].x += a_feat1.x * f3.x;\ 01003 results[5].x += a_feat1.x * f3.y;\ 01004 results[6].x += a_feat1.x * f4.x;\ 01005 results[7].x += a_feat1.x * f4.y;\ 01006 \ 01007 results[0].y += a_feat1.y * f1.x;\ 01008 results[1].y += a_feat1.y * f1.y;\ 01009 results[2].y += a_feat1.y * f2.x;\ 01010 results[3].y += a_feat1.y * f2.y;\ 01011 results[4].y += a_feat1.y * f3.x;\ 01012 results[5].y += a_feat1.y * f3.y;\ 01013 results[6].y += a_feat1.y * f4.x;\ 01014 results[7].y += a_feat1.y * f4.y;\ 01015 \ 01016 results[0].z += a_feat2.x * f1.x;\ 01017 results[1].z += a_feat2.x * f1.y;\ 01018 results[2].z += a_feat2.x * f2.x;\ 01019 results[3].z += a_feat2.x * f2.y;\ 01020 results[4].z += a_feat2.x * f3.x;\ 01021 results[5].z += a_feat2.x * f3.y;\ 01022 results[6].z += a_feat2.x * f4.x;\ 01023 results[7].z += a_feat2.x * f4.y;\ 01024 \ 01025 results[0].w += a_feat2.y * f1.x;\ 01026 results[1].w += a_feat2.y * f1.y;\ 01027 results[2].w += a_feat2.y * f2.x;\ 01028 results[3].w += a_feat2.y * f2.y;\ 01029 results[4].w += a_feat2.y * f3.x;\ 01030 results[5].w += a_feat2.y * f3.y;\ 01031 results[6].w += a_feat2.y * f4.x;\ 01032 results[7].w += a_feat2.y * f4.y;\ 01033 01034 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y); 01035 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y); 01036 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y); 01037 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y); 01038 01039 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w); 01040 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w); 01041 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w); 01042 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w); 01043 01044 __syncthreads(); 01045 01046 // Do the multiplies. 01047 #pragma unroll 01048 for (int koff = 0; koff < 32; koff ++) { 01049 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8]; 01050 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8]; 01051 01052 // first feature is at (threadIdx.y/4) * 8 last is at start + 8. 01053 int start_feature = (threadIdx.y / 4) * 8; 01054 01055 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4]; 01056 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4]; 01057 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4]; 01058 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4]; 01059 01060 add_vals(a3, a4, br1, br2, br3, br4) 01061 } 01062 __syncthreads(); 01063 } // end loop over k 01064 01065 01066 __syncthreads(); 01067 Index horiz_base = (threadIdx.y/4)*8+base_n; 01068 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 01069 for (int i = 0; i < 8; i++) { 01070 output(lhs_vert, horiz_base + i) = results[i].x; 01071 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01072 output(lhs_vert + 2, horiz_base + i) = results[i].z; 01073 output(lhs_vert + 3, horiz_base + i) = results[i].w; 01074 } 01075 } else if (!CHECK_RHS_BOUNDARY) { 01076 if (lhs_vert + 3 < m_size) { 01077 for (int i = 0; i < 8; i++) { 01078 output(lhs_vert, horiz_base + i) = results[i].x; 01079 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01080 output(lhs_vert + 2, horiz_base + i) = results[i].z; 01081 output(lhs_vert + 3, horiz_base + i) = results[i].w; 01082 } 01083 } else if (lhs_vert + 2 < m_size) { 01084 for (int i = 0; i < 8; i++) { 01085 output(lhs_vert, horiz_base + i) = results[i].x; 01086 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01087 output(lhs_vert + 2, horiz_base + i) = results[i].z; 01088 } 01089 } else if (lhs_vert + 1 < m_size) { 01090 for (int i = 0; i < 8; i++) { 01091 output(lhs_vert, horiz_base + i) = results[i].x; 01092 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01093 } 01094 } else if (lhs_vert < m_size) { 01095 for (int i = 0; i < 8; i++) { 01096 output(lhs_vert, horiz_base + i) = results[i].x; 01097 } 01098 } 01099 } else if (!CHECK_LHS_BOUNDARY) { 01100 // CHECK BOUNDARY_B 01101 for (int i = 0; i < 8; i++) { 01102 if (horiz_base + i < n_size) { 01103 output(lhs_vert, horiz_base + i) = results[i].x; 01104 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01105 output(lhs_vert + 2, horiz_base + i) = results[i].z; 01106 output(lhs_vert + 3, horiz_base + i) = results[i].w; 01107 } 01108 } 01109 } else { 01110 // CHECK both boundaries. 01111 for (int i = 0; i < 8; i++) { 01112 if (horiz_base + i < n_size) { 01113 if (lhs_vert < m_size) 01114 output(lhs_vert, horiz_base + i) = results[i].x; 01115 if (lhs_vert + 1 < m_size) 01116 output(lhs_vert + 1, horiz_base + i) = results[i].y; 01117 if (lhs_vert + 2 < m_size) 01118 output(lhs_vert + 2, horiz_base + i) = results[i].z; 01119 if (lhs_vert + 3 < m_size) 01120 output(lhs_vert + 3, horiz_base + i) = results[i].w; 01121 } 01122 } 01123 } 01124 } 01125 01126 01127 template<typename Index, typename LhsMapper, 01128 typename RhsMapper, typename OutputMapper> 01129 __global__ void 01130 __launch_bounds__(256) 01131 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 01132 const OutputMapper output, 01133 const Index m_size, const Index n_size, const Index k_size) { 01134 __shared__ float2 lhs_shmem[64*32]; 01135 __shared__ float2 rhs_shmem[128*8]; 01136 01137 typedef float2 LHS_MEM[64][32]; 01138 typedef float2 RHS_MEM[128][8]; 01139 01140 typedef float2 LHS_MEM16x16[32][16]; 01141 typedef float2 RHS_MEM16x16[64][8]; 01142 01143 const Index m_block_idx = blockIdx.x; 01144 const Index n_block_idx = blockIdx.y; 01145 01146 const Index base_m = 128 * m_block_idx; 01147 const Index base_n = 64 * n_block_idx; 01148 01149 bool check_rhs = (base_n + 63) >= n_size; 01150 bool check_lhs128 = (base_m + 127) >= m_size; 01151 01152 if (!check_rhs) { 01153 if (!check_lhs128) { 01154 // >= 128 rows left 01155 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>( 01156 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 01157 } else { 01158 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>( 01159 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 01160 } 01161 } else { 01162 if (!check_lhs128) { 01163 // >= 128 rows left 01164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>( 01165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 01166 } else { 01167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>( 01168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 01169 } 01170 } 01171 } 01172 01173 template<typename Index, typename LhsMapper, 01174 typename RhsMapper, typename OutputMapper> 01175 __global__ void 01176 __launch_bounds__(256) 01177 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, 01178 const OutputMapper output, 01179 const Index m_size, const Index n_size, const Index k_size) { 01180 __shared__ float2 lhs_shmem[32][16]; 01181 __shared__ float2 rhs_shmem[64][8]; 01182 01183 const Index m_block_idx = blockIdx.x; 01184 const Index n_block_idx = blockIdx.y; 01185 01186 const Index base_m = 64 * m_block_idx; 01187 const Index base_n = 64 * n_block_idx; 01188 01189 if (base_m + 63 < m_size) { 01190 if (base_n + 63 < n_size) { 01191 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 01192 } else { 01193 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 01194 } 01195 } else { 01196 if (base_n + 63 < n_size) { 01197 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 01198 } else { 01199 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 01200 } 01201 } 01202 } 01203 01204 01205 template<typename Indices, typename LeftArgType, typename RightArgType> 01206 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> : 01207 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > { 01208 01209 typedef GpuDevice Device; 01210 01211 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; 01212 typedef TensorContractionEvaluatorBase<Self> Base; 01213 01214 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 01215 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 01216 typedef typename XprType::Index Index; 01217 typedef typename XprType::CoeffReturnType CoeffReturnType; 01218 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType; 01219 01220 enum { 01221 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 01222 }; 01223 01224 // Most of the code is assuming that both input tensors are ColMajor. If the 01225 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 01226 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 01227 // will pretend B is LHS and A is RHS. 01228 typedef typename internal::conditional< 01229 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 01230 typedef typename internal::conditional< 01231 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 01232 01233 static const int LDims = 01234 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 01235 static const int RDims = 01236 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 01237 static const int ContractDims = internal::array_size<Indices>::value; 01238 01239 typedef array<Index, LDims> left_dim_mapper_t; 01240 typedef array<Index, RDims> right_dim_mapper_t; 01241 01242 typedef array<Index, ContractDims> contract_t; 01243 typedef array<Index, LDims - ContractDims> left_nocontract_t; 01244 typedef array<Index, RDims - ContractDims> right_nocontract_t; 01245 01246 static const int NumDims = LDims + RDims - 2 * ContractDims; 01247 01248 typedef DSizes<Index, NumDims> Dimensions; 01249 01250 // typedefs needed in evalTo 01251 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 01252 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 01253 01254 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 01255 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 01256 01257 typedef typename LeftEvaluator::Dimensions LeftDimensions; 01258 typedef typename RightEvaluator::Dimensions RightDimensions; 01259 01260 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : 01261 Base(op, device) {} 01262 01263 // We need to redefine this method to make nvcc happy 01264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { 01265 this->m_leftImpl.evalSubExprsIfNeeded(NULL); 01266 this->m_rightImpl.evalSubExprsIfNeeded(NULL); 01267 if (data) { 01268 evalTo(data); 01269 return false; 01270 } else { 01271 this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); 01272 evalTo(this->m_result); 01273 return true; 01274 } 01275 } 01276 01277 void evalTo(Scalar* buffer) const { 01278 if (this->m_lhs_inner_dim_contiguous) { 01279 if (this->m_rhs_inner_dim_contiguous) { 01280 if (this->m_rhs_inner_dim_reordered) { 01281 evalTyped<true, true, true, Unaligned>(buffer); 01282 } 01283 else { 01284 evalTyped<true, true, false, Unaligned>(buffer); 01285 } 01286 } 01287 else { 01288 if (this->m_rhs_inner_dim_reordered) { 01289 evalTyped<true, false, true, Unaligned>(buffer); 01290 } 01291 else { 01292 evalTyped<true, false, false, Unaligned>(buffer); 01293 } 01294 } 01295 } 01296 else { 01297 if (this->m_rhs_inner_dim_contiguous) { 01298 if (this->m_rhs_inner_dim_reordered) { 01299 evalTyped<false, true, true, Unaligned>(buffer); 01300 } 01301 else { 01302 evalTyped<false, true, false, Unaligned>(buffer); 01303 } 01304 } 01305 else { 01306 if (this->m_rhs_inner_dim_reordered) { 01307 evalTyped<false, false, true, Unaligned>(buffer); 01308 } 01309 else { 01310 evalTyped<false, false, false, Unaligned>(buffer); 01311 } 01312 } 01313 } 01314 } 01315 01316 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels { 01317 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 01318 const Index m_blocks = (m + 63) / 64; 01319 const Index n_blocks = (n + 63) / 64; 01320 const dim3 num_blocks(m_blocks, n_blocks, 1); 01321 const dim3 block_size(8, 8, 8); 01322 LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 01323 } 01324 }; 01325 01326 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> { 01327 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 01328 if (m < 768 || n < 768) { 01329 const Index m_blocks = (m + 63) / 64; 01330 const Index n_blocks = (n + 63) / 64; 01331 const dim3 num_blocks(m_blocks, n_blocks, 1); 01332 const dim3 block_size(16, 16, 1); 01333 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 01334 } else { 01335 const Index m_blocks = (m + 127) / 128; 01336 const Index n_blocks = (n + 63) / 64; 01337 const dim3 num_blocks(m_blocks, n_blocks, 1); 01338 const dim3 block_size(8, 32, 1); 01339 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 01340 } 01341 } 01342 }; 01343 01344 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 01345 void evalTyped(Scalar* buffer) const { 01346 // columns in left side, rows in right side 01347 const Index k = this->m_k_size; 01348 EIGEN_UNUSED_VARIABLE(k) 01349 01350 // rows in left side 01351 const Index m = this->m_i_size; 01352 01353 // columns in right side 01354 const Index n = this->m_j_size; 01355 01356 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) 01357 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 01358 01359 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 01360 LeftEvaluator, left_nocontract_t, 01361 contract_t, 4, 01362 lhs_inner_dim_contiguous, 01363 false, Unaligned> LhsMapper; 01364 01365 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 01366 RightEvaluator, right_nocontract_t, 01367 contract_t, 4, 01368 rhs_inner_dim_contiguous, 01369 rhs_inner_dim_reordered, Unaligned> RhsMapper; 01370 01371 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 01372 01373 01374 // initialize data mappers 01375 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 01376 this->m_left_contracting_strides, this->m_k_strides); 01377 01378 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 01379 this->m_right_contracting_strides, this->m_k_strides); 01380 01381 OutputMapper output(buffer, m); 01382 01383 setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte); 01384 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device); 01385 } 01386 }; 01387 01388 } // end namespace Eigen 01389 01390 #endif // EIGEN_USE_GPU and __CUDACC__ 01391 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H