ViennaCL - The Vienna Computing Library 1.3.0
/export/development/ViennaCL/viennacl/generator/compound_node.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_GENERATOR_COMPOUND_NODE_HPP
00002 #define VIENNACL_GENERATOR_COMPOUND_NODE_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2012, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008 
00009                             -----------------
00010                   ViennaCL - The Vienna Computing Library
00011                             -----------------
00012 
00013    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00014                
00015    (A list of authors and contributors can be found in the PDF manual)
00016 
00017    License:         MIT (X11), see file LICENSE in the base directory
00018 ============================================================================= */
00019 
00026 #include <string>
00027 #include <sstream>
00028 #include <set>
00029 
00030 #include "viennacl/generator/forwards.h"
00031 #include "viennacl/generator/meta_tools/utils.hpp"
00032 #include "viennacl/generator/traits/general_purpose_traits.hpp"
00033 #include "viennacl/generator/traits/result_of.hpp"
00034 
00035 namespace viennacl 
00036 {
00037   namespace generator
00038   {
00039 
00048     template<class LHS_, class OP_, class RHS_, bool is_temporary_>
00049     class compound_node 
00050     {
00051       public:
00052         typedef LHS_  LHS;
00053         typedef RHS_  RHS;
00054         typedef OP_   OP;
00055 
00056         static const bool is_temporary = is_temporary_;
00057 
00058         static const std::string name() 
00059         {
00060             return LHS::name() + "_" + OP::name() + "_" + RHS::name();
00061         }
00062     };
00063 
00064     template<class LHS_, class RHS_, bool is_temporary_>
00065     class compound_node<LHS_,inner_prod_type,RHS_, is_temporary_> 
00066     {
00067       public:
00071         typedef LHS_ LHS;
00072         typedef RHS_ RHS;
00073         typedef inner_prod_type OP;
00074         typedef typename result_of::expression_type<RHS>::Result IntermediateType;  //Note: Visual Studio does not allow to combine this line with the next one directly.
00075         typedef typename IntermediateType::ScalarType ScalarType;
00076 
00077         static const bool is_temporary = is_temporary_;
00078 
00079         enum { id = -2 };
00080 
00081         static const std::string kernel_arguments() 
00082         {
00083           return  "__global float * " + name() + '\n';
00084         }
00085 
00086         static const std::string name() 
00087         {
00088           return  LHS::name() + "_inprod_" + RHS::name();
00089         }
00090 
00091         static const std::string scalar_name() 
00092         {
00093           return name() +"_s";
00094         };
00095 
00096     };
00097 
00101     template<class LHS_, class RHS_, bool is_temporary_>
00102     class compound_node<LHS_,prod_type,RHS_, is_temporary_> 
00103     {
00104       private:
00105         typedef compound_node<LHS_,prod_type,RHS_, is_temporary_> self_type;
00106 
00107       public:
00108         typedef LHS_ LHS;
00109         typedef RHS_ RHS;
00110 
00111         typedef prod_type OP;
00112         enum { id = LHS::id };
00113 
00114         typedef typename result_of::expression_type<RHS>::Result IntermediateType;    //Note: Visual Studio does not allow to combine this line with the next one directly.
00115         typedef typename IntermediateType::ScalarType ScalarType;
00116         static const unsigned int Alignment = result_of::expression_type<RHS>::Result::Alignment;
00117         static const bool is_temporary = is_temporary_;
00118 
00119         static const std::string name() 
00120         {
00121           return LHS::name() + "_prod_" + RHS::name();
00122         }
00123 
00124         static const std::string size2_name() 
00125         {
00126           return "size_"+name();
00127         }
00128 
00129         static const std::string internal_size2_name() 
00130         {
00131           return "internal_size_"+name();
00132         }
00133         
00134         static const std::string name_argument() 
00135         {
00136           return " __global " + print_type<ScalarType*,Alignment>::value() + " " + name();
00137         }
00138 
00139         static const std::string kernel_arguments() 
00140         {
00141           return name_argument() + ", unsigned int " + size2_name() + ", unsigned int " + internal_size2_name() + "\n" ;
00142         }
00143     };
00144 
00145 
00147     template<class LHS_TYPE, class RHS_TYPE>
00148     typename enable_if< is_same_expression_type<LHS_TYPE, RHS_TYPE>,
00149                         compound_node<LHS_TYPE, add_type, RHS_TYPE> >::type
00150     operator+ ( LHS_TYPE const & lhs, RHS_TYPE const & rhs ) 
00151     {
00152       return compound_node<LHS_TYPE, add_type, RHS_TYPE>();
00153     }
00154 
00156     template<class LHS_TYPE, class RHS_TYPE>
00157     typename enable_if< is_same_expression_type<LHS_TYPE, RHS_TYPE>,
00158                         compound_node<LHS_TYPE, sub_type, RHS_TYPE> >::type
00159     operator- ( LHS_TYPE const & lhs, RHS_TYPE const & rhs ) 
00160     {
00161       return compound_node<LHS_TYPE, sub_type, RHS_TYPE>();
00162     }
00163 
00165     template<class LHS, class RHS>
00166     struct make_inner_prod;
00167 
00168     template<class LHS, class LHS_SIZE_DESCRIPTOR,
00169              class RHS, class RHS_SIZE_DESCRIPTOR>
00170     struct make_inner_prod<result_of::vector_expression<LHS, LHS_SIZE_DESCRIPTOR>,
00171                            result_of::vector_expression<RHS, RHS_SIZE_DESCRIPTOR> > 
00172     {
00173       typedef compound_node<LHS,inner_prod_type,RHS,true> Result;
00174     };
00175 
00176 
00178     template<class LHS, class RHS>
00179     compound_node<LHS,inner_prod_type,RHS,true> inner_prod ( LHS vec_expr1,RHS vec_expr2 ) 
00180     {
00181       typedef typename result_of::expression_type<LHS>::Result LHS_TYPE;
00182       typedef typename result_of::expression_type<RHS>::Result RHS_TYPE;
00183       typename make_inner_prod<LHS_TYPE,RHS_TYPE>::Result result;
00184       
00185       return result;;
00186     }
00187 
00189     template<class LHS, class RHS>
00190     compound_node<LHS,prod_type,RHS> prod ( LHS vec_expr1,RHS vec_expr2 ) 
00191     {
00192       return compound_node<LHS,prod_type,RHS>();
00193     }
00194 
00195   } // namespace generator
00196 } // namespace viennacl
00197 
00198 #endif
00199