1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP
41 template<
typename StringT>
44 source.append(
"__kernel void vec_mul( \n");
45 source.append(
" __global const uint2 * coords, \n");
46 source.append(
" __global const "); source.append(numeric_string); source.append(
" * elements, \n");
47 source.append(
" __global const uint * group_boundaries, \n");
48 source.append(
" __global const "); source.append(numeric_string); source.append(
" * x, \n");
49 source.append(
" uint4 layout_x, \n");
50 source.append(
" "); source.append(numeric_string); source.append(
" alpha, \n");
51 source.append(
" __global "); source.append(numeric_string); source.append(
" * result, \n");
52 source.append(
" uint4 layout_result, \n");
53 source.append(
" "); source.append(numeric_string); source.append(
" beta, \n");
54 source.append(
" __local unsigned int * shared_rows, \n");
55 source.append(
" __local "); source.append(numeric_string); source.append(
" * inter_results) \n");
56 source.append(
"{ \n");
57 source.append(
" uint2 tmp; \n");
58 source.append(
" "); source.append(numeric_string); source.append(
" val; \n");
59 source.append(
" uint group_start = group_boundaries[get_group_id(0)]; \n");
60 source.append(
" uint group_end = group_boundaries[get_group_id(0) + 1]; \n");
61 source.append(
" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n");
63 source.append(
" uint local_index = 0; \n");
65 source.append(
" for (uint k = 0; k < k_end; ++k) { \n");
66 source.append(
" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
68 source.append(
" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
69 source.append(
" val = (local_index < group_end) ? elements[local_index] * x[tmp.y * layout_x.y + layout_x.x] : 0; \n");
72 source.append(
" if (get_local_id(0) == 0 && k > 0) { \n");
73 source.append(
" if (tmp.x == shared_rows[get_local_size(0)-1]) \n");
74 source.append(
" val += inter_results[get_local_size(0)-1]; \n");
75 source.append(
" else if (beta != 0) \n");
76 source.append(
" result[shared_rows[get_local_size(0)-1] * layout_result.y + layout_result.x] += alpha * inter_results[get_local_size(0)-1]; \n");
77 source.append(
" else \n");
78 source.append(
" result[shared_rows[get_local_size(0)-1] * layout_result.y + layout_result.x] = alpha * inter_results[get_local_size(0)-1]; \n");
79 source.append(
" } \n");
82 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
83 source.append(
" shared_rows[get_local_id(0)] = tmp.x; \n");
84 source.append(
" inter_results[get_local_id(0)] = val; \n");
85 source.append(
" "); source.append(numeric_string); source.append(
" left = 0; \n");
86 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
88 source.append(
" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n");
89 source.append(
" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n");
90 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
91 source.append(
" inter_results[get_local_id(0)] += left; \n");
92 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
93 source.append(
" } \n");
96 source.append(
" if (local_index < group_end - 1 && get_local_id(0) < get_local_size(0) - 1 && \n");
97 source.append(
" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n");
98 source.append(
" if (beta != 0) result[tmp.x * layout_result.y + layout_result.x] += alpha * inter_results[get_local_id(0)]; \n");
99 source.append(
" else result[tmp.x * layout_result.y + layout_result.x] = alpha * inter_results[get_local_id(0)]; \n");
100 source.append(
" } \n");
102 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
103 source.append(
" } \n");
105 source.append(
" if (local_index + 1 == group_end) {\n");
106 source.append(
" if (beta != 0) result[tmp.x * layout_result.y + layout_result.x] += alpha * inter_results[get_local_id(0)]; \n");
107 source.append(
" else result[tmp.x * layout_result.y + layout_result.x] = alpha * inter_results[get_local_id(0)]; \n");
108 source.append(
" } \n");
109 source.append(
"} \n");
116 template<
typename StringT>
118 bool B_transposed,
bool B_row_major,
bool C_row_major)
120 source.append(
"__kernel void ");
122 source.append(
"( \n");
123 source.append(
" __global const uint2 * coords, \n");
124 source.append(
" __global const "); source.append(numeric_string); source.append(
" * elements, \n");
125 source.append(
" __global const uint * group_boundaries, \n");
126 source.append(
" __global const "); source.append(numeric_string); source.append(
" * d_mat, \n");
127 source.append(
" unsigned int d_mat_row_start, \n");
128 source.append(
" unsigned int d_mat_col_start, \n");
129 source.append(
" unsigned int d_mat_row_inc, \n");
130 source.append(
" unsigned int d_mat_col_inc, \n");
131 source.append(
" unsigned int d_mat_row_size, \n");
132 source.append(
" unsigned int d_mat_col_size, \n");
133 source.append(
" unsigned int d_mat_internal_rows, \n");
134 source.append(
" unsigned int d_mat_internal_cols, \n");
135 source.append(
" __global "); source.append(numeric_string); source.append(
" * result, \n");
136 source.append(
" unsigned int result_row_start, \n");
137 source.append(
" unsigned int result_col_start, \n");
138 source.append(
" unsigned int result_row_inc, \n");
139 source.append(
" unsigned int result_col_inc, \n");
140 source.append(
" unsigned int result_row_size, \n");
141 source.append(
" unsigned int result_col_size, \n");
142 source.append(
" unsigned int result_internal_rows, \n");
143 source.append(
" unsigned int result_internal_cols, \n");
144 source.append(
" __local unsigned int * shared_rows, \n");
145 source.append(
" __local "); source.append(numeric_string); source.append(
" * inter_results) \n");
146 source.append(
"{ \n");
147 source.append(
" uint2 tmp; \n");
148 source.append(
" "); source.append(numeric_string); source.append(
" val; \n");
149 source.append(
" uint group_start = group_boundaries[get_group_id(0)]; \n");
150 source.append(
" uint group_end = group_boundaries[get_group_id(0) + 1]; \n");
151 source.append(
" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n");
153 source.append(
" uint local_index = 0; \n");
155 source.append(
" for (uint result_col = 0; result_col < result_col_size; ++result_col) { \n");
156 source.append(
" for (uint k = 0; k < k_end; ++k) { \n");
157 source.append(
" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
159 source.append(
" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
160 if (B_transposed && B_row_major)
161 source.append(
" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + tmp.y * d_mat_col_inc ] : 0; \n");
162 else if (B_transposed && !B_row_major)
163 source.append(
" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) + (d_mat_col_start + tmp.y * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n");
164 else if (!B_transposed && B_row_major)
165 source.append(
" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + tmp.y * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + result_col * d_mat_col_inc ] : 0; \n");
167 source.append(
" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + tmp.y * d_mat_row_inc) + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n");
170 source.append(
" if (get_local_id(0) == 0 && k > 0) { \n");
171 source.append(
" if (tmp.x == shared_rows[get_local_size(0)-1]) \n");
172 source.append(
" val += inter_results[get_local_size(0)-1]; \n");
173 source.append(
" else \n");
175 source.append(
" result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_size(0)-1]; \n");
177 source.append(
" result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_size(0)-1]; \n");
178 source.append(
" } \n");
181 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
182 source.append(
" shared_rows[get_local_id(0)] = tmp.x; \n");
183 source.append(
" inter_results[get_local_id(0)] = val; \n");
184 source.append(
" "); source.append(numeric_string); source.append(
" left = 0; \n");
185 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
187 source.append(
" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n");
188 source.append(
" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n");
189 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
190 source.append(
" inter_results[get_local_id(0)] += left; \n");
191 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
192 source.append(
" } \n");
195 source.append(
" if (local_index < group_end && get_local_id(0) < get_local_size(0) - 1 && \n");
196 source.append(
" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n");
198 source.append(
" result[(tmp.x * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n");
200 source.append(
" result[(tmp.x * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n");
201 source.append(
" } \n");
203 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
204 source.append(
" } \n");
206 source.append(
" if (local_index + 1 == group_end) \n");
208 source.append(
" result[(tmp.x * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n");
210 source.append(
" result[(tmp.x * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n");
211 source.append(
" } \n");
212 source.append(
"} \n");
217 template<
typename StringT>
231 template<
typename StringT>
234 source.append(
"__kernel void row_info_extractor( \n");
235 source.append(
" __global const uint2 * coords, \n");
236 source.append(
" __global const "); source.append(numeric_string); source.append(
" * elements, \n");
237 source.append(
" __global const uint * group_boundaries, \n");
238 source.append(
" __global "); source.append(numeric_string); source.append(
" * result, \n");
239 source.append(
" unsigned int option, \n");
240 source.append(
" __local unsigned int * shared_rows, \n");
241 source.append(
" __local "); source.append(numeric_string); source.append(
" * inter_results) \n");
242 source.append(
"{ \n");
243 source.append(
" uint2 tmp; \n");
244 source.append(
" "); source.append(numeric_string); source.append(
" val; \n");
245 source.append(
" uint last_index = get_local_size(0) - 1; \n");
246 source.append(
" uint group_start = group_boundaries[get_group_id(0)]; \n");
247 source.append(
" uint group_end = group_boundaries[get_group_id(0) + 1]; \n");
248 source.append(
" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : ("); source.append(numeric_string); source.append(
")0; \n");
250 source.append(
" uint local_index = 0; \n");
252 source.append(
" for (uint k = 0; k < k_end; ++k) \n");
253 source.append(
" { \n");
254 source.append(
" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
256 source.append(
" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
257 source.append(
" val = (local_index < group_end && (option != 3 || tmp.x == tmp.y) ) ? elements[local_index] : 0; \n");
260 source.append(
" if (get_local_id(0) == 0 && k > 0) \n");
261 source.append(
" { \n");
262 source.append(
" if (tmp.x == shared_rows[last_index]) \n");
263 source.append(
" { \n");
264 source.append(
" switch (option) \n");
265 source.append(
" { \n");
266 source.append(
" case 0: \n");
267 source.append(
" case 3: \n");
268 source.append(
" val = max(val, fabs(inter_results[last_index])); \n");
269 source.append(
" break; \n");
271 source.append(
" case 1: \n");
272 source.append(
" val = fabs(val) + inter_results[last_index]; \n");
273 source.append(
" break; \n");
275 source.append(
" case 2: \n");
276 source.append(
" val = sqrt(val * val + inter_results[last_index]); \n");
277 source.append(
" break; \n");
279 source.append(
" default: \n");
280 source.append(
" break; \n");
281 source.append(
" } \n");
282 source.append(
" } \n");
283 source.append(
" else \n");
284 source.append(
" { \n");
285 source.append(
" switch (option) \n");
286 source.append(
" { \n");
287 source.append(
" case 0: \n");
288 source.append(
" case 1: \n");
289 source.append(
" case 3: \n");
290 source.append(
" result[shared_rows[last_index]] = inter_results[last_index]; \n");
291 source.append(
" break; \n");
293 source.append(
" case 2: \n");
294 source.append(
" result[shared_rows[last_index]] = sqrt(inter_results[last_index]); \n");
295 source.append(
" default: \n");
296 source.append(
" break; \n");
297 source.append(
" } \n");
298 source.append(
" } \n");
299 source.append(
" } \n");
302 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
303 source.append(
" shared_rows[get_local_id(0)] = tmp.x; \n");
304 source.append(
" switch (option) \n");
305 source.append(
" { \n");
306 source.append(
" case 0: \n");
307 source.append(
" case 3: \n");
308 source.append(
" inter_results[get_local_id(0)] = val; \n");
309 source.append(
" break; \n");
310 source.append(
" case 1: \n");
311 source.append(
" inter_results[get_local_id(0)] = fabs(val); \n");
312 source.append(
" break; \n");
313 source.append(
" case 2: \n");
314 source.append(
" inter_results[get_local_id(0)] = val * val; \n");
315 source.append(
" default: \n");
316 source.append(
" break; \n");
317 source.append(
" } \n");
318 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
320 source.append(
" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) \n");
321 source.append(
" { \n");
322 source.append(
" "); source.append(numeric_string); source.append(
" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : ("); source.append(numeric_string); source.append(
")0; \n");
323 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
324 source.append(
" switch (option) \n");
325 source.append(
" { \n");
326 source.append(
" case 0: \n");
327 source.append(
" case 3: \n");
328 source.append(
" inter_results[get_local_id(0)] = max(inter_results[get_local_id(0)], left); \n");
329 source.append(
" break; \n");
331 source.append(
" case 1: \n");
332 source.append(
" inter_results[get_local_id(0)] += left; \n");
333 source.append(
" break; \n");
335 source.append(
" case 2: \n");
336 source.append(
" inter_results[get_local_id(0)] += left; \n");
337 source.append(
" break; \n");
339 source.append(
" default: \n");
340 source.append(
" break; \n");
341 source.append(
" } \n");
342 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
343 source.append(
" } \n");
346 source.append(
" if (get_local_id(0) != last_index && \n");
347 source.append(
" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1] && \n");
348 source.append(
" inter_results[get_local_id(0)] != 0) \n");
349 source.append(
" { \n");
350 source.append(
" result[tmp.x] = (option == 2) ? sqrt(inter_results[get_local_id(0)]) : inter_results[get_local_id(0)]; \n");
351 source.append(
" } \n");
353 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
354 source.append(
" } \n");
356 source.append(
" if (local_index + 1 == group_end && inter_results[get_local_id(0)] != 0) \n");
357 source.append(
" result[tmp.x] = (option == 2) ? sqrt(inter_results[get_local_id(0)]) : inter_results[get_local_id(0)]; \n");
358 source.append(
"} \n");
365 template<
typename NumericT>
375 static std::map<cl_context, bool> init_done;
382 source.reserve(1024);
384 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
391 #ifdef VIENNACL_BUILD_INFO
392 std::cout <<
"Creating program " << prog_name << std::endl;
394 ctx.add_program(source, prog_name);
395 init_done[ctx.handle().get()] =
true;
static void init(viennacl::ocl::context &ctx)
std::string sparse_dense_matmult_kernel_name(bool B_transposed, bool B_row_major, bool C_row_major)
Returns the OpenCL kernel string for the operation C = A * B with A sparse, B, C dense matrices...
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
Main kernel class for generating OpenCL kernels for coordinate_matrix.
Provides OpenCL-related utilities.
void generate_coordinate_matrix_dense_matrix_mul(StringT &source, std::string const &numeric_string, bool B_transposed, bool B_row_major, bool C_row_major)
Generate kernel for C = A * B with A being a compressed_matrix, B and C dense.
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
Common implementations shared by OpenCL-based operations.
static void apply(viennacl::ocl::context const &)
const OCL_TYPE & get() const
static std::string program_name()
void generate_coordinate_matrix_dense_matrix_multiplication(StringT &source, std::string const &numeric_string)
Representation of an OpenCL kernel in ViennaCL.
void generate_coordinate_matrix_vec_mul(StringT &source, std::string const &numeric_string)
Helper class for converting a type to its string representation.
void generate_coordinate_matrix_row_info_extractor(StringT &source, std::string const &numeric_string)