ViennaCL - The Vienna Computing Library  1.4.2
viennacl/generator/get_kernels_infos.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_GENERATOR_CREATE_KERNEL_HPP
00002 #define VIENNACL_GENERATOR_CREATE_KERNEL_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00027 // #include "kernel_utils.hpp"
00028 
00029 #include <map>
00030 
00031 #include "viennacl/generator/operators.hpp"
00032 #include "viennacl/generator/symbolic_types.hpp"
00033 #include "viennacl/generator/tree_operations.hpp"
00034 #include "viennacl/generator/tokens_management.hpp"
00035 #include "viennacl/generator/make_code.hpp"
00036 #include "viennacl/generator/meta_tools/typelist.hpp"
00037 #include "viennacl/generator/result_of.hpp"
00038 #include "viennacl/tools/shared_ptr.hpp"
00039 
00040 namespace viennacl
00041 {
00042 namespace generator
00043 {
00044 
00045 
00046 
00047 typedef std::multimap<std::string, std::pair<unsigned int,viennacl::tools::shared_ptr<result_of::runtime_wrapper> > > runtime_wrappers_t;
00048 
00049 template<class T>
00050 struct get_head{
00051     typedef T Result;
00052 };
00053 
00054 template<class Head, class Tail>
00055 struct get_head<typelist<Head, Tail> >
00056 {
00057     typedef Head Result;
00058 };
00059 
00060 
00062 template<class T>
00063 struct transform_inner_prod
00064 {
00065     typedef T Result;
00066 };
00067 
00068 template<class LHS, class RHS>
00069 struct transform_inner_prod<compound_node<LHS,inner_prod_type,RHS> >
00070 {
00071     typedef inner_prod_impl_t<compound_node<LHS,inner_prod_type,RHS> > Result;
00072 };
00073 
00080 template<class TreeList, class Res, int CurrentIndex=0>
00081 struct register_kernels;
00082 
00083 template<class Head,class Tail, class Res,int CurrentIndex>
00084 struct register_kernels<typelist<Head, Tail>,Res,CurrentIndex >
00085 {
00086 private:
00087     typedef typelist<Head, Tail> self_type;
00088 public:
00089     template<class T, class List, int Index>
00090     struct add_to_res
00091     {
00092         //Gets the typelist at Index
00093         typedef typename typelist_utils::type_at<List,Index>::Result Tmp;
00094 
00095         //Fuses it with the argument provided
00096         typedef typename typelist_utils::fuse<Tmp,T>::Result TmpRes;
00097 
00098         //Replace the former typelist with the new typelist
00099         typedef typename typelist_utils::replace<List,Tmp,TmpRes>::Result ResultIfTmpNotNull;
00100         typedef typename typelist_utils::append<List,T>::Result ResultIfTmpNull;
00101         typedef typename get_type_if<ResultIfTmpNull,ResultIfTmpNotNull,result_of::is_null_type<Tmp>::value>::Result Result;
00102     };
00103 
00104 private:
00105     typedef typename tree_utils::extract_if<Head,result_of::is_inner_product_leaf>::Result InProds;
00106     typedef typename add_to_res<typename typelist_utils::ForEachType<InProds,transform_inner_prod>::Result,Res,CurrentIndex - 1>::Result TmpNewRes;
00107     static const bool inc = tree_utils::count_if<typename get_head<Tail>::Result, result_of::or_is<result_of::is_product_leaf,result_of::is_inner_product_leaf>::Pred >::value
00108                             + tree_utils::count_if<Head,result_of::is_product_leaf>::value;
00109 public:
00110     typedef typename add_to_res<typelist<Head,NullType>,TmpNewRes,CurrentIndex>::Result NewRes;
00111     typedef typename register_kernels<Tail,NewRes,CurrentIndex+inc>::Result Result;
00112 };
00113 
00114 template<class Res, int CurrentIndex>
00115 struct register_kernels<NullType,Res,CurrentIndex>
00116 {
00117     typedef Res NewRes;
00118     typedef Res Result;
00119 };
00120 
00125 template<class ARG>
00126 struct program_infos
00127 {
00128 
00129     static const bool first_has_ip = tree_utils::count_if<typename ARG::Head,result_of::is_inner_product_leaf>::value;
00130     typedef typename register_kernels<ARG,NullType,first_has_ip>::Result          KernelsList;
00131 
00132     template<class Operations>
00133     struct fill_args
00134     {
00135     private:
00136             typedef typename tree_utils::extract_if<typename get_operations_from_expressions<Operations>::Unrolled,result_of::is_kernel_argument>::Result IntermediateType;
00137             typedef typename typelist_utils::no_duplicates<IntermediateType>::Result Arguments;
00138 
00139     public:
00140         template<class U>
00141         struct functor{
00142         private:
00143             typedef typename result_of::expression_type<U>::Result ExpressionType;
00144         public:
00145             static void execute(unsigned int & arg_pos, runtime_wrappers_t & runtime_wrappers, std::string const & name)
00146             {
00147                 runtime_wrappers.insert(runtime_wrappers_t::value_type(name,
00148                                                                        std::make_pair(arg_pos,
00149                                                                                       ExpressionType::runtime_descriptor())
00150                                                                        )
00151                                         );
00152                 runtime_wrappers.size();
00153                 arg_pos += ExpressionType::n_args();
00154             }
00155         };
00156 
00157         static void execute(runtime_wrappers_t & runtime_wrappers,std::string const & operation_name)
00158         {
00159             unsigned int arg_pos = 0;
00160             unsigned int n = typelist_utils::index_of<KernelsList,Operations>::value;
00161             std::string current_kernel_name("__" + operation_name + "_k" + to_string(n));
00162             typelist_utils::ForEach<Arguments,functor>::execute(arg_pos,runtime_wrappers,current_kernel_name);
00163             if(tree_utils::count_if<Operations,result_of::is_inner_product_leaf>::value || tree_utils::count_if<Operations,result_of::is_product_leaf>::value){
00164                 runtime_wrappers.insert(runtime_wrappers_t::value_type(current_kernel_name,
00165                                                                        std::make_pair(arg_pos,
00166                                                                                       new result_of::shared_memory_wrapper())));
00167             }
00168         }
00169 
00170 
00171     };
00172 
00173     template<class Operations>
00174     struct fill_sources
00175     {
00176     private:
00177             typedef typename tree_utils::extract_if<typename get_operations_from_expressions<Operations>::Unrolled,result_of::is_kernel_argument>::Result IntermediateType;
00178             typedef typename typelist_utils::no_duplicates<IntermediateType>::Result Arguments;
00179 
00180     public:
00181         template<class TList>
00182         struct header_code
00183         {
00184 
00185             template<class T>
00186             struct functor{
00187                 static void execute(std::string & res,bool & is_first){
00188                     if(is_first){
00189                         res+=T::kernel_arguments();
00190                         is_first=false;
00191                     }
00192                     else{
00193                         res+=", "+T::kernel_arguments();
00194                     }
00195                 }
00196             };
00197 
00198         public:
00199             static const std::string value ( std::string const & name )
00200             {
00201                 std::string res;
00202                 res+="__kernel void " + name + "(\n";
00203                 bool state=true;
00204                 typelist_utils::ForEach<Arguments,functor>::execute(res,state);
00205                 if(tree_utils::count_if<TList,result_of::is_inner_product_leaf>::value || tree_utils::count_if<Operations,result_of::is_product_leaf>::value)
00206                     res+=",__local float* shared_memory_ptr\n";
00207                 res+=")\n";
00208                 return res;
00209             }
00210         };
00211 
00212 
00213         static void execute(std::map<std::string,std::string> & sources,std::string const & operation_name)
00214         {
00215             unsigned int n = typelist_utils::index_of<KernelsList,Operations>::value;
00216             std::string current_kernel_name("__" + operation_name + "_k" + to_string(n));
00217             sources.insert(std::make_pair(current_kernel_name,
00218                                           header_code<Operations>::value(current_kernel_name)
00219                                           +body_code<Operations>::value()));
00220         }
00221     };
00222 
00223 
00227     static void fill(std::string const & operation_name, std::map<std::string,std::string> & sources, runtime_wrappers_t & runtime_wrappers)
00228     {
00229         //std::cout << KernelsList::name() << std::endl;
00230         typelist_utils::ForEach<KernelsList,fill_sources>::execute(sources,operation_name);
00231         typelist_utils::ForEach<KernelsList,fill_args>::execute(runtime_wrappers,operation_name);
00232     }
00233 };
00234 
00235 
00236 
00237 } // namespace generator
00238 } // namespace viennacl
00239 #endif
00240