matrix_mul.h

Go to the documentation of this file.
00001 /* -*- C++ -*- ------------------------------------------------------------
00002  
00003 Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/
00004 
00005 The Configurable Math Library (CML) is distributed under the terms of the
00006 Boost Software License, v1.0 (see cml/LICENSE for details).
00007 
00008  *-----------------------------------------------------------------------*/
00021 #ifndef matrix_mul_h
00022 #define matrix_mul_h
00023 
00024 #include <cml/et/size_checking.h>
00025 #include <cml/matrix/matrix_expr.h>
00026 
00027 /* This is used below to create a more meaningful compile-time error when
00028  * mul is not provided with matrix or MatrixExpr arguments:
00029  */
00030 struct mul_expects_matrix_args_error;
00031 
00032 /* This is used below to create a more meaningful compile-time error when
00033  * fixed-size arguments to mul() have the wrong size:
00034  */
00035 struct mul_expressions_have_wrong_size_error;
00036 
00037 namespace cml {
00038 namespace detail {
00039 
00044 template<typename LeftT, typename RightT> inline matrix_size
00045 MatMulCheckedSize(const LeftT&, const RightT&, fixed_size_tag)
00046 {
00047     CML_STATIC_REQUIRE_M(
00048             ((size_t)LeftT::array_cols == (size_t)RightT::array_rows),
00049             mul_expressions_have_wrong_size_error);
00050     return matrix_size(LeftT::array_rows,RightT::array_cols);
00051 }
00052 
00057 template<typename LeftT, typename RightT> inline matrix_size
00058 MatMulCheckedSize(const LeftT& left, const RightT& right, dynamic_size_tag)
00059 {
00060     matrix_size left_N = left.size(), right_N = right.size();
00061     et::GetCheckedSize<LeftT,RightT,dynamic_size_tag>()
00062         .equal_or_fail(left_N.second, right_N.first); /* cols,rows */
00063     return matrix_size(left_N.first, right_N.second); /* rows,cols */
00064 }
00065 
00066 
00071 template<class LeftT, class RightT>
00072 inline typename et::MatrixPromote<
00073     typename et::ExprTraits<LeftT>::result_type,
00074     typename et::ExprTraits<RightT>::result_type
00075 >::temporary_type
00076 mul(const LeftT& left, const RightT& right)
00077 {
00078     /* Shorthand: */
00079     typedef et::ExprTraits<LeftT> left_traits;
00080     typedef et::ExprTraits<RightT> right_traits;
00081     typedef typename left_traits::result_type left_result;
00082     typedef typename right_traits::result_type right_result;
00083 
00084     /* First, require matrix expressions: */
00085     CML_STATIC_REQUIRE_M(
00086             (et::MatrixExpressions<LeftT,RightT>::is_true),
00087             mul_expects_matrix_args_error);
00088     /* Note: parens are required here so that the preprocessor ignores the
00089      * commas.
00090      */
00091 
00092     /* Deduce size type to ensure that a run-time check is performed if
00093      * necessary:
00094      */
00095     typedef typename et::MatrixPromote<
00096         typename left_traits::result_type,
00097         typename right_traits::result_type
00098     >::type result_type;
00099     typedef typename result_type::size_tag size_tag;
00100 
00101     /* Require that left has the same number of columns as right has rows.
00102      * This automatically checks fixed-size matrices at compile time, and
00103      * throws at run-time if the sizes don't match:
00104      */
00105     matrix_size N = detail::MatMulCheckedSize(left, right, size_tag());
00106 
00107     /* Create an array with the right size (resize() is a no-op for
00108      * fixed-size matrices):
00109      */
00110     result_type C;
00111     cml::et::detail::Resize(C, N);
00112 
00113     /* XXX Specialize this for fixed-size matrices: */
00114     typedef typename result_type::value_type value_type;
00115     for(size_t i = 0; i < left.rows(); ++i) {               /* rows */
00116         for(size_t j = 0; j < right.cols(); ++j) {          /* cols */
00117             value_type sum(left(i,0)*right(0,j));
00118             for(size_t k = 1; k < right.rows(); ++k) {
00119                 sum += (left(i,k)*right(k,j));
00120             }
00121             C(i,j) = sum;
00122         }
00123     }
00124 
00125     return C;
00126 }
00127 
00128 } // namespace detail
00129 
00130 
00132 template<typename E1, class AT1, typename L1,
00133          typename E2, class AT2, typename L2,
00134          typename BO>
00135 inline typename et::MatrixPromote<
00136     matrix<E1,AT1,BO,L1>, matrix<E2,AT2,BO,L2>
00137 >::temporary_type
00138 operator*(const matrix<E1,AT1,BO,L1>& left,
00139           const matrix<E2,AT2,BO,L2>& right)
00140 {
00141     return detail::mul(left,right);
00142 }
00143 
00145 template<typename E, class AT, typename BO, typename L, typename XprT>
00146 inline typename et::MatrixPromote<
00147     matrix<E,AT,BO,L>, typename XprT::result_type
00148 >::temporary_type
00149 operator*(const matrix<E,AT,BO,L>& left,
00150           const et::MatrixXpr<XprT>& right)
00151 {
00152     /* Generate a temporary, and compute the right-hand expression: */
00153     typedef typename et::MatrixXpr<XprT>::temporary_type expr_tmp;
00154     expr_tmp tmp;
00155     cml::et::detail::Resize(tmp,right.rows(),right.cols());
00156     tmp = right;
00157 
00158     return detail::mul(left,tmp);
00159 }
00160 
00162 template<typename XprT, typename E, class AT, typename BO, typename L>
00163 inline typename et::MatrixPromote<
00164     typename XprT::result_type , matrix<E,AT,BO,L>
00165 >::temporary_type
00166 operator*(const et::MatrixXpr<XprT>& left,
00167           const matrix<E,AT,BO,L>& right)
00168 {
00169     /* Generate a temporary, and compute the left-hand expression: */
00170     typedef typename et::MatrixXpr<XprT>::temporary_type expr_tmp;
00171     expr_tmp tmp;
00172     cml::et::detail::Resize(tmp,left.rows(),left.cols());
00173     tmp = left;
00174 
00175     return detail::mul(tmp,right);
00176 }
00177 
00179 template<typename XprT1, typename XprT2>
00180 inline typename et::MatrixPromote<
00181     typename XprT1::result_type, typename XprT2::result_type
00182 >::temporary_type
00183 operator*(const et::MatrixXpr<XprT1>& left,
00184           const et::MatrixXpr<XprT2>& right)
00185 {
00186     /* Generate temporaries and compute expressions: */
00187     typedef typename et::MatrixXpr<XprT1>::temporary_type left_tmp;
00188     left_tmp ltmp;
00189     cml::et::detail::Resize(ltmp,left.rows(),left.cols());
00190     ltmp = left;
00191 
00192     typedef typename et::MatrixXpr<XprT2>::temporary_type right_tmp;
00193     right_tmp rtmp;
00194     cml::et::detail::Resize(rtmp,right.rows(),right.cols());
00195     rtmp = right;
00196 
00197     return detail::mul(ltmp,rtmp);
00198 }
00199 
00200 } // namespace cml
00201 
00202 #endif
00203 
00204 // -------------------------------------------------------------------------
00205 // vim:ft=cpp

Generated on Sat Jul 18 19:35:24 2009 for CML 1.0 by  doxygen 1.5.9