summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/matrix/lapack/syev.h
blob: 0c9fd251a175fd67f5d7aa609cb525bd27db37dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_LAPACk_EV_Hh_
#define DLIB_LAPACk_EV_Hh_

#include "fortran_id.h"
#include "../matrix.h"

namespace dlib
{
    namespace lapack
    {
        namespace binding
        {
            extern "C"
            {
                void DLIB_FORTRAN_ID(dsyev) (char *jobz, char *uplo, integer *n, double *a,
                                             integer *lda, double *w, double *work, integer *lwork, 
                                             integer *info);

                void DLIB_FORTRAN_ID(ssyev) (char *jobz, char *uplo, integer *n, float *a,
                                             integer *lda, float *w, float *work, integer *lwork, 
                                             integer *info);

            }

            inline int syev (char jobz, char uplo, integer n, double *a,
                             integer lda, double *w, double *work, integer lwork)
            {
                integer info = 0;
                DLIB_FORTRAN_ID(dsyev)(&jobz, &uplo, &n, a,
                                       &lda, w, work, &lwork, &info);
                return info;
            }

            inline int syev (char jobz, char uplo, integer n, float *a,
                             integer lda, float *w, float *work, integer lwork)
            {
                integer info = 0;
                DLIB_FORTRAN_ID(ssyev)(&jobz, &uplo, &n, a,
                                       &lda, w, work, &lwork, &info);
                return info;
            }


        }

    // ------------------------------------------------------------------------------------

/*  -- LAPACK driver routine (version 3.1) -- */
/*     Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
/*     November 2006 */

/*     .. Scalar Arguments .. */
/*     .. */
/*     .. Array Arguments .. */
/*     .. */

/*  Purpose */
/*  ======= */

/*  DSYEV computes all eigenvalues and, optionally, eigenvectors of a */
/*  real symmetric matrix A. */

/*  Arguments */
/*  ========= */

/*  JOBZ    (input) CHARACTER*1 */
/*          = 'N':  Compute eigenvalues only; */
/*          = 'V':  Compute eigenvalues and eigenvectors. */

/*  UPLO    (input) CHARACTER*1 */
/*          = 'U':  Upper triangle of A is stored; */
/*          = 'L':  Lower triangle of A is stored. */

/*  N       (input) INTEGER */
/*          The order of the matrix A.  N >= 0. */

/*  A       (input/output) DOUBLE PRECISION array, dimension (LDA, N) */
/*          On entry, the symmetric matrix A.  If UPLO = 'U', the */
/*          leading N-by-N upper triangular part of A contains the */
/*          upper triangular part of the matrix A.  If UPLO = 'L', */
/*          the leading N-by-N lower triangular part of A contains */
/*          the lower triangular part of the matrix A. */
/*          On exit, if JOBZ = 'V', then if INFO = 0, A contains the */
/*          orthonormal eigenvectors of the matrix A. */
/*          If JOBZ = 'N', then on exit the lower triangle (if UPLO='L') */
/*          or the upper triangle (if UPLO='U') of A, including the */
/*          diagonal, is destroyed. */

/*  LDA     (input) INTEGER */
/*          The leading dimension of the array A.  LDA >= max(1,N). */

/*  W       (output) DOUBLE PRECISION array, dimension (N) */
/*          If INFO = 0, the eigenvalues in ascending order. */

/*  WORK    (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
/*          On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */

/*  LWORK   (input) INTEGER */
/*          The length of the array WORK.  LWORK >= max(1,3*N-1). */
/*          For optimal efficiency, LWORK >= (NB+2)*N, */
/*          where NB is the blocksize for DSYTRD returned by ILAENV. */

/*          If LWORK = -1, then a workspace query is assumed; the routine */
/*          only calculates the optimal size of the WORK array, returns */
/*          this value as the first entry of the WORK array, and no error */
/*          message related to LWORK is issued by XERBLA. */

/*  INFO    (output) INTEGER */
/*          = 0:  successful exit */
/*          < 0:  if INFO = -i, the i-th argument had an illegal value */
/*          > 0:  if INFO = i, the algorithm failed to converge; i */
/*                off-diagonal elements of an intermediate tridiagonal */
/*                form did not converge to zero. */


    // ------------------------------------------------------------------------------------

        template <
            typename T, 
            long NR1, long NR2, 
            long NC1, long NC2,
            typename MM
            >
        int syev (
            const char jobz,
            const char uplo,
            matrix<T,NR1,NC1,MM,column_major_layout>& a,
            matrix<T,NR2,NC2,MM,column_major_layout>& w
        )
        {
            matrix<T,0,1,MM,column_major_layout> work;

            const long n = a.nr();

            w.set_size(n,1);


            // figure out how big the workspace needs to be.
            T work_size = 1;
            int info = binding::syev(jobz, uplo, n, &a(0,0),
                                     a.nr(), &w(0,0), &work_size, -1);

            if (info != 0)
                return info;

            if (work.size() < work_size)
                work.set_size(static_cast<long>(work_size), 1);

            // compute the actual decomposition 
            info = binding::syev(jobz, uplo, n, &a(0,0),
                                 a.nr(), &w(0,0), &work(0,0), work.size());

            return info;
        }

    // ------------------------------------------------------------------------------------

        template <
            typename T, 
            long NR1, long NR2, 
            long NC1, long NC2,
            typename MM
            >
        int syev (
            char jobz,
            char uplo,
            matrix<T,NR1,NC1,MM,row_major_layout>& a,
            matrix<T,NR2,NC2,MM,row_major_layout>& w
        )
        {
            matrix<T,0,1,MM,row_major_layout> work;

            if (uplo == 'L')
                uplo = 'U';
            else
                uplo = 'L';

            const long n = a.nr();

            w.set_size(n,1);


            // figure out how big the workspace needs to be.
            T work_size = 1;
            int info = binding::syev(jobz, uplo, n, &a(0,0),
                                     a.nc(), &w(0,0), &work_size, -1);

            if (info != 0)
                return info;

            if (work.size() < work_size)
                work.set_size(static_cast<long>(work_size), 1);

            // compute the actual decomposition 
            info = binding::syev(jobz, uplo, n, &a(0,0),
                                 a.nc(), &w(0,0), &work(0,0), work.size());


            a = trans(a);

            return info;
        }

    // ------------------------------------------------------------------------------------

    }

}

// ----------------------------------------------------------------------------------------

#endif // DLIB_LAPACk_EV_Hh_