summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/matrix/matrix_exp.h
blob: c0afb54c01517c610f5e60e5f6877a9bdeb7121d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
// Copyright (C) 2006  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_EXP_h_
#define DLIB_MATRIx_EXP_h_

#include "../algs.h"
#include "../is_kind.h"
#include "matrix_fwd.h"
#include "matrix_exp_abstract.h"
#include <iterator>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    // We want to return the compile time constant if our NR and NC dimensions
    // aren't zero but if they are then we want to call ref_.nx() and return
    // the correct values. 
    template < typename exp_type, long NR >
    struct get_nr_helper
    {
        static inline long get(const exp_type&) { return NR; }
    };

    template < typename exp_type >
    struct get_nr_helper<exp_type,0>
    {
        static inline long get(const exp_type& m) { return m.nr(); }
    };

    template < typename exp_type, long NC >
    struct get_nc_helper
    {
        static inline long get(const exp_type&) { return NC; }
    };

    template < typename exp_type >
    struct get_nc_helper<exp_type,0>
    {
        static inline long get(const exp_type& m) { return m.nc(); }
    };

    template <typename EXP>
    struct matrix_traits
    {
        typedef typename EXP::type type;
        typedef typename EXP::const_ret_type const_ret_type;
        typedef typename EXP::mem_manager_type mem_manager_type;
        typedef typename EXP::layout_type layout_type;
        const static long NR = EXP::NR;
        const static long NC = EXP::NC;
        const static long cost = EXP::cost;
    };

// ----------------------------------------------------------------------------------------

    template <typename EXP> class matrix_exp;
    template <typename EXP>
    class matrix_exp_iterator : public std::iterator<std::forward_iterator_tag, typename matrix_traits<EXP>::type>
    {
        friend class matrix_exp<EXP>;
        matrix_exp_iterator(const EXP& m_, long r_, long c_)
        {
            r = r_;
            c = c_;
            nc = m_.nc();
            m = &m_;
        }

    public:

        matrix_exp_iterator() : r(0), c(0), nc(0), m(0) {}

        typedef typename matrix_traits<EXP>::type type;
        typedef type value_type;
        typedef typename matrix_traits<EXP>::const_ret_type const_ret_type;


        bool operator == ( const matrix_exp_iterator& itr) const
        { return r == itr.r && c == itr.c; }

        bool operator != ( const matrix_exp_iterator& itr) const
        { return !(*this == itr); }

        matrix_exp_iterator& operator++()
        {
            ++c;
            if (c==nc)
            {
                c = 0;
                ++r;
            }
            return *this;
        }

        matrix_exp_iterator operator++(int)
        {
            matrix_exp_iterator temp(*this);
            ++(*this);
            return temp;
        }

        const_ret_type operator* () const { return (*m)(r,c); }

    private:
        long r, c;
        long nc;
        const EXP* m;
    };

// ----------------------------------------------------------------------------------------

    template <
        typename EXP
        >
    class matrix_exp 
    {
        /*!
            REQUIREMENTS ON EXP
                EXP should be something convertible to a matrix_exp.  That is,
                it should inherit from matrix_exp
        !*/

    public:
        typedef typename matrix_traits<EXP>::type type;
        typedef type value_type;
        typedef typename matrix_traits<EXP>::const_ret_type const_ret_type;
        typedef typename matrix_traits<EXP>::mem_manager_type mem_manager_type;
        typedef typename matrix_traits<EXP>::layout_type layout_type;
        const static long NR = matrix_traits<EXP>::NR;
        const static long NC = matrix_traits<EXP>::NC;
        const static long cost = matrix_traits<EXP>::cost;

        typedef matrix<type,NR,NC,mem_manager_type,layout_type> matrix_type;
        typedef EXP exp_type;
        typedef matrix_exp_iterator<EXP> iterator;
        typedef matrix_exp_iterator<EXP> const_iterator;

        inline const_ret_type operator() (
            long r,
            long c
        ) const 
        { 
            DLIB_ASSERT(r < nr() && c < nc() && r >= 0 && c >= 0, 
                "\tconst type matrix_exp::operator(r,c)"
                << "\n\tYou must give a valid row and column"
                << "\n\tr:    " << r 
                << "\n\tc:    " << c
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc() 
                << "\n\tthis: " << this
                );
            return ref()(r,c); 
        }

        const_ret_type operator() (
            long i
        ) const 
        {
            COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
            DLIB_ASSERT(nc() == 1 || nr() == 1, 
                "\tconst type matrix_exp::operator(i)"
                << "\n\tYou can only use this operator on column or row vectors"
                << "\n\ti:    " << i
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc()
                << "\n\tthis: " << this
                );
            DLIB_ASSERT( ((nc() == 1 && i < nr()) || (nr() == 1 && i < nc())) && i >= 0, 
                "\tconst type matrix_exp::operator(i)"
                << "\n\tYou must give a valid row/column number"
                << "\n\ti:    " << i
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc()
                << "\n\tthis: " << this
                );
            if (nc() == 1)
                return ref()(i,0);
            else
                return ref()(0,i);
        }

        long size (
        ) const { return nr()*nc(); }

        long nr (
        ) const { return get_nr_helper<exp_type,NR>::get(ref()); }

        long nc (
        ) const { return get_nc_helper<exp_type,NC>::get(ref()); }

        template <typename U>
        bool aliases (
            const matrix_exp<U>& item
        ) const { return ref().aliases(item); }

        template <typename U>
        bool destructively_aliases (
            const matrix_exp<U>& item
        ) const { return ref().destructively_aliases(item); }

        inline const exp_type& ref (
        ) const { return *static_cast<const exp_type*>(this); }

        inline operator const type (
        ) const 
        {
            COMPILE_TIME_ASSERT(NC == 1 || NC == 0);
            COMPILE_TIME_ASSERT(NR == 1 || NR == 0);
            DLIB_ASSERT(nr() == 1 && nc() == 1, 
                "\tmatrix_exp::operator const type() const"
                << "\n\tYou can only use this operator on a 1x1 matrix"
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc()
                << "\n\tthis: " << this
                );

            // Put the expression contained in this matrix_exp into
            // a temporary 1x1 matrix so that the expression will encounter
            // all the overloads of matrix_assign() and have the chance to
            // go through any applicable optimizations.
            matrix<type,1,1,mem_manager_type,layout_type> temp(ref());
            return temp(0);
        }

        const_iterator begin() const { return matrix_exp_iterator<EXP>(ref(),0,0); }
        const_iterator end()   const { return matrix_exp_iterator<EXP>(ref(),nr(),0); }

    protected:
        matrix_exp() {}
        matrix_exp(const matrix_exp& ) {}

    private:

        matrix_exp& operator= (const matrix_exp&);
    };

// ----------------------------------------------------------------------------------------

    // something is a matrix if it is convertible to a matrix_exp object
    template <typename T>
    struct is_matrix<T, typename enable_if<is_convertible<T, const matrix_exp<typename T::exp_type>& > >::type > 
    { static const bool value = true; }; 
    /*
        is_matrix<T>::value == 1 if T is a matrix type else 0
    */

// ----------------------------------------------------------------------------------------

    template <
        typename EXP
        >
    class matrix_diag_exp : public matrix_exp<EXP> 
    {
        /*!
            This is a matrix expression type used to represent diagonal matrices.
            That is, square matrices with all off diagonal elements equal to 0.
        !*/

    protected:
        matrix_diag_exp() {}
        matrix_diag_exp(const matrix_diag_exp& item ):matrix_exp<EXP>(item) {}
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_MATRIx_EXP_h_