ViennaCL - The Vienna Computing Library  1.7.1
Free open-source GPU-accelerated linear algebra and solver library.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
matrix_size_deducer.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
2 #define VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
3 
4 /* =========================================================================
5  Copyright (c) 2010-2016, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8  Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14  Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16  (A list of authors and contributors can be found in the manual)
17 
18  License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
25 #include <string>
26 #include <fstream>
27 #include <sstream>
28 #include <cmath>
29 #include <vector>
30 #include <map>
31 
32 #include "viennacl/forwards.h"
34 
35 namespace viennacl
36 {
37 namespace tools
38 {
39 
46 template<typename LHS, typename RHS, typename OP>
48 {
49  //Standard case: size1 from lhs, size2 from rhs (fits most cases)
50  static vcl_size_t size1(LHS & lhs, RHS & /*rhs*/) { return lhs.size1(); }
51  static vcl_size_t size2(LHS & /*lhs*/, RHS & rhs) { return rhs.size2(); }
52 };
53 
55 //special case: outer vector product:
56 template<typename ScalarType>
57 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<ScalarType>,
58  const viennacl::vector_base<ScalarType>,
60 {
62  viennacl::vector_base<ScalarType> const & /*rhs*/) { return lhs.size(); }
63 
64  static vcl_size_t size2(viennacl::vector_base<ScalarType> const & /*lhs*/,
65  viennacl::vector_base<ScalarType> const & rhs) { return rhs.size(); }
66 };
67 
68 
69 //special case: multiplication with a scalar
70 template<typename LHS, typename RHS, typename OP, typename ScalarType>
71 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
72  const ScalarType,
74 {
76  ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
77 
79  ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
80 };
81 
82 //special case: multiplication with a scalar
83 template<typename T, typename ScalarType>
84 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>,
85  const ScalarType,
87 {
88  static vcl_size_t size1(viennacl::matrix_base<T> const & lhs,
89  ScalarType const & /*rhs*/) { return lhs.size1(); }
90 
91  static vcl_size_t size2(viennacl::matrix_base<T> const & lhs,
92  ScalarType const & /*rhs*/) { return lhs.size2(); }
93 };
94 
95 
96 //special case: division with a scalar
97 template<typename LHS, typename RHS, typename OP, typename ScalarType>
98 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
99  const ScalarType,
101 {
103  ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
104 
106  ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
107 };
108 
109 //special case: division with a scalar
110 template<typename T, typename ScalarType>
111 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>,
112  const ScalarType,
114 {
115  static vcl_size_t size1(viennacl::matrix_base<T> const & lhs,
116  ScalarType const & /*rhs*/) { return lhs.size1(); }
117 
118  static vcl_size_t size2(viennacl::matrix_base<T> const & lhs,
119  ScalarType const & /*rhs*/) { return lhs.size2(); }
120 };
121 
122 //special case: diagonal from vector
123 template<typename T>
124 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<T>,
125  const int,
127 {
128  static vcl_size_t size1(viennacl::vector_base<T> const & lhs,
129  const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
130 
131  static vcl_size_t size2(viennacl::vector_base<T> const & lhs,
132  const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
133 };
134 
135 //special case: transposed matrix-vector product: Return the number of rows of the matrix
136 template<typename MatrixType>
137 struct MATRIX_SIZE_DEDUCER<MatrixType,
138  MatrixType,
139  viennacl::op_trans>
140 {
141  static vcl_size_t size1(const MatrixType & lhs,
142  const MatrixType & /*rhs*/) { return lhs.size2(); }
143  static vcl_size_t size2(const MatrixType & lhs,
144  const MatrixType & /*rhs*/) { return lhs.size1(); }
145 };
146 
147 // A^T * B
148 template<typename ScalarType, typename T1>
149 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
150  T1, op_trans>,
151  const viennacl::matrix_base<ScalarType>,
153 {
155  T1,
156  op_trans> const & lhs,
157  viennacl::matrix_base<ScalarType> const & /*rhs*/) { return lhs.lhs().size2(); }
159  T1,
160  op_trans> const & /*lhs*/,
161  viennacl::matrix_base<ScalarType> const & rhs) { return rhs.size2(); }
162 };
163 
164 
165 // A * B^T
166 template<typename ScalarType, typename T2>
167 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<ScalarType>,
169  T2, op_trans>,
171 {
174  T2,
175  op_trans> const & /*rhs*/) { return lhs.size1(); }
176  static vcl_size_t size2(viennacl::matrix_base<ScalarType> const & /*lhs*/,
178  T2,
179  op_trans> const & rhs) { return rhs.lhs().size1(); }
180 };
181 
182 // A^T * B^T
183 template<typename T1, typename T2>
184 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
185  T1, op_trans>,
187  T2, op_trans>,
189 {
192 
193  static vcl_size_t size1(LHSType const & lhs,
194  RHSType const & /*rhs*/) { return lhs.lhs().size2(); }
195  static vcl_size_t size2(LHSType const & /*lhs*/,
196  RHSType const & rhs) { return rhs.lhs().size1(); }
197 };
200 }
201 }
202 
203 #endif
204 
A tag class representing multiplication by a scalar.
Definition: forwards.h:92
Adapter classes for sparse matrices made of the STL type std::vector >
A tag class representing a matrix given by a vector placed on a certain (off-)diagonal.
Definition: forwards.h:189
Expression template class for representing a tree of expressions which ultimately result in a matrix...
Definition: forwards.h:341
This file provides the forward declarations for the main types used within ViennaCL.
A tag class representing division.
Definition: forwards.h:98
static vcl_size_t size1(LHS &lhs, RHS &)
Deduces the size of the resulting vector represented by a vector_expression from the operands...
Common base class for dense vectors, vector ranges, and vector slices.
Definition: vector_def.hpp:104
A tag class representing matrix-matrix products.
Definition: forwards.h:96
std::size_t vcl_size_t
Definition: forwards.h:75
size_type size2() const
Returns the number of columns.
Definition: matrix_def.hpp:226
size_type size1() const
Returns the number of rows.
Definition: matrix_def.hpp:224
RHS & rhs() const
Get right hand side operand.
Definition: matrix.hpp:69
A tag class representing matrix-vector products and element-wise multiplications. ...
Definition: forwards.h:94
size_type size() const
Returns the length of the vector (cf. std::vector)
Definition: vector_def.hpp:118
float ScalarType
Definition: fft_1d.cpp:42
static vcl_size_t size2(LHS &, RHS &rhs)
LHS & lhs() const
Get left hand side operand.
Definition: matrix.hpp:66