1 #ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
2 #define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
44 namespace device_specific
51 ,
unsigned int ms,
unsigned int ks,
unsigned int ns
53 ,
unsigned int local_fetch_0_param,
unsigned int local_fetch_1_param):
template_base::
parameters_type(simd_width, local_size_0, local_size_1, 1),
56 mL(ms*local_size_0),
nL(ns*local_size_1){}
78 unsigned int n_lmem_elements()
const
82 N +=
p_.kL * (
p_.mL+1);
84 N +=
p_.nL * (
p_.kL+1);
91 return TEMPLATE_GLOBAL_MEMORY_REQUIRES_ZERO_LOCAL_FETCH;
94 return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
97 return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
99 if (!(A_trans_==
'N' && B_trans_==
'T') &&
p_.
simd_width>1)
100 return TEMPLATE_SIMD_WIDTH_MUST_BE_ONE;
105 return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
110 unsigned int bound1 = (A_trans_==
'N')?
p_.kL:
p_.mL;
111 unsigned int bound0 = (A_trans_==
'N')?
p_.mL:
p_.kL;
113 if (
p_.local_fetch_1>0 && (bound1 %
p_.local_fetch_1)> 0)
114 return A_trans_==
'N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
117 return A_trans_==
'N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
122 unsigned int bound1 = (B_trans_==
'T')?
p_.kL:
p_.nL;
123 unsigned int bound0 = (B_trans_==
'T')?
p_.nL:
p_.kL;
125 if (
p_.local_fetch_1>0 && (bound1 %
p_.local_fetch_1)> 0)
126 return B_trans_==
'T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
129 return B_trans_==
'T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
133 return TEMPLATE_VALID;
141 using namespace tree_parsing;
142 using namespace scheduler;
150 vcl_size_t node_add_idx = array[root_idx].rhs.node_index;
152 vcl_size_t node_1_idx = array[node_add_idx].lhs.node_index;
153 alpha_idx = node_1_idx;
156 vcl_size_t mat_prod_idx = array[node_1_idx].lhs.node_index;
160 A_idx = mat_prod_idx;
165 A_idx = array[mat_prod_idx].lhs.node_index;
172 B_idx = mat_prod_idx;
178 B_idx = array[mat_prod_idx].rhs.node_index;
182 vcl_size_t node_2_idx = array[node_add_idx].rhs.node_index;
183 beta_idx = node_2_idx;
191 stream <<
"if (" << inbounds <<
")" << std::endl;
193 stream << do_if <<
";" << std::endl;
195 stream <<
"else" << std::endl;
197 stream << do_else <<
";" << std::endl;
201 stream << do_if <<
";" << std::endl;
205 std::string generate_impl(
const std::string &kernel_prefix,
const statements_container &statements,
const std::vector<mapping_type> &mappings,
bool fallback)
const
210 parameters_type pfallback(1,
p_.
local_size_0,
p_.
kL,
p_.
local_size_1,
p_.
mS, 1,
p_.
nS,
p_.
A_fetching_policy,
p_.
B_fetching_policy,
p_.
local_fetch_0,
p_.
local_fetch_1);
213 #define VIENNACL_MUL_STRIDE1 string(fallback?"*#stride1":"")
214 #define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load) (!fallback?string(to_load):string( string(in_bounds) + "?" + string(to_load) + ":0"))
215 #define VIENNACL_VSTORE(value, offset, ptr) vstore(p.simd_width, value, offset, ptr)
226 bool A_trans =
false, B_trans =
false;
227 vcl_size_t C_idx=0, alpha_idx=0, A_idx=0, B_idx=0, beta_idx=0;
229 parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
241 stream <<
" __attribute__((reqd_work_group_size(" << p.
local_size_0 <<
"," << p.
local_size_1 <<
",1)))" << std::endl;
242 std::map<std::string, unsigned int> widths;
245 generate_prototype(stream, kernel_prefix,
"unsigned int M, unsigned int N, unsigned int K, ", mappings, statements, widths);
246 stream <<
"{" << std::endl;
276 stream <<
"size_t gidx = get_group_id(0);" << std::endl;
277 stream <<
"size_t gidy = get_group_id(1);" << std::endl;
278 stream <<
"size_t idx = get_local_id(0);" << std::endl;
279 stream <<
"size_t idy = get_local_id(1);" << std::endl;
284 stream <<
"size_t idt = " << p.
local_size_0 <<
"*idy + idx;" << std::endl;
285 stream <<
"size_t idxT = idt % " << p.
local_fetch_0 <<
";" << std::endl;
286 stream <<
"size_t idyT = idt / " << p.
local_fetch_0 <<
";" << std::endl;
293 stream <<
"bool in_bounds_m[" << p.
mS <<
"];" << std::endl;
294 stream <<
"for(size_t m = 0; m < " << p.
mS <<
"; m++)" << std::endl;
299 stream <<
"in_bounds_m[m] = gidx*" << p.
mL <<
" + idx*" << p.
mS <<
" + m < M;" << std::endl;
302 stream <<
"in_bounds_m[m] = gidx*" << p.
mL <<
" + idx + m*" << p.
local_size_0 <<
" < M;" << std::endl;
311 stream <<
"bool in_bounds_m_local[" << p.
mL/fetch_size <<
"];" << std::endl;
312 stream <<
"for(size_t m = 0; m < " << p.
mL/fetch_size <<
"; m++)" << std::endl;
314 stream <<
"in_bounds_m_local[m] = gidx*" << p.
mL <<
" + " << (A_trans_==
'N'?
"idxT":
"idyT") <<
" + m*" << fetch_size <<
" < M;" << std::endl;
319 stream <<
"bool in_bounds_n[" << p.
nS <<
"];" << std::endl;
320 stream <<
"for(size_t n = 0; n < " << p.
nS <<
"; n++)" << std::endl;
325 stream <<
"in_bounds_n[n] = gidy*" << p.
nL <<
" + idy*" << p.
nS <<
" + n < N;" << std::endl;
328 stream <<
"in_bounds_n[n] = gidy*" << p.
nL <<
" + idy + n*" << p.
local_size_1 <<
" < N;" << std::endl;
337 stream <<
"bool in_bounds_n_local[" << p.
nL/fetch_size <<
"];" << std::endl;
338 stream <<
"for(size_t n = 0; n < " << p.
nL/fetch_size <<
"; n++)" << std::endl;
340 stream <<
"in_bounds_n_local[n] = gidy*" << p.
nL <<
" + " << (B_trans_==
'T'?
"idxT":
"idyT") <<
" + n*" << fetch_size <<
" < N;" << std::endl;
398 stream <<
"size_t K_size_t = K;" << std::endl;
399 stream <<
"for(size_t block_k=0; block_k < K_size_t; block_k+=" << p.
kL <<
"){" << std::endl;
407 stream << A->
process(
"__local #scalartype* plA = lA + idxT*" +
to_string(p.
mL + 1) +
" + idyT;") << std::endl;
416 stream << B->
process(
"__local #scalartype* plB = lB + idxT*" +
to_string(p.
nL+1) +
"+ idyT;") <<std::endl;
421 stream <<
"barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
460 stream <<
"barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
461 stream <<
"size_t offA = " << p.
simd_width <<
"*idx;" << std::endl;
462 stream <<
"size_t offB = " << p.
simd_width <<
"*idy;" << std::endl;
466 stream <<
"for(size_t k = 0; k < " << p.
kL <<
" && (block_k + k < K_size_t); k+=" << p.
kS <<
"){" << std::endl;
468 stream <<
"for(size_t k = 0; k < " << p.
kL <<
"; k+=" << p.
kS <<
"){" << std::endl;
472 stream <<
"#pragma unroll " << p.
kS << std::endl;
473 stream <<
"for(size_t kk = 0; kk < " << p.
kS <<
"; kk++)" << std::endl;
474 stream <<
"#pragma unroll " << p.
mS/p.
simd_width << std::endl;
475 stream <<
"for(size_t mm = 0; mm < " << p.
mS/p.
simd_width <<
"; mm++)" << std::endl;
476 stream <<
"{" << std::endl;
481 for (
unsigned int ss = 0; ss < p.
simd_width; ++ss)
506 stream <<
"}" << std::endl;
508 stream <<
"#pragma unroll " << p.
kS << std::endl;
509 stream <<
"for(size_t kk = 0; kk < " << p.
kS <<
"; kk++)" << std::endl;
510 stream <<
"#pragma unroll " << p.
nS/p.
simd_width << std::endl;
511 stream <<
"for(size_t nn = 0; nn < " << p.
nS/p.
simd_width <<
"; nn++)" << std::endl;
512 stream <<
"{" << std::endl;
517 for (
unsigned int ss = 0; ss < p.
simd_width; ++ss)
542 stream <<
"}" << std::endl;
549 stream <<
"offA += " << p.
kS*(p.
mL+1) <<
";" << std::endl;
564 stream <<
"offB += " << p.
kS*(p.
nL+1) <<
";" << std::endl;
576 stream <<
"#pragma unroll " << p.
kS << std::endl;
577 stream <<
"for(size_t kk = 0; kk <" << p.
kS <<
"; ++kk)" << std::endl;
578 stream <<
"{" << std::endl;
580 for (
unsigned int nn=0; nn < p.
nS; ++nn)
581 for (
unsigned int mm=0; mm < p.
mS; ++mm)
583 string res_str, lhs_str, rhs_str;
593 stream << res_str <<
"=" <<
"fma(" << lhs_str <<
"," << rhs_str <<
"," << res_str <<
");" << std::endl;
596 stream <<
"}" << std::endl;
602 stream <<
"}" << std::endl;
623 stream <<
"}" << std::endl;
632 stream << C->
process(
"#pointer += idx*" +
to_string(ministartstride0) +
"*#ld;") << std::endl;
633 stream << C->
process(
"#pointer += gidy*" +
to_string(p.
nL) +
"*#stride2;") << std::endl;
634 stream << C->
process(
"#pointer += idy*" +
to_string(ministartstride1) +
"*#stride2;") << std::endl;
636 for (
unsigned int n=0; n < p.
nS; ++n)
638 for (
unsigned int m=0; m < p.
mS; ++m)
644 stream <<
"if (in_bounds_m[" +
to_string(m) +
"] && in_bounds_n[" +
to_string(n) +
"])" << std::endl;
647 stream << C->
process(
"#pointer[" + Cj +
"*#ld] = rC[" +
to_string(m) +
"][" +
to_string(n) +
"]*" + alpha->
name() +
"+ #pointer[" + Cj +
"*#ld]*" + beta->
name() +
";") << std::endl;
652 stream << C->
process(
"#pointer += #stride2;") << std::endl;
663 stream << C->
process(
"#pointer += gidx*" +
to_string(p.
mL) +
"*#stride1;") << std::endl;
664 stream << C->
process(
"#pointer += idx*" +
to_string(ministartstride0) +
"*#stride1;") << std::endl;
666 stream << C->
process(
"#pointer += idy*" +
to_string(ministartstride1) +
"*#ld;") << std::endl;
668 for (
unsigned int m=0; m < p.
mS; ++m)
670 for (
unsigned int n=0; n < p.
nS; ++n)
676 stream <<
"if (in_bounds_m[" +
to_string(m) +
"] && in_bounds_n[" +
to_string(n) +
"])" << std::endl;
679 stream << C->
process(
"#pointer[" + Cj +
"*#ld] = rC[" +
to_string(m) +
"][" +
to_string(n) +
"]*" + alpha->
name() +
" + #pointer[" + Cj +
"*#ld]*" + beta->
name() +
";") << std::endl;
685 stream << C->
process(
"#pointer += #stride1;") << std::endl;
692 stream <<
"}" << std::endl;
696 #undef VIENNACL_MUL_STRIDE1
697 #undef VIENNACL_HANDLE_BOUNDS
698 #undef VIENNACL_VSTORE
701 std::vector<std::string> generate_impl(std::string
const & kernel_prefix,
statements_container const & statements, std::vector<mapping_type>
const & mappings)
const
703 std::vector<std::string> res;
704 res.push_back(generate_impl(kernel_prefix, statements, mappings,
false));
705 res.push_back(generate_impl(kernel_prefix, statements, mappings,
true));
709 template<
class NumericT>
713 std::vector<lazy_program_compiler> & programs, std::string
const & kernel_prefix,
vcl_size_t id)
738 unsigned int current_arg = 0;
739 kernel.
arg(current_arg++, cl_uint(C.
size1()));
740 kernel.
arg(current_arg++, cl_uint(C.
size2()));
742 kernel.
arg(current_arg++, cl_uint(A_trans_==
'T'?A.
size2():A.
size1()));
744 kernel.
arg(current_arg++, cl_uint(A_trans_==
'N'?A.
size2():A.
size1()));
750 template<
class NumericT>
755 slice s0(s0_0, 1, s0_1 - s0_0);
756 slice s1(s1_0, 1, s1_1 - s1_0);
762 template<
class NumericT>
765 NumericT beta_value, std::vector<lazy_program_compiler> & programs, std::string
const & kernel_prefix)
767 using namespace device_specific::utils;
768 vcl_size_t ldstrideA = call_on_matrix(A, leading_stride());
769 vcl_size_t ldstrideB = call_on_matrix(B, leading_stride());
770 vcl_size_t ldstrideC = call_on_matrix(C, leading_stride());
771 vcl_size_t ldstartA = call_on_matrix(A, leading_start());
772 vcl_size_t ldstartB = call_on_matrix(B, leading_start());
773 bool swap_A = ((A_trans_==
'T') ^ utils::call_on_matrix(A, row_major_fun()));
774 bool swap_B = ((B_trans_==
'T') ^ utils::call_on_matrix(B, row_major_fun()));
776 vcl_size_t M = call_on_matrix(C, size1_fun());
777 vcl_size_t N = call_on_matrix(C, size2_fun());
779 if (utils::call_on_matrix(A, row_major_fun()))
780 K = A_trans_==
'T'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
782 K = A_trans_==
'N'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
784 if (M <
p_.
mL || N <
p_.
nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 ||
787 enqueue_block(statement, A, B, C, beta, create_slice(ptr_matrix, A, 0, M, 0, K, swap_A),
788 create_slice(ptr_matrix, B, 0, K, 0, N, swap_B),
789 create_slice(ptr_matrix, C, 0, M, 0, N,
false), beta_value, programs, kernel_prefix, 1);
803 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN,
false), beta_value, programs, kernel_prefix, 0);
804 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN,
false), (
NumericT)1, programs, kernel_prefix, 1);
806 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N,
false), beta_value, programs, kernel_prefix, 1);
807 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N,
false), (
NumericT)1, programs, kernel_prefix, 1);
809 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN,
false), beta_value, programs, kernel_prefix, 1);
810 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN,
false), (
NumericT)1, programs, kernel_prefix, 1);
812 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N,
false), beta_value, programs, kernel_prefix, 1);
813 enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N,
false), (
NumericT)1, programs, kernel_prefix, 1);
821 using namespace device_specific::utils;
822 using namespace tree_parsing;
825 bool A_trans, B_trans;
826 vcl_size_t C_idx=0, A_idx=0, B_idx=0, alpha_idx=0, beta_idx = 0;
828 parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
virtual void enqueue(std::string const &kernel_prefix, std::vector< lazy_program_compiler > &programs, statements_container const &statements)
unsigned int local_size_0
#define VIENNACL_MUL_STRIDE1
unsigned int local_fetch_0
Exception for the case the generator is unable to deal with the operation.
void set_arguments(statements_container const &statements, viennacl::ocl::kernel &kernel, unsigned int ¤t_arg)
Class for representing strided submatrices of a bigger matrix A.
Represents an OpenCL kernel within ViennaCL.
size_type local_work_size(int index=0) const
Returns the local work size at the respective dimension.
static void assign_element(lhs_rhs_element &elem, char const &t)
parameters_type const & parameters() const
A class representing a compute device (e.g. a GPU)
This file provides the forward declarations for the main types used within ViennaCL.
A class representing the 'data' for the LHS or RHS operand of the respective node.
container_type const & array() const
fetching_policy_type A_fetching_policy
viennacl::scalar< float > s1
std::list< scheduler::statement > const & data() const
Forward declaration of dense matrix classes.
viennacl::matrix_base< double > * matrix_double
void swap(vector_base< T > &vec1, vector_base< T > &vec2)
Swaps the contents of two vectors, data is copied.
std::vector< value_type > container_type
#define VIENNACL_VSTORE(value, offset, ptr)
std::string to_string(viennacl::scheduler::op_element op_elem)
Helper routine for converting the operation enums to string.
unsigned int local_size_1
Map ViennaCL objects to generator wrappers.
statement_node_numeric_type numeric_type
viennacl::matrix_base< float > * matrix_float
std::string process(std::string const &in) const
size_type size2() const
Returns the number of columns.
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
#define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load)
static void generate_prototype(utils::kernel_generation_stream &stream, std::string const &name, std::string const &first_arguments, std::vector< mapping_type > const &mappings, statements_container const &statements, std::map< std::string, unsigned int > const &widths)
size_type size1() const
Returns the number of rows.
viennacl::enable_if< viennacl::is_scalar< ScalarT1 >::value &&viennacl::is_scalar< ScalarT2 >::value >::type swap(ScalarT1 &s1, ScalarT2 &s2)
Swaps the contents of two scalars, data is copied.
std::map< mapping_key, tools::shared_ptr< mapped_object > > mapping_type
matrix_product_parameters(unsigned int simd_width, unsigned int local_size_0, unsigned int KL, unsigned int local_size_1, unsigned int ms, unsigned int ks, unsigned int ns, fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param, unsigned int local_fetch_0_param, unsigned int local_fetch_1_param)
Proxy classes for matrices.
Code for parsing the expression trees.
void enqueue(KernelType &k, viennacl::ocl::command_queue const &queue)
Enqueues a kernel in the provided queue.
scheduler::lhs_rhs_element & lhs_rhs_element(scheduler::statement const &st, vcl_size_t idx, leaf_t leaf)
size_type global_work_size(int index=0) const
Returns the global work size at the respective dimension.
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
void arg(unsigned int pos, cl_char val)
Sets a char argument at the provided position.
matrix_product_template(matrix_product_template::parameters_type const ¶meters, char A_trans, char B_trans)
ValueT const & at(std::map< KeyT, ValueT > const &map, KeyT const &key)
Emulation of C++11's .at() member for std::map<>, const-version.
A slice class that refers to an interval [start, stop), where 'start' is included, and 'stop' is excluded.
std::string const & name() const
fetching_policy_type B_fetching_policy
std::pair< vcl_size_t, leaf_t > mapping_key
unsigned int local_fetch_1
parameters_type(unsigned int _simd_width, unsigned int _local_size_1, unsigned int _local_size_2, unsigned int _num_kernels)
void process(utils::kernel_generation_stream &stream, leaf_t leaf, std::string const &type_key, std::string const &to_process, scheduler::statement const &statement, vcl_size_t root_idx, mapping_type const &mapping, std::set< std::string > &already_processed)
std::string append_width(std::string const &str, unsigned int width)