ViennaCL - The Vienna Computing Library  1.7.1
Free open-source GPU-accelerated linear algebra and solver library.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
matrix_product_template.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
2 #define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
3 
4 /* =========================================================================
5 Copyright (c) 2010-2016, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8 Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14 Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16 (A list of authors and contributors can be found in the manual)
17 
18 License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
27 #include <vector>
28 
30 
33 
38 #include "viennacl/forwards.h"
39 
40 #include "viennacl/tools/tools.hpp"
41 
42 namespace viennacl
43 {
44 namespace device_specific
45 {
46 
48 {
50  , unsigned int local_size_0, unsigned int KL, unsigned int local_size_1
51  , unsigned int ms, unsigned int ks, unsigned int ns
52  , fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param
53  , unsigned int local_fetch_0_param, unsigned int local_fetch_1_param): template_base::parameters_type(simd_width, local_size_0, local_size_1, 1),
54  kL(KL), mS(ms), kS(ks), nS(ns), A_fetching_policy(A_fetching_policy_param), B_fetching_policy(B_fetching_policy_param),
55  local_fetch_0(local_fetch_0_param), local_fetch_1(local_fetch_1_param),
56  mL(ms*local_size_0), nL(ns*local_size_1){}
57 
58  unsigned int kL;
59 
60  unsigned int mS;
61  unsigned int kS;
62  unsigned int nS;
63 
66 
67  unsigned int local_fetch_0;
68  unsigned int local_fetch_1;
69 
70  unsigned int mL;
71  unsigned int nL;
72 };
73 
74 class matrix_product_template : public template_base_impl<matrix_product_template, matrix_product_parameters>
75 {
76 
77 private:
78  unsigned int n_lmem_elements() const
79  {
80  unsigned int N = 0;
81  if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
82  N += p_.kL * (p_.mL+1);
83  if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
84  N += p_.nL * (p_.kL+1);
85  return N;
86  }
87 
88  int check_invalid_impl(viennacl::ocl::device const & /*device*/) const
89  {
90  if (p_.A_fetching_policy!=FETCH_FROM_LOCAL && p_.B_fetching_policy!=FETCH_FROM_LOCAL&& (p_.local_fetch_0!=0 || p_.local_fetch_1!=0))
91  return TEMPLATE_GLOBAL_MEMORY_REQUIRES_ZERO_LOCAL_FETCH;
92 
93  if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
94  return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
95 
96  if (p_.kS > p_.kL)
97  return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
98 
99  if (!(A_trans_=='N' && B_trans_=='T') && p_.simd_width>1)
100  return TEMPLATE_SIMD_WIDTH_MUST_BE_ONE;
101 
102  if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
103  {
104  if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
105  return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
106  }
107 
108  if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
109  {
110  unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
111  unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
112 
113  if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
114  return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
115 
116  if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
117  return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
118 
119  }
120  if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
121  {
122  unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
123  unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
124 
125  if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
126  return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
127 
128  if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
129  return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
130 
131  }
132 
133  return TEMPLATE_VALID;
134  }
135 
136  static void parse(scheduler::statement const & s,
137  vcl_size_t & C_idx, leaf_t & C_leaf, vcl_size_t & alpha_idx, leaf_t & alpha_leaf,
138  vcl_size_t & A_idx, leaf_t & A_leaf, bool& A_trans, vcl_size_t & B_idx, leaf_t & B_leaf, bool& B_trans,
139  vcl_size_t & beta_idx, leaf_t & beta_leaf)
140  {
141  using namespace tree_parsing;
142  using namespace scheduler;
143 
144  scheduler::statement::container_type const & array = s.array();
145  vcl_size_t root_idx = s.root();
146 
147  C_idx = root_idx;
148  C_leaf = LHS_NODE_TYPE;
149 
150  vcl_size_t node_add_idx = array[root_idx].rhs.node_index;
151 
152  vcl_size_t node_1_idx = array[node_add_idx].lhs.node_index;
153  alpha_idx = node_1_idx;
154  alpha_leaf = RHS_NODE_TYPE;
155 
156  vcl_size_t mat_prod_idx = array[node_1_idx].lhs.node_index;
157  if (array[mat_prod_idx].lhs.type_family==MATRIX_TYPE_FAMILY)
158  {
159  A_trans = false;
160  A_idx = mat_prod_idx;
161  }
162  else
163  {
164  A_trans = true;
165  A_idx = array[mat_prod_idx].lhs.node_index;
166  }
167  A_leaf = LHS_NODE_TYPE;
168 
169  if (array[mat_prod_idx].rhs.type_family==MATRIX_TYPE_FAMILY)
170  {
171  B_trans = false;
172  B_idx = mat_prod_idx;
173  B_leaf = RHS_NODE_TYPE;
174  }
175  else
176  {
177  B_trans = true;
178  B_idx = array[mat_prod_idx].rhs.node_index;
179  B_leaf = LHS_NODE_TYPE;
180  }
181 
182  vcl_size_t node_2_idx = array[node_add_idx].rhs.node_index;
183  beta_idx = node_2_idx;
184  beta_leaf = RHS_NODE_TYPE;
185  }
186 
187  void VIENNACL_HANDLE_BOUNDS(bool fallback, utils::kernel_generation_stream & stream, std::string const & inbounds, std::string const & do_if, std::string do_else) const
188  {
189  if (fallback)
190  {
191  stream << "if (" << inbounds << ")" << std::endl;
192  stream.inc_tab();
193  stream << do_if << ";" << std::endl;
194  stream.dec_tab();
195  stream << "else" << std::endl;
196  stream.inc_tab();
197  stream << do_else << ";" << std::endl;
198  stream.dec_tab();
199  }
200  else
201  stream << do_if << ";" << std::endl;
202  }
203 
204 
205  std::string generate_impl(const std::string &kernel_prefix, const statements_container &statements, const std::vector<mapping_type> &mappings, bool fallback) const
206  {
207  using std::string;
208  using tools::to_string;
209 
211  parameters_type const & p = fallback?pfallback:p_;
212 
213 #define VIENNACL_MUL_STRIDE1 string(fallback?"*#stride1":"")
214 #define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load) (!fallback?string(to_load):string( string(in_bounds) + "?" + string(to_load) + ":0"))
215 #define VIENNACL_VSTORE(value, offset, ptr) vstore(p.simd_width, value, offset, ptr)
216 
217  string widthstr = tools::to_string(p.simd_width);
218 
223  scheduler::statement const & st = statements.data().front();
224  mapping_type const & mapping = mappings.front();
225 
226  bool A_trans = false, B_trans = false;
227  vcl_size_t C_idx=0, alpha_idx=0, A_idx=0, B_idx=0, beta_idx=0;
228  leaf_t C_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE;
229  parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
230 
231  mapped_matrix * C = (mapped_matrix* )at(mapping, mapping_key( C_idx, C_leaf)).get();
232  mapped_host_scalar * alpha = (mapped_host_scalar*)at(mapping, mapping_key(alpha_idx, alpha_leaf)).get();
233  mapped_matrix * A = (mapped_matrix* )at(mapping, mapping_key( A_idx, A_leaf)).get();
234  mapped_matrix * B = (mapped_matrix* )at(mapping, mapping_key( B_idx, B_leaf)).get();
235  mapped_host_scalar * beta = (mapped_host_scalar*)at(mapping, mapping_key( beta_idx, beta_leaf)).get();
236 
240 
241  stream << " __attribute__((reqd_work_group_size(" << p.local_size_0 << "," << p.local_size_1 << ",1)))" << std::endl;
242  std::map<std::string, unsigned int> widths;
243  widths[A->name()] = p.simd_width;
244  widths[B->name()] = p.simd_width;
245  generate_prototype(stream, kernel_prefix, "unsigned int M, unsigned int N, unsigned int K, ", mappings, statements, widths);
246  stream << "{" << std::endl;
247  stream.inc_tab();
248  if(!fallback)
249  {
250  stream << A->process("#start1 /= " + to_string(p.simd_width) + ";") << std::endl;
251  stream << A->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl;
252  stream << B->process("#start1/= " + to_string(p.simd_width) + ";") << std::endl;
253  stream << B->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl;
254  }
255  tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#pointer += $OFFSET{#start1, #start2};", statements, mappings);
256  tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#ld *= #nldstride;", statements, mappings);
257 
259  stream << C->process("#scalartype rC[" + to_string(p.mS) + "][" + to_string(p.nS) + "] = {{(#scalartype)0}};") << std::endl;
261  stream << A->process("#scalartype rA[" + to_string(p.kS) + "][" + to_string(p.mS) + "];") << std::endl;
262  else
263  stream << A->process(utils::append_width("#scalartype",p.simd_width) + " rA[" + to_string(p.kS) + "][" + to_string(p.mS/p.simd_width) + "];") << std::endl;
265  stream << B->process("#scalartype rB[" + to_string(p.kS) + "][" + to_string(p.nS) + "];");
266  else
267  stream << B->process(utils::append_width("#scalartype",p.simd_width) + " rB[" + to_string(p.kS) + "][" + to_string(p.nS/p.simd_width) + "];") << std::endl;
268 
269 
271  stream << A->process("__local #scalartype lA[" + to_string(p.kL*(p.mL+1)) + "];");
273  stream << B->process("__local #scalartype lB[" + to_string(p.kL*(p.nL+1)) + "];");
274  stream << std::endl;
275 
276  stream << "size_t gidx = get_group_id(0);" << std::endl;
277  stream << "size_t gidy = get_group_id(1);" << std::endl;
278  stream << "size_t idx = get_local_id(0);" << std::endl;
279  stream << "size_t idy = get_local_id(1);" << std::endl;
280 
282  {
283  stream << std::endl;
284  stream << "size_t idt = " << p.local_size_0 << "*idy + idx;" << std::endl;
285  stream << "size_t idxT = idt % " << p.local_fetch_0 << ";" << std::endl;
286  stream << "size_t idyT = idt / " << p.local_fetch_0 << ";" << std::endl;
287  }
288  stream << std::endl;
289 
290  if (fallback)
291  {
292  //Bounds checking for M (in A, C)
293  stream << "bool in_bounds_m[" << p.mS << "];" << std::endl;
294  stream << "for(size_t m = 0; m < " << p.mS << "; m++)" << std::endl;
295  stream.inc_tab();
296  switch (p.A_fetching_policy)
297  {
299  stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx*" << p.mS << " + m < M;" << std::endl;
300  break;
301  default:
302  stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx + m*" << p.local_size_0 << " < M;" << std::endl;
303  break;
304  }
305  stream.dec_tab();
306 
307  //Bounds checking for A if Local
309  {
310  unsigned int fetch_size = (A_trans_=='N'?p.local_fetch_0*p.simd_width:p.local_fetch_1);
311  stream << "bool in_bounds_m_local[" << p.mL/fetch_size << "];" << std::endl;
312  stream << "for(size_t m = 0; m < " << p.mL/fetch_size << "; m++)" << std::endl;
313  stream.inc_tab();
314  stream << "in_bounds_m_local[m] = gidx*" << p.mL << " + " << (A_trans_=='N'?"idxT":"idyT") << " + m*" << fetch_size << " < M;" << std::endl;
315  stream.dec_tab();
316  }
317 
318  //Bounds checking for N (in B, C)
319  stream << "bool in_bounds_n[" << p.nS << "];" << std::endl;
320  stream << "for(size_t n = 0; n < " << p.nS << "; n++)" << std::endl;
321  stream.inc_tab();
322  switch (p.B_fetching_policy)
323  {
325  stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy*" << p.nS << " + n < N;" << std::endl;
326  break;
327  default:
328  stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy + n*" << p.local_size_1 << " < N;" << std::endl;
329  break;
330  }
331  stream.dec_tab();
332 
333  //Bounds checking for B if Local
335  {
336  unsigned int fetch_size = (B_trans_=='T'?p.local_fetch_0*p.simd_width:p.local_fetch_1);
337  stream << "bool in_bounds_n_local[" << p.nL/fetch_size << "];" << std::endl;
338  stream << "for(size_t n = 0; n < " << p.nL/fetch_size << "; n++)" << std::endl;
339  stream.inc_tab();
340  stream << "in_bounds_n_local[n] = gidy*" << p.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + n*" << fetch_size << " < N;" << std::endl;
341  stream.dec_tab();
342  }
343  }
344 
345  switch (p.A_fetching_policy)
346  {
347  case FETCH_FROM_LOCAL:
348  if (A_trans_=='N')
349  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + " + idxT)" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl;
350  else
351  stream << A->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidx*" + to_string(p.mL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl;
352  break;
353 
355  if (A_trans_=='N')
356  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
357  else
358  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")*#ld;") << std::endl;
359  break;
360 
362  if (A_trans_=='N')
363  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
364  else
365  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx)*#ld;") << std::endl;
366  break;
367 
368  //default: break;
369  }
370 
371  switch (p.B_fetching_policy)
372  {
373  case FETCH_FROM_LOCAL:
374  if (B_trans_=='T')
375  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + " + idxT" + ")" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl;
376  else
377  stream << B->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidy*" + to_string(p.nL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl;
378  break;
379 
381  if (B_trans_=='T')
382  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
383  else
384  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")*#ld;") << std::endl;
385  break;
386 
388  if (B_trans_=='T')
389  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
390  else
391  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy)*#ld;") << std::endl;
392  break;
393 
394  //default: break;
395  }
396 
397  stream << std::endl;
398  stream << "size_t K_size_t = K;" << std::endl;
399  stream << "for(size_t block_k=0; block_k < K_size_t; block_k+=" << p.kL << "){" << std::endl;
400  stream.inc_tab();
401 
403  {
404  if (A_trans_=='N')
405  stream << A->process("__local #scalartype* plA = lA + idyT*" + to_string(p.mL + 1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl;
406  else
407  stream << A->process("__local #scalartype* plA = lA + idxT*" + to_string(p.mL + 1) + " + idyT;") << std::endl;
408  }
409 
410 
412  {
413  if (B_trans_=='T')
414  stream << B->process("__local #scalartype* plB = lB + idyT*" + to_string(p.nL+1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl;
415  else
416  stream << B->process("__local #scalartype* plB = lB + idxT*" + to_string(p.nL+1) + "+ idyT;") <<std::endl;
417  }
418 
419 
421  stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
422 
424  if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='N')
425  for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1)
426  for (unsigned int m = 0; m < p.mL; m += p.local_fetch_0*p.simd_width)
427  {
428  string in_bounds = "in_bounds_m_local[" + to_string(m/(p.local_fetch_0*p.simd_width)) + "]";
429  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
430  stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(k*(p.mL+1)+m))) << ";" << std::endl;
431  }
432  else if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='T')
433  for (unsigned int k = 0; k < p.mL; k += p.local_fetch_1)
434  for (unsigned int m = 0; m < p.kL; m += p.local_fetch_0*p.simd_width)
435  {
436  string in_bounds = "in_bounds_m_local[" + to_string(k/p.local_fetch_1) + "]";
437  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
438  stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(m*(p.mL+1)+k))) << ";" << std::endl;
439  }
440 
441  if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='T')
442  for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1)
443  for (unsigned int n = 0; n < p.nL; n += p.local_fetch_0*p.simd_width)
444  {
445  string in_bounds = "in_bounds_n_local[" + to_string(n/(p.local_fetch_0*p.simd_width)) + "]";
446  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
447  stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(k*(p.nL+1)+n))) << ";" << std::endl;
448  }
449  else if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='N')
450  for (unsigned int k = 0; k < p.nL; k += p.local_fetch_1)
451  for (unsigned int n = 0; n < p.kL; n += p.local_fetch_0*p.simd_width)
452  {
453  string in_bounds = "in_bounds_n_local[" + to_string(k/p.local_fetch_1) + "]";
454  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
455  stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(n*(p.nL+1)+k))) << ";" << std::endl;
456  }
457 
459  {
460  stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
461  stream << "size_t offA = " << p.simd_width << "*idx;" << std::endl;
462  stream << "size_t offB = " << p.simd_width << "*idy;" << std::endl;
463  }
464 
465  if (fallback)
466  stream << "for(size_t k = 0; k < " << p.kL << " && (block_k + k < K_size_t); k+=" << p.kS << "){" << std::endl;
467  else
468  stream << "for(size_t k = 0; k < " << p.kL << "; k+=" << p.kS << "){" << std::endl;
469  stream.inc_tab();
470 
472  stream << "#pragma unroll " << p.kS << std::endl;
473  stream << "for(size_t kk = 0; kk < " << p.kS << "; kk++)" << std::endl;
474  stream << "#pragma unroll " << p.mS/p.simd_width << std::endl;
475  stream << "for(size_t mm = 0; mm < " << p.mS/p.simd_width << "; mm++)" << std::endl;
476  stream << "{" << std::endl;
477  stream.inc_tab();
478  switch (p.A_fetching_policy)
479  {
480  case FETCH_FROM_LOCAL:
481  for (unsigned int ss = 0; ss < p.simd_width; ++ss)
482  stream << "rA[kk][mm*" << p.simd_width << "+" << ss << "] = lA[offA + mm*" << p.local_size_0*p.simd_width << "+" << ss << "+ kk*" << (p.mL+1) << "];" << std::endl;
483  break;
484 
486  {
487  if (A_trans_=='N')
488  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
489  else
490  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
491  break;
492  }
493 
495  {
496  if (A_trans_=='N')
497  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm*" + to_string(p.local_size_0) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
498  else
499  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld*" + to_string(p.local_size_0) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
500  break;
501  }
502 
503  //default: break;
504  }
505  stream.dec_tab();
506  stream << "}" << std::endl;
507 
508  stream << "#pragma unroll " << p.kS << std::endl;
509  stream << "for(size_t kk = 0; kk < " << p.kS << "; kk++)" << std::endl;
510  stream << "#pragma unroll " << p.nS/p.simd_width << std::endl;
511  stream << "for(size_t nn = 0; nn < " << p.nS/p.simd_width << "; nn++)" << std::endl;
512  stream << "{" << std::endl;
513  stream.inc_tab();
514  switch (p.B_fetching_policy)
515  {
516  case FETCH_FROM_LOCAL:
517  for (unsigned int ss = 0; ss < p.simd_width; ++ss)
518  stream << "rB[kk][nn*" << p.simd_width << "+" << ss << "] = lB[offB + nn*" << p.local_size_1*p.simd_width << "+" << ss << "+ kk*" << (p.nL+1) << "];" << std::endl;
519  break;
520 
522  {
523  if (B_trans_=='T')
524  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
525  else
526  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
527  break;
528  }
529 
531  {
532  if (B_trans_=='T')
533  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn*" + to_string(p.local_size_1) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
534  else
535  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld*" + to_string(p.local_size_1) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
536  break;
537  }
538 
539  //default: break;
540  }
541  stream.dec_tab();
542  stream << "}" << std::endl;
543 
544 
546  switch (p.A_fetching_policy)
547  {
548  case FETCH_FROM_LOCAL:
549  stream << "offA += " << p.kS*(p.mL+1) << ";" << std::endl;
550  break;
551 
552  default:
553  if (A_trans_=='N')
554  stream << A->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl;
555  else
556  stream << A->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
557  break;
558  }
559 
560 
561  switch (p.B_fetching_policy)
562  {
563  case FETCH_FROM_LOCAL:
564  stream << "offB += " << p.kS*(p.nL+1) << ";" << std::endl;
565  break;
566 
567  default:
568  if (B_trans_=='T')
569  stream << B->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl;
570  else
571  stream << B->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
572  break;
573  }
574 
575 
576  stream << "#pragma unroll " << p.kS << std::endl;
577  stream << "for(size_t kk = 0; kk <" << p.kS << "; ++kk)" << std::endl;
578  stream << "{" << std::endl;
579  stream.inc_tab();
580  for (unsigned int nn=0; nn < p.nS; ++nn)
581  for (unsigned int mm=0; mm < p.mS; ++mm)
582  {
583  string res_str, lhs_str, rhs_str;
584  res_str = "rC[" + tools::to_string(mm) + "][" + tools::to_string(nn) + "]";
586  lhs_str = "rA[kk][" + tools::to_string(mm) + "]";
587  else
588  lhs_str = "rA[kk][" + tools::to_string(mm/p.simd_width) + "].s" + tools::to_string(mm%p.simd_width);
590  rhs_str = "rB[kk]["+tools::to_string(nn)+"]";
591  else
592  rhs_str = "rB[kk]["+tools::to_string(nn/p.simd_width)+"].s"+tools::to_string(nn%p.simd_width);
593  stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
594  }
595  stream.dec_tab();
596  stream << "}" << std::endl;
597 
598 
599 
600 
601  stream.dec_tab();
602  stream << "}" << std::endl;
603 
604  //Increment global pointer if local memory is used
605  //Else, it's incremented directly when fetching
607  {
608  if (A_trans_=='N')
609  stream << A->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl;
610  else
611  stream << A->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
612  }
613 
615  {
616  if (B_trans_=='T')
617  stream << B->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl;
618  else
619  stream << B->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
620  }
621 
622  stream.dec_tab();
623  stream << "}" << std::endl;
624 
625 
626  if (C->row_major())
627  {
628  unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width;
629  unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width;
630 
631  stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#ld;") << std::endl;
632  stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#ld;") << std::endl;
633  stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#stride2;") << std::endl;
634  stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#stride2;") << std::endl;
635 
636  for (unsigned int n=0; n < p.nS; ++n)
637  {
638  for (unsigned int m=0; m < p.mS; ++m)
639  {
640  unsigned int ministride1 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_0;
641  string Cj = to_string((m/p.simd_width)*(ministride1*p.simd_width) + m%p.simd_width);
642  if (fallback)
643  {
644  stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl;
645  stream.inc_tab();
646  }
647  stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + "+ #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl;
648  if (fallback)
649  stream.dec_tab();
650  }
652  stream << C->process("#pointer += #stride2;") << std::endl;
653  else
654  stream << C->process("#pointer += " + to_string((p.local_size_1*p.simd_width) - (p.simd_width-1)) + "*#stride2;") << std::endl;
655  }
656 
657  }
658  else
659  {
660  unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width;
661  unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width;
662 
663  stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#stride1;") << std::endl;
664  stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#stride1;") << std::endl;
665  stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#ld;") << std::endl;
666  stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#ld;") << std::endl;
667 
668  for (unsigned int m=0; m < p.mS; ++m)
669  {
670  for (unsigned int n=0; n < p.nS; ++n)
671  {
672  unsigned int ministride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_1;
673  string Cj = to_string((n/p.simd_width)*(ministride1*p.simd_width) + n%p.simd_width);
674  if (fallback)
675  {
676  stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl;
677  stream.inc_tab();
678  }
679  stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + " + #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl;
680  if (fallback)
681  stream.dec_tab();
682  }
683 
685  stream << C->process("#pointer += #stride1;") << std::endl;
686  else
687  stream << C->process("#pointer += " + to_string((p.local_size_0*p.simd_width) - (p.simd_width-1)) + "*#stride1;") << std::endl;
688  }
689  }
690 
691  stream.dec_tab();
692  stream << "}" << std::endl;
693 
694  return stream.str();
695 
696 #undef VIENNACL_MUL_STRIDE1
697 #undef VIENNACL_HANDLE_BOUNDS
698 #undef VIENNACL_VSTORE
699  }
700 
701  std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings) const
702  {
703  std::vector<std::string> res;
704  res.push_back(generate_impl(kernel_prefix, statements, mappings, false));
705  res.push_back(generate_impl(kernel_prefix, statements, mappings, true));
706  return res;
707  }
708 
709  template<class NumericT>
710  void enqueue_block(scheduler::statement & statement,
712  matrix_base<NumericT> const & A, matrix_base<NumericT> const & B, matrix_base<NumericT> const & C, NumericT beta,
713  std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix, vcl_size_t id)
714  {
715  if (A.size1()==0 || A.size2()==0 || B.size1()==0 || B.size2()==0 || C.size1()==0 || C.size2()==0)
716  return;
717 
718  viennacl::ocl::kernel& kernel = programs[id].program().get_kernel(kernel_prefix);
719 
720  kernel.local_work_size(0, p_.local_size_0);
721  kernel.local_work_size(1, p_.local_size_1);
722 
727 
728  if (id==1)
729  {
732  }
733  else
734  {
735  kernel.global_work_size(0, C.size1()/p_.mS);
736  kernel.global_work_size(1, C.size2()/p_.nS);
737  }
738  unsigned int current_arg = 0;
739  kernel.arg(current_arg++, cl_uint(C.size1()));
740  kernel.arg(current_arg++, cl_uint(C.size2()));
741  if (A.row_major())
742  kernel.arg(current_arg++, cl_uint(A_trans_=='T'?A.size2():A.size1()));
743  else
744  kernel.arg(current_arg++, cl_uint(A_trans_=='N'?A.size2():A.size1()));
745  set_arguments(statement, kernel, current_arg);
746  viennacl::ocl::enqueue(kernel);
747 
748  }
749 
750  template<class NumericT>
752  vcl_size_t s0_0, vcl_size_t s0_1, vcl_size_t s1_0, vcl_size_t s1_1, bool swap)
753  {
754  matrix_base<NumericT> & M = *(element.*ptr);
755  slice s0(s0_0, 1, s0_1 - s0_0);
756  slice s1(s1_0, 1, s1_1 - s1_0);
757  if (swap)
758  std::swap(s0, s1);
760  }
761 
762  template<class NumericT>
763  void enqueue_impl(viennacl::matrix_base<NumericT>* scheduler::lhs_rhs_element::*ptr_matrix,
765  NumericT beta_value, std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix)
766  {
767  using namespace device_specific::utils;
768  vcl_size_t ldstrideA = call_on_matrix(A, leading_stride());
769  vcl_size_t ldstrideB = call_on_matrix(B, leading_stride());
770  vcl_size_t ldstrideC = call_on_matrix(C, leading_stride());
771  vcl_size_t ldstartA = call_on_matrix(A, leading_start());
772  vcl_size_t ldstartB = call_on_matrix(B, leading_start());
773  bool swap_A = ((A_trans_=='T') ^ utils::call_on_matrix(A, row_major_fun()));
774  bool swap_B = ((B_trans_=='T') ^ utils::call_on_matrix(B, row_major_fun()));
775 
776  vcl_size_t M = call_on_matrix(C, size1_fun());
777  vcl_size_t N = call_on_matrix(C, size2_fun());
778  vcl_size_t K;
779  if (utils::call_on_matrix(A, row_major_fun()))
780  K = A_trans_=='T'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
781  else
782  K = A_trans_=='N'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
783 
784  if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 ||
785  (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
786  {
787  enqueue_block(statement, A, B, C, beta, create_slice(ptr_matrix, A, 0, M, 0, K, swap_A),
788  create_slice(ptr_matrix, B, 0, K, 0, N, swap_B),
789  create_slice(ptr_matrix, C, 0, M, 0, N, false), beta_value, programs, kernel_prefix, 1);
790  return;
791  }
792 
793 
794  scheduler::lhs_rhs_element Acopy = A;
795  scheduler::lhs_rhs_element Bcopy = B;
796  scheduler::lhs_rhs_element Ccopy = C;
797 
798  vcl_size_t lM = M / p_.mL * p_.mL;
799  vcl_size_t lN = N / p_.nL * p_.nL;
800  vcl_size_t lK = K / p_.kL * p_.kL;
801 
802 
803  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), beta_value, programs, kernel_prefix, 0);
804  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1);
805 
806  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), beta_value, programs, kernel_prefix, 1);
807  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), (NumericT)1, programs, kernel_prefix, 1);
808 
809  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), beta_value, programs, kernel_prefix, 1);
810  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1);
811 
812  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), beta_value, programs, kernel_prefix, 1);
813  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), (NumericT)1, programs, kernel_prefix, 1);
814  }
815 
816 public:
818 
819  virtual void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements)
820  {
821  using namespace device_specific::utils;
822  using namespace tree_parsing;
823 
824  scheduler::statement const & st = statements.data().front();
825  bool A_trans, B_trans;
826  vcl_size_t C_idx=0, A_idx=0, B_idx=0, alpha_idx=0, beta_idx = 0;
827  leaf_t C_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE;
828  parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
829 
830  scheduler::statement stcopy = st;
831  scheduler::lhs_rhs_element& A = utils::lhs_rhs_element(stcopy, A_idx, A_leaf);
832  scheduler::lhs_rhs_element& B = utils::lhs_rhs_element(stcopy, B_idx, B_leaf);
833  scheduler::lhs_rhs_element& C = utils::lhs_rhs_element(stcopy, C_idx, C_leaf);
834  scheduler::lhs_rhs_element& beta = utils::lhs_rhs_element(stcopy, beta_idx, beta_leaf);
835 
836 
837 
838 
839 
840 
842  enqueue_impl<float>(&scheduler::lhs_rhs_element::matrix_float, stcopy, A, B, C, beta, beta.host_float, programs, kernel_prefix);
844  enqueue_impl<double>(&scheduler::lhs_rhs_element::matrix_double, stcopy, A, B, C, beta, beta.host_double, programs, kernel_prefix);
845  else
846  throw generator_not_supported_exception("GEMM only supported for float/double");
847 
848  }
849 
850 private:
851  const char A_trans_;
852  const char B_trans_;
853 };
854 
855 }
856 
857 }
858 
859 #endif
virtual void enqueue(std::string const &kernel_prefix, std::vector< lazy_program_compiler > &programs, statements_container const &statements)
#define VIENNACL_MUL_STRIDE1
Exception for the case the generator is unable to deal with the operation.
Definition: forwards.h:163
void set_arguments(statements_container const &statements, viennacl::ocl::kernel &kernel, unsigned int &current_arg)
Class for representing strided submatrices of a bigger matrix A.
Definition: forwards.h:443
Represents an OpenCL kernel within ViennaCL.
Definition: kernel.hpp:58
Various little tools used here and there in ViennaCL.
size_type local_work_size(int index=0) const
Returns the local work size at the respective dimension.
Definition: kernel.hpp:742
static void assign_element(lhs_rhs_element &elem, char const &t)
Definition: forwards.h:535
A class representing a compute device (e.g. a GPU)
Definition: device.hpp:49
This file provides the forward declarations for the main types used within ViennaCL.
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:337
container_type const & array() const
Definition: forwards.h:528
viennacl::scalar< float > s1
std::list< scheduler::statement > const & data() const
Definition: forwards.h:282
Forward declaration of dense matrix classes.
viennacl::matrix_base< double > * matrix_double
Definition: forwards.h:410
void swap(vector_base< T > &vec1, vector_base< T > &vec2)
Swaps the contents of two vectors, data is copied.
Definition: vector.hpp:1648
float NumericT
Definition: bisect.cpp:40
std::vector< value_type > container_type
Definition: forwards.h:507
#define VIENNACL_VSTORE(value, offset, ptr)
std::string to_string(viennacl::scheduler::op_element op_elem)
Helper routine for converting the operation enums to string.
Definition: io.hpp:42
Map ViennaCL objects to generator wrappers.
statement_node_numeric_type numeric_type
Definition: forwards.h:341
viennacl::matrix_base< float > * matrix_float
Definition: forwards.h:409
std::size_t vcl_size_t
Definition: forwards.h:75
std::string process(std::string const &in) const
size_type size2() const
Returns the number of columns.
Definition: matrix_def.hpp:226
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
#define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load)
static void generate_prototype(utils::kernel_generation_stream &stream, std::string const &name, std::string const &first_arguments, std::vector< mapping_type > const &mappings, statements_container const &statements, std::map< std::string, unsigned int > const &widths)
size_type size1() const
Returns the number of rows.
Definition: matrix_def.hpp:224
viennacl::enable_if< viennacl::is_scalar< ScalarT1 >::value &&viennacl::is_scalar< ScalarT2 >::value >::type swap(ScalarT1 &s1, ScalarT2 &s2)
Swaps the contents of two scalars, data is copied.
std::map< mapping_key, tools::shared_ptr< mapped_object > > mapping_type
Definition: forwards.h:191
matrix_product_parameters(unsigned int simd_width, unsigned int local_size_0, unsigned int KL, unsigned int local_size_1, unsigned int ms, unsigned int ks, unsigned int ns, fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param, unsigned int local_fetch_0_param, unsigned int local_fetch_1_param)
Proxy classes for matrices.
Code for parsing the expression trees.
INT_TYPE align_to_multiple(INT_TYPE to_reach, INT_TYPE base)
Rounds an integer to the next multiple of another integer.
Definition: tools.hpp:133
void enqueue(KernelType &k, viennacl::ocl::command_queue const &queue)
Enqueues a kernel in the provided queue.
Definition: enqueue.hpp:50
Internal utils.
bool row_major() const
Definition: matrix_def.hpp:248
scheduler::lhs_rhs_element & lhs_rhs_element(scheduler::statement const &st, vcl_size_t idx, leaf_t leaf)
Definition: utils.hpp:525
size_type global_work_size(int index=0) const
Returns the global work size at the respective dimension.
Definition: kernel.hpp:751
size_type root() const
Definition: forwards.h:530
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:502
void arg(unsigned int pos, cl_char val)
Sets a char argument at the provided position.
Definition: kernel.hpp:116
matrix_product_template(matrix_product_template::parameters_type const &parameters, char A_trans, char B_trans)
ValueT const & at(std::map< KeyT, ValueT > const &map, KeyT const &key)
Emulation of C++11's .at() member for std::map<>, const-version.
Definition: forwards.h:142
std::string to_string(T const t)
Definition: tools.hpp:304
A slice class that refers to an interval [start, stop), where 'start' is included, and 'stop' is excluded.
Definition: forwards.h:429
std::pair< vcl_size_t, leaf_t > mapping_key
Definition: forwards.h:188
parameters_type(unsigned int _simd_width, unsigned int _local_size_1, unsigned int _local_size_2, unsigned int _num_kernels)
void process(utils::kernel_generation_stream &stream, leaf_t leaf, std::string const &type_key, std::string const &to_process, scheduler::statement const &statement, vcl_size_t root_idx, mapping_type const &mapping, std::set< std::string > &already_processed)
std::string append_width(std::string const &str, unsigned int width)
Definition: utils.hpp:558