ViennaCL - The Vienna Computing Library  1.7.1
Free open-source GPU-accelerated linear algebra and solver library.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
execute_matrix_prod.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_SCHEDULER_EXECUTE_MATRIX_PROD_HPP
2 #define VIENNACL_SCHEDULER_EXECUTE_MATRIX_PROD_HPP
3 
4 /* =========================================================================
5  Copyright (c) 2010-2016, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8  Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14  Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16  (A list of authors and contributors can be found in the manual)
17 
18  License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
26 #include "viennacl/forwards.h"
35 #include "viennacl/ell_matrix.hpp"
36 #include "viennacl/hyb_matrix.hpp"
37 
38 namespace viennacl
39 {
40 namespace scheduler
41 {
42 namespace detail
43 {
44  inline bool matrix_prod_temporary_required(statement const & s, lhs_rhs_element const & elem)
45  {
47  return false;
48 
49  // check composite node for being a transposed matrix proxy:
50  statement_node const & leaf = s.array()[elem.node_index];
52  return false;
53 
54  return true;
55  }
56 
57  inline void matrix_matrix_prod(statement const & s,
58  lhs_rhs_element result,
59  lhs_rhs_element const & A,
60  lhs_rhs_element const & B,
61  double alpha,
62  double beta)
63  {
64  if (A.type_family == MATRIX_TYPE_FAMILY && B.type_family == MATRIX_TYPE_FAMILY) // C = A * B
65  {
66  assert( A.numeric_type == B.numeric_type && bool("Numeric type not the same!"));
67  assert( result.numeric_type == B.numeric_type && bool("Numeric type not the same!"));
68 
70  {
71  switch (result.numeric_type)
72  {
73  case FLOAT_TYPE:
74  viennacl::linalg::prod_impl(*A.matrix_float, *B.matrix_float, *result.matrix_float, static_cast<float>(alpha), static_cast<float>(beta)); break;
75  case DOUBLE_TYPE:
76  viennacl::linalg::prod_impl(*A.matrix_double, *B.matrix_double, *result.matrix_double, alpha, beta); break;
77  default:
78  throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");
79  }
80  }
81 
82  }
83  else if (A.type_family == MATRIX_TYPE_FAMILY && B.type_family == COMPOSITE_OPERATION_FAMILY) // C = A * B^T
84  {
85  statement_node const & leaf = s.array()[B.node_index];
86 
87  assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!"));
88  assert(leaf.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!"));
89  assert(result.numeric_type == A.numeric_type && bool("Numeric type not the same!"));
90 
92  {
93  switch (result.numeric_type)
94  {
95  case FLOAT_TYPE:
98  const matrix_base<float>,
99  op_trans> (*(leaf.lhs.matrix_float), *(leaf.lhs.matrix_float)),
100  *result.matrix_float, static_cast<float>(alpha), static_cast<float>(beta)); break;
101  case DOUBLE_TYPE:
104  const matrix_base<double>,
105  op_trans>(*(leaf.lhs.matrix_double), *(leaf.lhs.matrix_double)),
106  *result.matrix_double, alpha, beta); break;
107  default:
108  throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");
109  }
110  }
111  }
112  else if (A.type_family == COMPOSITE_OPERATION_FAMILY && B.type_family == MATRIX_TYPE_FAMILY) // C = A^T * B
113  {
114  statement_node const & leaf = s.array()[A.node_index];
115 
116  assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!"));
117  assert(leaf.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!"));
118  assert(result.numeric_type == B.numeric_type && bool("Numeric type not the same!"));
119 
121  {
122  switch (result.numeric_type)
123  {
124  case FLOAT_TYPE:
126  const matrix_base<float>,
127  op_trans>(*leaf.lhs.matrix_float, *leaf.lhs.matrix_float),
128  *B.matrix_float,
129  *result.matrix_float, static_cast<float>(alpha), static_cast<float>(beta)); break;
130  case DOUBLE_TYPE:
132  const matrix_base<double>,
133  op_trans>(*leaf.lhs.matrix_double, *leaf.lhs.matrix_double),
134  *B.matrix_double,
135  *result.matrix_double, alpha, beta); break;
136  default:
137  throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");
138  }
139  }
140  }
141  else if (A.type_family == COMPOSITE_OPERATION_FAMILY && B.type_family == COMPOSITE_OPERATION_FAMILY) // C = A^T * B^T
142  {
143  statement_node const & leafA = s.array()[A.node_index];
144  statement_node const & leafB = s.array()[B.node_index];
145 
146  assert(leafA.lhs.type_family == MATRIX_TYPE_FAMILY && leafA.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!"));
147  assert(leafB.lhs.type_family == MATRIX_TYPE_FAMILY && leafB.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!"));
148  assert(leafA.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!"));
149  assert(leafB.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!"));
150 
151  if (leafA.lhs.subtype == DENSE_MATRIX_TYPE && leafB.lhs.subtype == DENSE_MATRIX_TYPE && result.subtype == DENSE_MATRIX_TYPE)
152  {
153  switch (result.numeric_type)
154  {
155  case FLOAT_TYPE:
157  const matrix_base<float>,
158  op_trans>(*leafA.lhs.matrix_float, *leafA.lhs.matrix_float),
160  const matrix_base<float>,
161  op_trans>(*leafB.lhs.matrix_float, *leafB.lhs.matrix_float),
162  *result.matrix_float, static_cast<float>(alpha), static_cast<float>(beta)); break;
163  case DOUBLE_TYPE:
165  const matrix_base<double>,
166  op_trans>(*leafA.lhs.matrix_double, *leafA.lhs.matrix_double),
168  const matrix_base<double>,
169  op_trans>(*leafB.lhs.matrix_double, *leafB.lhs.matrix_double),
170  *result.matrix_double, alpha, beta); break;
171  default:
172  throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");
173  }
174  }
175  }
176  else
177  throw statement_not_supported_exception("Matrix-matrix multiplication encountered operands being neither dense matrices nor transposed dense matrices");
178  }
179 
180  inline void matrix_vector_prod(statement const & s,
181  lhs_rhs_element result,
182  lhs_rhs_element const & A,
183  lhs_rhs_element const & x)
184  {
185  assert( result.numeric_type == x.numeric_type && bool("Numeric type not the same!"));
186  assert( result.type_family == x.type_family && bool("Subtype not the same!"));
187  assert( result.subtype == DENSE_VECTOR_TYPE && bool("Result node for matrix-vector product not a vector type!"));
188 
189  // deal with transposed product first:
190  // switch: trans for A
191  if (A.type_family == COMPOSITE_OPERATION_FAMILY) // prod(trans(A), x)
192  {
193  statement_node const & leaf = s.array()[A.node_index];
194 
195  assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!"));
196  assert(leaf.lhs.numeric_type == x.numeric_type && bool("Numeric type not the same!"));
197 
198  if (leaf.lhs.subtype == DENSE_MATRIX_TYPE)
199  {
200  switch (leaf.lhs.numeric_type)
201  {
202  case FLOAT_TYPE:
204  const matrix_base<float>,
205  op_trans>(*leaf.lhs.matrix_float, *leaf.lhs.matrix_float),
206  *x.vector_float,
207  *result.vector_float); break;
208  case DOUBLE_TYPE:
210  const matrix_base<double>,
211  op_trans>(*leaf.lhs.matrix_double, *leaf.lhs.matrix_double),
212  *x.vector_double,
213  *result.vector_double); break;
214  default:
215  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
216  }
217  }
218  else
219  throw statement_not_supported_exception("Invalid matrix type for transposed matrix-vector product");
220  }
221  else if (A.subtype == DENSE_MATRIX_TYPE)
222  {
223  switch (A.numeric_type)
224  {
225  case FLOAT_TYPE:
227  break;
228  case DOUBLE_TYPE:
230  break;
231  default:
232  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
233  }
234  }
235  else if (A.subtype == COMPRESSED_MATRIX_TYPE)
236  {
237  switch (A.numeric_type)
238  {
239  case FLOAT_TYPE:
241  break;
242  case DOUBLE_TYPE:
244  break;
245  default:
246  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
247  }
248  }
249  else if (A.subtype == COORDINATE_MATRIX_TYPE)
250  {
251  switch (A.numeric_type)
252  {
253  case FLOAT_TYPE:
255  break;
256  case DOUBLE_TYPE:
258  break;
259  default:
260  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
261  }
262  }
263  else if (A.subtype == ELL_MATRIX_TYPE)
264  {
265  switch (A.numeric_type)
266  {
267  case FLOAT_TYPE:
268  viennacl::linalg::prod_impl(*A.ell_matrix_float, *x.vector_float, float(1), *result.vector_float, float(0));
269  break;
270  case DOUBLE_TYPE:
271  viennacl::linalg::prod_impl(*A.ell_matrix_double, *x.vector_double, double(1), *result.vector_double, double(0));
272  break;
273  default:
274  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
275  }
276  }
277  else if (A.subtype == HYB_MATRIX_TYPE)
278  {
279  switch (A.numeric_type)
280  {
281  case FLOAT_TYPE:
282  viennacl::linalg::prod_impl(*A.hyb_matrix_float, *x.vector_float, float(1), *result.vector_float, float(0));
283  break;
284  case DOUBLE_TYPE:
285  viennacl::linalg::prod_impl(*A.hyb_matrix_double, *x.vector_double, double(1), *result.vector_double, double(0));
286  break;
287  default:
288  throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication");
289  }
290  }
291  else
292  {
293  std::cout << "A.subtype: " << A.subtype << std::endl;
294  throw statement_not_supported_exception("Invalid matrix type for matrix-vector product");
295  }
296  }
297 
298 } // namespace detail
299 
300 inline void execute_matrix_prod(statement const & s, statement_node const & root_node)
301 {
302  statement_node const & leaf = s.array()[root_node.rhs.node_index];
304 
305  // Part 1: Check whether temporaries are required //
306 
307  statement_node new_root_lhs;
308  statement_node new_root_rhs;
309 
310  bool lhs_needs_temporary = detail::matrix_prod_temporary_required(s, leaf.lhs);
311  bool rhs_needs_temporary = detail::matrix_prod_temporary_required(s, leaf.rhs);
312 
313  // check for temporary on lhs:
314  if (lhs_needs_temporary)
315  {
316  std::cout << "Temporary for LHS!" << std::endl;
317  detail::new_element(new_root_lhs.lhs, root_node.lhs, ctx);
318 
320  new_root_lhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
321 
323  new_root_lhs.rhs.subtype = INVALID_SUBTYPE;
324  new_root_lhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
325  new_root_lhs.rhs.node_index = leaf.lhs.node_index;
326 
327  // work on subexpression:
328  // TODO: Catch exception, free temporary, then rethrow
329  detail::execute_composite(s, new_root_lhs);
330  }
331 
332  // check for temporary on rhs:
333  if (rhs_needs_temporary)
334  {
335  detail::new_element(new_root_rhs.lhs, root_node.lhs, ctx);
336 
338  new_root_rhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
339 
341  new_root_rhs.rhs.subtype = INVALID_SUBTYPE;
342  new_root_rhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
343  new_root_rhs.rhs.node_index = leaf.rhs.node_index;
344 
345  // work on subexpression:
346  // TODO: Catch exception, free temporary, then rethrow
347  detail::execute_composite(s, new_root_rhs);
348  }
349 
350  // Part 2: Run the actual computations //
351 
352  lhs_rhs_element x = lhs_needs_temporary ? new_root_lhs.lhs : leaf.lhs;
353  lhs_rhs_element y = rhs_needs_temporary ? new_root_rhs.lhs : leaf.rhs;
354 
355  if (root_node.lhs.type_family == VECTOR_TYPE_FAMILY)
356  {
357  if (root_node.op.type != OPERATION_BINARY_ASSIGN_TYPE)
358  {
359  //split y += A*x
360  statement_node new_root_z;
361  detail::new_element(new_root_z.lhs, root_node.lhs, ctx);
362 
363  // compute z = A * x
364  detail::matrix_vector_prod(s, new_root_z.lhs, x, y);
365 
366  // assignment y = z
367  double alpha = 0;
368  if (root_node.op.type == OPERATION_BINARY_INPLACE_ADD_TYPE)
369  alpha = 1.0;
370  else if (root_node.op.type == OPERATION_BINARY_INPLACE_SUB_TYPE)
371  alpha = -1.0;
372  else
373  throw statement_not_supported_exception("Invalid assignment type for matrix-vector product");
374 
375  lhs_rhs_element y2 = root_node.lhs;
376  detail::axbx(y2,
377  y2, 1.0, 1, false, false,
378  new_root_z.lhs, alpha, 1, false, false);
379 
380  detail::delete_element(new_root_z.lhs);
381  }
382  else
383  detail::matrix_vector_prod(s, root_node.lhs, x, y);
384  }
385  else
386  {
387  double alpha = (root_node.op.type == OPERATION_BINARY_INPLACE_SUB_TYPE) ? -1.0 : 1.0;
388  double beta = (root_node.op.type != OPERATION_BINARY_ASSIGN_TYPE) ? 1.0 : 0.0;
389 
390  detail::matrix_matrix_prod(s, root_node.lhs, x, y, alpha, beta);
391  }
392 
393  // Part 3: Clean up //
394 
395  if (lhs_needs_temporary)
396  detail::delete_element(new_root_lhs.lhs);
397 
398  if (rhs_needs_temporary)
399  detail::delete_element(new_root_rhs.lhs);
400 }
401 
402 } // namespace scheduler
403 } // namespace viennacl
404 
405 #endif
406 
Implementations of dense matrix related operations including matrix-vector products.
void execute_matrix_prod(statement const &s, statement_node const &root_node)
Implementations of vector operations.
viennacl::context extract_context(statement_node const &root_node)
Helper routine for extracting the context in which a statement is executed.
statement_node_subtype subtype
Definition: forwards.h:340
Expression template class for representing a tree of expressions which ultimately result in a matrix...
Definition: forwards.h:341
viennacl::coordinate_matrix< double > * coordinate_matrix_double
Definition: forwards.h:443
This file provides the forward declarations for the main types used within ViennaCL.
void matrix_vector_prod(statement const &s, lhs_rhs_element result, lhs_rhs_element const &A, lhs_rhs_element const &x)
statement_node_type_family type_family
Definition: forwards.h:339
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:337
container_type const & array() const
Definition: forwards.h:528
Provides unified wrappers for the common routines {as(), asbs(), asbs_s()}, {av(), avbv(), avbv_v()}, and {am(), ambm(), ambm_m()} such that scheduler logic is not cluttered with numeric type decutions.
operation_node_type_family type_family
Definition: forwards.h:473
Implementation of the coordinate_matrix class.
viennacl::matrix_base< double > * matrix_double
Definition: forwards.h:410
void delete_element(lhs_rhs_element &elem)
viennacl::hyb_matrix< double > * hyb_matrix_double
Definition: forwards.h:465
Represents a generic 'context' similar to an OpenCL context, but is backend-agnostic and thus also su...
Definition: context.hpp:39
Implementation of the hyb_matrix class.
viennacl::compressed_matrix< float > * compressed_matrix_float
Definition: forwards.h:431
void axbx(lhs_rhs_element &x1, lhs_rhs_element const &x2, ScalarType1 const &alpha, vcl_size_t len_alpha, bool reciprocal_alpha, bool flip_sign_alpha, lhs_rhs_element const &x3, ScalarType2 const &beta, vcl_size_t len_beta, bool reciprocal_beta, bool flip_sign_beta)
Wrapper for viennacl::linalg::avbv(), taking care of the argument unwrapping.
statement_node_numeric_type numeric_type
Definition: forwards.h:341
Implementation of the compressed_matrix class.
viennacl::vector_base< float > * vector_float
Definition: forwards.h:385
Implementations of operations using sparse matrices.
viennacl::matrix_base< float > * matrix_float
Definition: forwards.h:409
viennacl::vector_base< double > * vector_double
Definition: forwards.h:386
Implementation of the ell_matrix class.
viennacl::compressed_matrix< double > * compressed_matrix_double
Definition: forwards.h:432
viennacl::ell_matrix< double > * ell_matrix_double
Definition: forwards.h:454
viennacl::hyb_matrix< float > * hyb_matrix_float
Definition: forwards.h:464
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
void execute_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is an expression and x is either a scalar, a vector, or a matrix...
Definition: execute.hpp:42
void new_element(lhs_rhs_element &new_elem, lhs_rhs_element const &old_element, viennacl::context const &ctx)
viennacl::ell_matrix< float > * ell_matrix_float
Definition: forwards.h:453
operation_node_type type
Definition: forwards.h:474
A tag class representing transposed matrices.
Definition: forwards.h:220
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:502
void matrix_matrix_prod(statement const &s, lhs_rhs_element result, lhs_rhs_element const &A, lhs_rhs_element const &B, double alpha, double beta)
void prod_impl(const matrix_base< NumericT > &mat, const vector_base< NumericT > &vec, vector_base< NumericT > &result)
Carries out matrix-vector multiplication.
bool matrix_prod_temporary_required(statement const &s, lhs_rhs_element const &elem)
viennacl::coordinate_matrix< float > * coordinate_matrix_float
Definition: forwards.h:442
Provides various utilities for implementing the execution of statements.
Main datastructure for an node in the statement tree.
Definition: forwards.h:478
Exception for the case the scheduler is unable to deal with the operation.
Definition: forwards.h:38