1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP
37 template <
typename StringType>
40 source.append(
"__kernel void bidiag_pack(__global "); source.append(numeric_string); source.append(
"* A, \n");
41 source.append(
" __global "); source.append(numeric_string); source.append(
"* D, \n");
42 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
43 source.append(
" uint size1, \n");
44 source.append(
" uint size2, \n");
45 source.append(
" uint stride \n");
46 source.append(
") { \n");
47 source.append(
" uint size = min(size1, size2); \n");
49 source.append(
" if(get_global_id(0) == 0) \n");
50 source.append(
" S[0] = 0; \n");
53 source.append(
" for(uint i = get_global_id(0); i < size ; i += get_global_size(0)) { \n");
54 source.append(
" D[i] = A[i*stride + i]; \n");
55 source.append(
" S[i + 1] = (i + 1 < size2) ? A[i*stride + (i + 1)] : 0; \n");
59 source.append(
" for(uint i = get_global_id(0); i < size ; i += get_global_size(0)) { \n");
60 source.append(
" D[i] = A[i*stride + i]; \n");
61 source.append(
" S[i + 1] = (i + 1 < size2) ? A[i + (i + 1) * stride] : 0; \n");
63 source.append(
" } \n");
64 source.append(
"} \n");
67 template<
typename StringT>
71 source.append(
"void col_reduce_lcl_array(__local "); source.append(numeric_string); source.append(
"* sums, uint lcl_id, uint lcl_sz) { \n");
72 source.append(
" uint step = lcl_sz >> 1; \n");
74 source.append(
" while (step > 0) { \n");
75 source.append(
" if (lcl_id < step) { \n");
76 source.append(
" sums[lcl_id] += sums[lcl_id + step]; \n");
77 source.append(
" } \n");
78 source.append(
" step >>= 1; \n");
79 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
80 source.append(
" } \n");
81 source.append(
"} \n");
84 template <
typename StringType>
88 source.append(
"__kernel void copy_col(__global "); source.append(numeric_string); source.append(
"* A, \n");
89 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
90 source.append(
" uint row_start, \n");
91 source.append(
" uint col_start, \n");
92 source.append(
" uint size, \n");
93 source.append(
" uint stride \n");
94 source.append(
" ) { \n");
95 source.append(
" uint glb_id = get_global_id(0); \n");
96 source.append(
" uint glb_sz = get_global_size(0); \n");
99 source.append(
" for(uint i = row_start + glb_id; i < size; i += glb_sz) { \n");
100 source.append(
" V[i - row_start] = A[i * stride + col_start]; \n");
101 source.append(
" } \n");
105 source.append(
" for(uint i = row_start + glb_id; i < size; i += glb_sz) { \n");
106 source.append(
" V[i - row_start] = A[i + col_start * stride]; \n");
107 source.append(
" } \n");
110 source.append(
"} \n");
113 template <
typename StringType>
117 source.append(
"__kernel void copy_row(__global "); source.append(numeric_string); source.append(
"* A, \n");
118 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
119 source.append(
" uint row_start, \n");
120 source.append(
" uint col_start, \n");
121 source.append(
" uint size, \n");
122 source.append(
" uint stride \n");
123 source.append(
" ) { \n");
124 source.append(
" uint glb_id = get_global_id(0); \n");
125 source.append(
" uint glb_sz = get_global_size(0); \n");
128 source.append(
" for(uint i = col_start + glb_id; i < size; i += glb_sz) { \n");
129 source.append(
" V[i - col_start] = A[row_start * stride + i]; \n");
130 source.append(
" } \n");
134 source.append(
" for(uint i = col_start + glb_id; i < size; i += glb_sz) { \n");
135 source.append(
" V[i - col_start] = A[row_start + i * stride]; \n");
136 source.append(
" } \n");
139 source.append(
"} \n");
142 template<
typename StringT>
145 source.append(
"__kernel void final_iter_update(__global "); source.append(numeric_string); source.append(
"* A, \n");
146 source.append(
" uint stride, \n");
147 source.append(
" uint n, \n");
148 source.append(
" uint last_n, \n");
149 source.append(
" "); source.append(numeric_string); source.append(
" q, \n");
150 source.append(
" "); source.append(numeric_string); source.append(
" p \n");
151 source.append(
" ) \n");
152 source.append(
"{ \n");
153 source.append(
" uint glb_id = get_global_id(0); \n");
154 source.append(
" uint glb_sz = get_global_size(0); \n");
156 source.append(
" for (uint px = glb_id; px < last_n; px += glb_sz) \n");
157 source.append(
" { \n");
158 source.append(
" "); source.append(numeric_string); source.append(
" v_in = A[n * stride + px]; \n");
159 source.append(
" "); source.append(numeric_string); source.append(
" z = A[(n - 1) * stride + px]; \n");
160 source.append(
" A[(n - 1) * stride + px] = q * z + p * v_in; \n");
161 source.append(
" A[n * stride + px] = q * v_in - p * z; \n");
162 source.append(
" } \n");
163 source.append(
"} \n");
166 template <
typename StringType>
169 source.append(
"__kernel void givens_next(__global "); source.append(numeric_string); source.append(
"* matr, \n");
170 source.append(
" __global "); source.append(numeric_string); source.append(
"* cs, \n");
171 source.append(
" __global "); source.append(numeric_string); source.append(
"* ss, \n");
172 source.append(
" uint size, \n");
173 source.append(
" uint stride, \n");
174 source.append(
" uint start_i, \n");
175 source.append(
" uint end_i \n");
176 source.append(
" ) \n");
177 source.append(
"{ \n");
178 source.append(
" uint glb_id = get_global_id(0); \n");
179 source.append(
" uint glb_sz = get_global_size(0); \n");
181 source.append(
" uint lcl_id = get_local_id(0); \n");
182 source.append(
" uint lcl_sz = get_local_size(0); \n");
184 source.append(
" uint j = glb_id; \n");
186 source.append(
" __local "); source.append(numeric_string); source.append(
" cs_lcl[256]; \n");
187 source.append(
" __local "); source.append(numeric_string); source.append(
" ss_lcl[256]; \n");
191 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(end_i + 1) + j * stride] : 0; \n");
193 source.append(
" uint elems_num = end_i - start_i + 1; \n");
194 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
196 source.append(
" for(uint block_id = 0; block_id < block_num; block_id++) \n");
197 source.append(
" { \n");
198 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
200 source.append(
" if(lcl_id < to) \n");
201 source.append(
" { \n");
202 source.append(
" cs_lcl[lcl_id] = cs[end_i - (lcl_id + block_id * lcl_sz)]; \n");
203 source.append(
" ss_lcl[lcl_id] = ss[end_i - (lcl_id + block_id * lcl_sz)]; \n");
204 source.append(
" } \n");
206 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
208 source.append(
" if(j < size) \n");
209 source.append(
" { \n");
210 source.append(
" for(uint ind = 0; ind < to; ind++) \n");
211 source.append(
" { \n");
212 source.append(
" uint i = end_i - (ind + block_id * lcl_sz); \n");
214 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i + j * stride]; \n");
216 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind]; \n");
217 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind]; \n");
219 source.append(
" matr[(i + 1) + j * stride] = x * cs_val + z * ss_val; \n");
220 source.append(
" x = -x * ss_val + z * cs_val; \n");
221 source.append(
" } \n");
222 source.append(
" } \n");
223 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
224 source.append(
" } \n");
225 source.append(
" if(j < size) \n");
226 source.append(
" matr[(start_i) + j * stride] = x; \n");
231 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(end_i + 1) * stride + j] : 0; \n");
233 source.append(
" uint elems_num = end_i - start_i + 1; \n");
234 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
236 source.append(
" for(uint block_id = 0; block_id < block_num; block_id++) \n");
237 source.append(
" { \n");
238 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
240 source.append(
" if(lcl_id < to) \n");
241 source.append(
" { \n");
242 source.append(
" cs_lcl[lcl_id] = cs[end_i - (lcl_id + block_id * lcl_sz)]; \n");
243 source.append(
" ss_lcl[lcl_id] = ss[end_i - (lcl_id + block_id * lcl_sz)]; \n");
244 source.append(
" } \n");
246 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
248 source.append(
" if(j < size) \n");
249 source.append(
" { \n");
250 source.append(
" for(uint ind = 0; ind < to; ind++) \n");
251 source.append(
" { \n");
252 source.append(
" uint i = end_i - (ind + block_id * lcl_sz); \n");
254 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i * stride + j]; \n");
256 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind]; \n");
257 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind]; \n");
259 source.append(
" matr[(i + 1) * stride + j] = x * cs_val + z * ss_val; \n");
260 source.append(
" x = -x * ss_val + z * cs_val; \n");
261 source.append(
" } \n");
262 source.append(
" } \n");
263 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
264 source.append(
" } \n");
265 source.append(
" if(j < size) \n");
266 source.append(
" matr[(start_i) * stride + j] = x; \n");
268 source.append(
"} \n");
271 template<
typename StringT>
274 source.append(
"__kernel void givens_prev(__global "); source.append(numeric_string); source.append(
"* matr, \n");
275 source.append(
" __global "); source.append(numeric_string); source.append(
"* cs, \n");
276 source.append(
" __global "); source.append(numeric_string); source.append(
"* ss, \n");
277 source.append(
" uint size, \n");
278 source.append(
" uint stride, \n");
279 source.append(
" uint start_i, \n");
280 source.append(
" uint end_i \n");
281 source.append(
" ) \n");
282 source.append(
"{ \n");
283 source.append(
" uint glb_id = get_global_id(0); \n");
284 source.append(
" uint glb_sz = get_global_size(0); \n");
286 source.append(
" uint lcl_id = get_local_id(0); \n");
287 source.append(
" uint lcl_sz = get_local_size(0); \n");
289 source.append(
" uint j = glb_id; \n");
291 source.append(
" __local "); source.append(numeric_string); source.append(
" cs_lcl[256]; \n");
292 source.append(
" __local "); source.append(numeric_string); source.append(
" ss_lcl[256]; \n");
294 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(start_i - 1) * stride + j] : 0; \n");
296 source.append(
" uint elems_num = end_i - start_i; \n");
297 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
299 source.append(
" for (uint block_id = 0; block_id < block_num; block_id++) \n");
300 source.append(
" { \n");
301 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
303 source.append(
" if (lcl_id < to) \n");
304 source.append(
" { \n");
305 source.append(
" cs_lcl[lcl_id] = cs[lcl_id + start_i + block_id * lcl_sz]; \n");
306 source.append(
" ss_lcl[lcl_id] = ss[lcl_id + start_i + block_id * lcl_sz]; \n");
307 source.append(
" } \n");
309 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
311 source.append(
" if (j < size) \n");
312 source.append(
" { \n");
313 source.append(
" for (uint ind = 0; ind < to; ind++) \n");
314 source.append(
" { \n");
315 source.append(
" uint i = ind + start_i + block_id * lcl_sz; \n");
317 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i * stride + j]; \n");
319 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind];//cs[i]; \n");
320 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind];//ss[i]; \n");
322 source.append(
" matr[(i - 1) * stride + j] = x * cs_val + z * ss_val; \n");
323 source.append(
" x = -x * ss_val + z * cs_val; \n");
324 source.append(
" } \n");
325 source.append(
" } \n");
326 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
327 source.append(
" } \n");
328 source.append(
" if (j < size) \n");
329 source.append(
" matr[(end_i - 1) * stride + j] = x; \n");
330 source.append(
"} \n");
333 template <
typename StringType>
336 source.append(
"__kernel void house_update_A_left( \n");
337 source.append(
" __global "); source.append(numeric_string); source.append(
"* A, \n");
338 source.append(
" __constant "); source.append(numeric_string); source.append(
"* V, \n");
339 source.append(
" uint row_start, \n");
340 source.append(
" uint col_start, \n");
341 source.append(
" uint size1, \n");
342 source.append(
" uint size2, \n");
343 source.append(
" uint stride, \n");
344 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
345 source.append(
" ) { \n");
346 source.append(
" uint glb_id = get_global_id(0); \n");
347 source.append(
" uint glb_sz = get_global_size(0); \n");
349 source.append(
" uint grp_id = get_group_id(0); \n");
350 source.append(
" uint grp_nm = get_num_groups(0); \n");
352 source.append(
" uint lcl_id = get_local_id(0); \n");
353 source.append(
" uint lcl_sz = get_local_size(0); \n");
355 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
360 source.append(
" for(uint i = glb_id + col_start; i < size2; i += glb_sz) { \n");
361 source.append(
" ss = 0; \n");
362 source.append(
" for(uint j = row_start; j < size1; j++) ss = ss + (V[j] * A[j * stride + i]); \n");
364 source.append(
" for(uint j = row_start; j < size1; j++) \n");
365 source.append(
" A[j * stride + i] = A[j * stride + i] - (2 * V[j] * ss); \n");
366 source.append(
" } \n");
370 source.append(
" for(uint i = glb_id + col_start; i < size2; i += glb_sz) { \n");
371 source.append(
" ss = 0; \n");
372 source.append(
" for(uint j = row_start; j < size1; j++) ss = ss + (V[j] * A[j + i * stride]); \n");
374 source.append(
" for(uint j = row_start; j < size1; j++) \n");
375 source.append(
" A[j + i * stride] = A[j + i * stride] - (2 * V[j] * ss); \n");
376 source.append(
" } \n");
378 source.append(
"} \n");
381 template <
typename StringType>
385 source.append(
"__kernel void house_update_A_right( \n");
386 source.append(
" __global "); source.append(numeric_string); source.append(
"* A, \n");
387 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
388 source.append(
" uint row_start, \n");
389 source.append(
" uint col_start, \n");
390 source.append(
" uint size1, \n");
391 source.append(
" uint size2, \n");
392 source.append(
" uint stride, \n");
393 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
394 source.append(
" ) { \n");
396 source.append(
" uint glb_id = get_global_id(0); \n");
398 source.append(
" uint grp_id = get_group_id(0); \n");
399 source.append(
" uint grp_nm = get_num_groups(0); \n");
401 source.append(
" uint lcl_id = get_local_id(0); \n");
402 source.append(
" uint lcl_sz = get_local_size(0); \n");
404 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
409 source.append(
" for(uint i = grp_id + row_start; i < size1; i += grp_nm) { \n");
410 source.append(
" ss = 0; \n");
412 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * A[i * stride + j]); \n");
413 source.append(
" sums[lcl_id] = ss; \n");
415 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
416 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
417 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
419 source.append(
" "); source.append(numeric_string); source.append(
" sum_Av = sums[0]; \n");
421 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) \n");
422 source.append(
" A[i * stride + j] = A[i * stride + j] - (2 * V[j] * sum_Av); \n");
423 source.append(
" } \n");
427 source.append(
" for(uint i = grp_id + row_start; i < size1; i += grp_nm) { \n");
428 source.append(
" ss = 0; \n");
430 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * A[i + j * stride]); \n");
431 source.append(
" sums[lcl_id] = ss; \n");
433 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
434 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
435 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
437 source.append(
" "); source.append(numeric_string); source.append(
" sum_Av = sums[0]; \n");
439 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) \n");
440 source.append(
" A[i + j * stride] = A[i + j * stride] - (2 * V[j] * sum_Av); \n");
441 source.append(
" } \n");
444 source.append(
"} \n");
448 template <
typename StringType>
451 source.append(
"__kernel void house_update_QL(\n");
452 source.append(
" __global "); source.append(numeric_string); source.append(
"* QL, \n");
453 source.append(
" __constant "); source.append(numeric_string); source.append(
"* V, \n");
454 source.append(
" uint size1, \n");
455 source.append(
" uint strideQ, \n");
456 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
457 source.append(
" ) { \n");
458 source.append(
" uint glb_id = get_global_id(0); \n");
459 source.append(
" uint glb_sz = get_global_size(0); \n");
461 source.append(
" uint grp_id = get_group_id(0); \n");
462 source.append(
" uint grp_nm = get_num_groups(0); \n");
464 source.append(
" uint lcl_id = get_local_id(0); \n");
465 source.append(
" uint lcl_sz = get_local_size(0); \n");
467 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
471 source.append(
" for(uint i = grp_id; i < size1; i += grp_nm) { \n");
472 source.append(
" ss = 0; \n");
473 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) ss = ss + (V[j] * QL[i * strideQ + j]); \n");
474 source.append(
" sums[lcl_id] = ss; \n");
476 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
477 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
478 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
480 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
482 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) \n");
483 source.append(
" QL[i * strideQ + j] = QL[i * strideQ + j] - (2 * V[j] * sum_Qv); \n");
484 source.append(
" } \n");
488 source.append(
" for(uint i = grp_id; i < size1; i += grp_nm) { \n");
489 source.append(
" ss = 0; \n");
490 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) ss = ss + (V[j] * QL[i + j * strideQ]); \n");
491 source.append(
" sums[lcl_id] = ss; \n");
493 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
494 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
495 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
497 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
499 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) \n");
500 source.append(
" QL[i + j * strideQ] = QL[i + j * strideQ] - (2 * V[j] * sum_Qv); \n");
501 source.append(
" } \n");
503 source.append(
"} \n");
507 template<
typename StringT>
510 source.append(
"__kernel void house_update_QR( \n");
511 source.append(
" __global "); source.append(numeric_string); source.append(
"* QR, \n");
512 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
513 source.append(
" uint size1, \n");
514 source.append(
" uint size2, \n");
515 source.append(
" uint strideQ, \n");
516 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
517 source.append(
" ) { \n");
519 source.append(
" uint glb_id = get_global_id(0); \n");
521 source.append(
" uint grp_id = get_group_id(0); \n");
522 source.append(
" uint grp_nm = get_num_groups(0); \n");
524 source.append(
" uint lcl_id = get_local_id(0); \n");
525 source.append(
" uint lcl_sz = get_local_size(0); \n");
527 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
532 source.append(
" for (uint i = grp_id; i < size2; i += grp_nm) { \n");
533 source.append(
" ss = 0; \n");
534 source.append(
" for (uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * QR[i * strideQ + j]); \n");
535 source.append(
" sums[lcl_id] = ss; \n");
537 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
538 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
539 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
541 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
542 source.append(
" for (uint j = lcl_id; j < size2; j += lcl_sz) \n");
543 source.append(
" QR[i * strideQ + j] = QR[i * strideQ + j] - (2 * V[j] * sum_Qv); \n");
544 source.append(
" } \n");
545 source.append(
"} \n");
548 template<
typename StringT>
551 source.append(
"__kernel void inverse_signs(__global "); source.append(numeric_string); source.append(
"* v, \n");
552 source.append(
" __global "); source.append(numeric_string); source.append(
"* signs, \n");
553 source.append(
" uint size, \n");
554 source.append(
" uint stride \n");
555 source.append(
" ) \n");
556 source.append(
"{ \n");
557 source.append(
" uint glb_id_x = get_global_id(0); \n");
558 source.append(
" uint glb_id_y = get_global_id(1); \n");
560 source.append(
" if ((glb_id_x < size) && (glb_id_y < size)) \n");
561 source.append(
" v[glb_id_x * stride + glb_id_y] *= signs[glb_id_x]; \n");
562 source.append(
"} \n");
566 template<
typename StringT>
570 source.append(
"__kernel void transpose_inplace(__global "); source.append(numeric_string); source.append(
"* input, \n");
571 source.append(
" unsigned int row_num, \n");
572 source.append(
" unsigned int col_num) { \n");
573 source.append(
" unsigned int size = row_num * col_num; \n");
574 source.append(
" for (unsigned int i = get_global_id(0); i < size; i+= get_global_size(0)) { \n");
575 source.append(
" unsigned int row = i / col_num; \n");
576 source.append(
" unsigned int col = i - row*col_num; \n");
578 source.append(
" unsigned int new_pos = col * row_num + row; \n");
583 source.append(
" if (i < new_pos) { \n");
584 source.append(
" "); source.append(numeric_string); source.append(
" val = input[i]; \n");
585 source.append(
" input[i] = input[new_pos]; \n");
586 source.append(
" input[new_pos] = val; \n");
587 source.append(
" } \n");
588 source.append(
" } \n");
589 source.append(
"} \n");
593 template<
typename StringT>
596 source.append(
"__kernel void update_qr_column(__global "); source.append(numeric_string); source.append(
"* A, \n");
597 source.append(
" uint stride, \n");
598 source.append(
" __global "); source.append(numeric_string); source.append(
"* buf, \n");
599 source.append(
" int m, \n");
600 source.append(
" int n, \n");
601 source.append(
" int last_n) \n");
602 source.append(
"{ \n");
603 source.append(
" uint glb_id = get_global_id(0); \n");
604 source.append(
" uint glb_sz = get_global_size(0); \n");
606 source.append(
" for (int i = glb_id; i < last_n; i += glb_sz) \n");
607 source.append(
" { \n");
608 source.append(
" "); source.append(numeric_string); source.append(
" a_ik = A[m * stride + i], a_ik_1, a_ik_2; \n");
610 source.append(
" a_ik_1 = A[(m + 1) * stride + i]; \n");
612 source.append(
" for (int k = m; k < n; k++) \n");
613 source.append(
" { \n");
614 source.append(
" bool notlast = (k != n - 1); \n");
616 source.append(
" "); source.append(numeric_string); source.append(
" p = buf[5 * k] * a_ik + buf[5 * k + 1] * a_ik_1; \n");
618 source.append(
" if (notlast) \n");
619 source.append(
" { \n");
620 source.append(
" a_ik_2 = A[(k + 2) * stride + i]; \n");
621 source.append(
" p = p + buf[5 * k + 2] * a_ik_2; \n");
622 source.append(
" a_ik_2 = a_ik_2 - p * buf[5 * k + 4]; \n");
623 source.append(
" } \n");
625 source.append(
" A[k * stride + i] = a_ik - p; \n");
626 source.append(
" a_ik_1 = a_ik_1 - p * buf[5 * k + 3]; \n");
628 source.append(
" a_ik = a_ik_1; \n");
629 source.append(
" a_ik_1 = a_ik_2; \n");
630 source.append(
" } \n");
632 source.append(
" A[n * stride + i] = a_ik; \n");
633 source.append(
" } \n");
635 source.append(
"} \n");
643 template<
typename NumericT,
typename MatrixLayout = row_major>
654 static std::map<cl_context, bool> init_done;
662 source.reserve(1024);
664 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
667 if (numeric_string ==
"float" || numeric_string ==
"double")
689 #ifdef VIENNACL_BUILD_INFO
690 std::cout <<
"Creating program " << prog_name << std::endl;
692 ctx.add_program(source, prog_name);
693 init_done[ctx.handle().get()] =
true;
void generate_svd_copy_row(StringType &source, std::string const &numeric_string, bool is_row_major)
Helper class for checking whether a matrix has a row-major layout.
void generate_svd_bidiag_pack(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_final_iter_update(StringT &source, std::string const &numeric_string)
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
void generate_svd_update_qr_column(StringT &source, std::string const &numeric_string)
Provides OpenCL-related utilities.
static std::string program_name()
void generate_svd_inverse_signs(StringT &source, std::string const &numeric_string)
static void init(viennacl::ocl::context &ctx)
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
void generate_svd_house_update_QR(StringT &source, std::string const &numeric_string)
void generate_svd_transpose_inplace(StringT &source, std::string const &numeric_string)
void generate_svd_givens_prev(StringT &source, std::string const &numeric_string)
static void apply(viennacl::ocl::context const &)
const OCL_TYPE & get() const
Main kernel class for generating OpenCL kernels for singular value decomposition of dense matrices...
void generate_svd_house_update_A_right(StringType &source, std::string const &numeric_string, bool is_row_major)
Representation of an OpenCL kernel in ViennaCL.
void generate_svd_col_reduce_lcl_array(StringT &source, std::string const &numeric_string)
void generate_svd_copy_col(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_house_update_QL(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_givens_next(StringType &source, std::string const &numeric_string, bool is_row_major)
Helper class for converting a type to its string representation.
void generate_svd_house_update_A_left(StringType &source, std::string const &numeric_string, bool is_row_major)