ViennaCL - The Vienna Computing Library  1.7.1
Free open-source GPU-accelerated linear algebra and solver library.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
execute_elementwise.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_SCHEDULER_EXECUTE_ELEMENTWISE_HPP
2 #define VIENNACL_SCHEDULER_EXECUTE_ELEMENTWISE_HPP
3 
4 /* =========================================================================
5  Copyright (c) 2010-2016, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8  Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14  Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16  (A list of authors and contributors can be found in the manual)
17 
18  License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
26 #include "viennacl/forwards.h"
31 
32 namespace viennacl
33 {
34 namespace scheduler
35 {
36 namespace detail
37 {
38  // result = element_op(x,y) for vectors or matrices x, y
39  inline void element_op(lhs_rhs_element result,
40  lhs_rhs_element const & x,
41  operation_node_type op_type)
42  {
43  assert( result.numeric_type == x.numeric_type && bool("Numeric type not the same!"));
44  assert( result.type_family == x.type_family && bool("Subtype not the same!"));
45 
46  if (x.subtype == DENSE_VECTOR_TYPE)
47  {
48  assert( result.subtype == x.subtype && bool("result not of vector type for unary elementwise operation"));
49  if (x.numeric_type == FLOAT_TYPE)
50  {
51  switch (op_type)
52  {
53 #define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, NumericT, OPTAG) \
54  case OPNAME: viennacl::linalg::element_op(*result.vector_##NumericT, \
55  viennacl::vector_expression<const vector_base<NumericT>, const vector_base<NumericT>, \
56  op_element_unary<OPTAG> >(*x.vector_##NumericT, *x.vector_##NumericT)); break;
57 
75  default:
76  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
77  }
78  }
79  else if (x.numeric_type == DOUBLE_TYPE)
80  {
81  switch (op_type)
82  {
100 
101 #undef VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP
102  default:
103  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
104  }
105  }
106  else
107  throw statement_not_supported_exception("Invalid numeric type in unary elementwise operator");
108  }
109  else if (x.subtype == DENSE_MATRIX_TYPE)
110  {
111  if (x.numeric_type == FLOAT_TYPE)
112  {
113  switch (op_type)
114  {
115 #define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, NumericT, OPTAG) \
116  case OPNAME: viennacl::linalg::element_op(*result.matrix_##NumericT, \
117  viennacl::matrix_expression<const matrix_base<NumericT>, const matrix_base<NumericT>, \
118  op_element_unary<OPTAG> >(*x.matrix_##NumericT, *x.matrix_##NumericT)); break;
119 
137  default:
138  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
139  }
140 
141  }
142  else if (x.numeric_type == DOUBLE_TYPE)
143  {
144  switch (op_type)
145  {
163  default:
164  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
165  }
166  }
167  else
168  throw statement_not_supported_exception("Invalid numeric type in unary elementwise operator");
169 
170 #undef VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP
171 
172  }
173  }
174 
175  // result = element_op(x,y) for vectors or matrices x, y
176  inline void element_op(lhs_rhs_element result,
177  lhs_rhs_element const & x,
178  lhs_rhs_element const & y,
179  operation_node_type op_type)
180  {
181  assert( x.numeric_type == y.numeric_type && bool("Numeric type not the same!"));
182  assert( result.numeric_type == y.numeric_type && bool("Numeric type not the same!"));
183 
184  assert( x.type_family == y.type_family && bool("Subtype not the same!"));
185  assert( result.type_family == y.type_family && bool("Subtype not the same!"));
186 
187  switch (op_type)
188  {
189 
191  if (x.subtype == DENSE_VECTOR_TYPE)
192  {
193  switch (x.numeric_type)
194  {
195  case FLOAT_TYPE:
198  const vector_base<float>,
200  break;
201  case DOUBLE_TYPE:
204  const vector_base<double>,
206  break;
207  default:
208  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
209  }
210  }
211  else if (x.subtype == DENSE_MATRIX_TYPE)
212  {
213  switch (x.numeric_type)
214  {
215  case FLOAT_TYPE:
218  const matrix_base<float>,
220  break;
221  case DOUBLE_TYPE:
224  const matrix_base<double>,
226  break;
227  default:
228  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
229  }
230  }
231  else
232  throw statement_not_supported_exception("Invalid operand type for binary elementwise division");
233  break;
234 
235 
237  if (x.subtype == DENSE_VECTOR_TYPE)
238  {
239  switch (x.numeric_type)
240  {
241  case FLOAT_TYPE:
244  const vector_base<float>,
246  break;
247  case DOUBLE_TYPE:
250  const vector_base<double>,
252  break;
253  default:
254  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
255  }
256  }
257  else if (x.subtype == DENSE_MATRIX_TYPE)
258  {
259  switch (x.numeric_type)
260  {
261  case FLOAT_TYPE:
264  const matrix_base<float>,
266  break;
267  case DOUBLE_TYPE:
270  const matrix_base<double>,
272  break;
273  default:
274  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
275  }
276  }
277  else
278  throw statement_not_supported_exception("Invalid operand type for binary elementwise division");
279  break;
280 
281 
283  if (x.subtype == DENSE_VECTOR_TYPE)
284  {
285  switch (x.numeric_type)
286  {
287  case FLOAT_TYPE:
290  const vector_base<float>,
292  break;
293  case DOUBLE_TYPE:
296  const vector_base<double>,
298  break;
299  default:
300  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
301  }
302  }
303  else if (x.subtype == DENSE_MATRIX_TYPE)
304  {
305  switch (x.numeric_type)
306  {
307  case FLOAT_TYPE:
310  const matrix_base<float>,
312  break;
313  case DOUBLE_TYPE:
316  const matrix_base<double>,
318  break;
319  default:
320  throw statement_not_supported_exception("Invalid numeric type for binary elementwise power");
321  }
322  }
323  else
324  throw statement_not_supported_exception("Invalid operand type for binary elementwise power");
325  break;
326 
327  default:
328  throw statement_not_supported_exception("Invalid operation type for binary elementwise operations");
329  }
330  }
331 }
332 
334 inline void execute_element_composite(statement const & s, statement_node const & root_node)
335 {
336  statement_node const & leaf = s.array()[root_node.rhs.node_index];
338 
339  statement_node new_root_lhs;
340  statement_node new_root_rhs;
341 
342  // check for temporary on lhs:
344  {
345  detail::new_element(new_root_lhs.lhs, root_node.lhs, ctx);
346 
348  new_root_lhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
349 
351  new_root_lhs.rhs.subtype = INVALID_SUBTYPE;
352  new_root_lhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
353  new_root_lhs.rhs.node_index = leaf.lhs.node_index;
354 
355  // work on subexpression:
356  // TODO: Catch exception, free temporary, then rethrow
357  detail::execute_composite(s, new_root_lhs);
358  }
359 
361  {
362  // check for temporary on rhs:
364  {
365  detail::new_element(new_root_rhs.lhs, root_node.lhs, ctx);
366 
368  new_root_rhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
369 
371  new_root_rhs.rhs.subtype = INVALID_SUBTYPE;
372  new_root_rhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
373  new_root_rhs.rhs.node_index = leaf.rhs.node_index;
374 
375  // work on subexpression:
376  // TODO: Catch exception, free temporary, then rethrow
377  detail::execute_composite(s, new_root_rhs);
378  }
379 
380  lhs_rhs_element x = (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_lhs.lhs : leaf.lhs;
381  lhs_rhs_element y = (leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_rhs.lhs : leaf.rhs;
382 
383  // compute element-wise operation:
384  detail::element_op(root_node.lhs, x, y, leaf.op.type);
385 
387  detail::delete_element(new_root_rhs.lhs);
388  }
389  else if (leaf.op.type_family == OPERATION_UNARY_TYPE_FAMILY)
390  {
391  lhs_rhs_element x = (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_lhs.lhs : leaf.lhs;
392 
393  // compute element-wise operation:
394  detail::element_op(root_node.lhs, x, leaf.op.type);
395  }
396  else
397  throw statement_not_supported_exception("Unsupported elementwise operation.");
398 
399  // clean up:
401  detail::delete_element(new_root_lhs.lhs);
402 
403 }
404 
405 
406 } // namespace scheduler
407 } // namespace viennacl
408 
409 #endif
410 
A tag class representing the cosh() function.
Definition: forwards.h:155
A tag class representing the tan() function.
Definition: forwards.h:181
#define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, NumericT, OPTAG)
Implementations of dense matrix related operations including matrix-vector products.
Implementations of vector operations.
viennacl::context extract_context(statement_node const &root_node)
Helper routine for extracting the context in which a statement is executed.
A tag class representing the modulus function for integers.
Definition: forwards.h:137
statement_node_subtype subtype
Definition: forwards.h:340
Expression template class for representing a tree of expressions which ultimately result in a matrix...
Definition: forwards.h:341
A tag class representing the ceil() function.
Definition: forwards.h:151
This file provides the forward declarations for the main types used within ViennaCL.
statement_node_type_family type_family
Definition: forwards.h:339
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:337
container_type const & array() const
Definition: forwards.h:528
void execute_element_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is a vector expression.
An expression template class that represents a binary operation that yields a vector.
Definition: forwards.h:239
void element_op(matrix_base< T > &A, matrix_expression< const matrix_base< T >, const matrix_base< T >, OP > const &proxy)
Implementation of the element-wise operation A = B .* C and A = B ./ C for matrices (using MATLAB syn...
A tag class representing the log() function.
Definition: forwards.h:171
operation_node_type_family type_family
Definition: forwards.h:473
A tag class representing the tanh() function.
Definition: forwards.h:183
A tag class representing the fabs() function.
Definition: forwards.h:159
viennacl::matrix_base< double > * matrix_double
Definition: forwards.h:410
void delete_element(lhs_rhs_element &elem)
Represents a generic 'context' similar to an OpenCL context, but is backend-agnostic and thus also su...
Definition: context.hpp:39
A tag class representing the atan() function.
Definition: forwards.h:147
A tag class representing the sinh() function.
Definition: forwards.h:177
A tag class representing the exp() function.
Definition: forwards.h:157
statement_node_numeric_type numeric_type
Definition: forwards.h:341
viennacl::vector_base< float > * vector_float
Definition: forwards.h:385
viennacl::matrix_base< float > * matrix_float
Definition: forwards.h:409
A tag class representing the sqrt() function.
Definition: forwards.h:179
viennacl::vector_base< double > * vector_double
Definition: forwards.h:386
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
void execute_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is an expression and x is either a scalar, a vector, or a matrix...
Definition: execute.hpp:42
operation_node_type
Enumeration for identifying the possible operations.
Definition: forwards.h:68
A tag class representing the sin() function.
Definition: forwards.h:175
void new_element(lhs_rhs_element &new_elem, lhs_rhs_element const &old_element, viennacl::context const &ctx)
A tag class representing the floor() function.
Definition: forwards.h:163
A tag class representing the asin() function.
Definition: forwards.h:141
operation_node_type type
Definition: forwards.h:474
A tag class representing element-wise binary operations (like multiplication) on vectors or matrices...
Definition: forwards.h:130
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:502
A tag class representing the acos() function.
Definition: forwards.h:139
A tag class representing the log10() function.
Definition: forwards.h:173
void element_op(lhs_rhs_element result, lhs_rhs_element const &x, operation_node_type op_type)
Provides various utilities for implementing the execution of statements.
A tag class representing the cos() function.
Definition: forwards.h:153
Main datastructure for an node in the statement tree.
Definition: forwards.h:478
Exception for the case the scheduler is unable to deal with the operation.
Definition: forwards.h:38