summaryrefslogtreecommitdiffstats
path: root/src/tests/intg/sssd_netgroup.py
blob: 81d017fa3dc2b5b6141c05b64b024354739e73d5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#
# Module for simulation of utility "getent netgroup -s sss" from coreutils
#
# Copyright (c) 2016 Red Hat, Inc.
# Author: Lukas Slebodnik <lslebodn@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from ctypes import (c_int, c_char, c_char_p, c_size_t, c_void_p, c_ulong,
                    POINTER, Structure, Union, create_string_buffer, get_errno)
from sssd_nss import NssReturnCode, nss_sss_ctypes_loader


class NetgroupType(object):
    """ 'enum' class for type of netgroup """
    TRIPLE_VAL = 0
    GROUP_VAL = 1


class Triple(Structure):
    _fields_ = [("host", c_char_p),
                ("user", c_char_p),
                ("domain", c_char_p)]


class Val(Union):
    _fields_ = [("triple", Triple),
                ("group", c_char_p)]


class Idx(Union):
    _fields_ = [("cursor", POINTER(c_char)),
                ("position", c_ulong)]


class NameList(Structure):
    pass


NameList._fields_ = [("next", POINTER(NameList)),
                     ("name", POINTER(c_char))]


class Netgrent(Structure):
    _fields_ = [("type", c_int),
                ("val", Val),
                ("data", POINTER(c_char)),
                ("data_size", c_size_t),
                ("idx", Idx),
                ("first", c_int),
                ("known_groups", POINTER(NameList)),
                ("needed_groups", POINTER(NameList)),
                ("nip", c_void_p)]


class NetgroupRetriever(object):
    def __init__(self, name):
        self.name = name.encode('utf-8')
        self.needed_groups = []
        self.known_groups = []
        self.netgroups = []

    @staticmethod
    def _setnetgrent(netgroup):
        """
        This private method is ctypes wrapper for
        enum nss_status _nss_sss_setnetgrent(const char *netgroup,
                                             struct __netgrent *result)

        @param string name name of netgroup

        @return (int, POINTER(Netgrent)) (err, result_p)
            err is a constant from class NssReturnCode and in case of SUCCESS
            result_p will contain POINTER(Netgrent) which can be used in
            _getnetgrent_r or _getnetgrent_r.
        """
        func = nss_sss_ctypes_loader('_nss_sss_setnetgrent')
        func.restype = c_int
        func.argtypes = [c_char_p, POINTER(Netgrent)]

        result = Netgrent()
        result_p = POINTER(Netgrent)(result)

        res = func(c_char_p(netgroup), result_p)

        return (int(res), result_p)

    @staticmethod
    def _getnetgrent_r(result_p, buff, buff_len):
        """
        This private method is ctypes wrapper for
        enum nss_status _nss_sss_getnetgrent_r(struct __netgrent *result,
                                               char *buffer, size_t buflen,
                                               int *errnop)
        @param POINTER(Netgrent) result_p pointer to initialized C structure
               struct __netgrent
        @param ctypes.c_char_Array buff buffer used by C functions
        @param int buff_len size of c_char_Array passed as a parameter buff

        @return (int, int, List[(string, string, string])
                (err, errno, netgroups)
            if err is NssReturnCode.SUCCESS netgroups will contain list of
            touples. Each touple will consist of 3 elements either string or
        """
        func = nss_sss_ctypes_loader('_nss_sss_getnetgrent_r')
        func.restype = c_int
        func.argtypes = [POINTER(Netgrent), POINTER(c_char), c_size_t,
                         POINTER(c_int)]

        errno = POINTER(c_int)(c_int(0))

        res = func(result_p, buff, buff_len, errno)

        return (int(res), int(errno[0]), result_p)

    @staticmethod
    def _endnetgrent(result_p):
        """
        This private method is ctypes wrapper for
        enum nss_status _nss_sss_endnetgrent(struct __netgrent *result)

        @param POINTER(Netgrent) result_p pointer to initialized C structure
               struct __netgrent

        @return int a constant from class NssReturnCode
        """
        func = nss_sss_ctypes_loader('_nss_sss_endnetgrent')
        func.restype = c_int
        func.argtypes = [POINTER(Netgrent)]

        res = func(result_p)

        return int(res)

    def get_netgroups(self):
        """
        Function will return netgroup triplets for given user. All nested
        netgroups will be retrieved as part of executions and will content
        will be merged with direct triplets.
        Missing nested netgroups will not cause failure and are considered
        as an empty netgroup without triplets.

        @param string name name of netgroup

        @return (int, int, List[(string, string, string])
                (err, errno, netgroups)
            if err is NssReturnCode.SUCCESS netgroups will contain list of
            touples. Each touple will consist of 3 elements either string or
            None (host, user, domain).
        """
        res, errno, result = self._flat_fetch_netgroups(self.name)
        if res != NssReturnCode.SUCCESS:
            return (res, errno, self.netgroups)

        self.netgroups += result

        while self.needed_groups:
            name = self.needed_groups.pop(0)

            nest_res, nest_errno, result = self._flat_fetch_netgroups(name)
            # do not fail for missing nested netgroup
            if nest_res not in (NssReturnCode.SUCCESS, NssReturnCode.NOTFOUND):
                return (nest_res, nest_errno, self.netgroups)

            self.netgroups = result + self.netgroups

        return (res, errno, self.netgroups)

    def _flat_fetch_netgroups(self, name):
        """
        Function will return netgroup triplets for given user. The nested
        netgroups will not be returned. Missing nested netgroups will be
        appended to the array needed_groups

        @param string name name of netgroup

        @return (int, int, List[(string, string, string])
                (err, errno, netgroups)
            if err is NssReturnCode.SUCCESS netgroups will contain list of
            touples. Each touple will consist of 3 elements either string or
            None (host, user, domain).
        """
        buff_len = 1024 * 1024
        buff = create_string_buffer(buff_len)

        result = []

        res, result_p = self._setnetgrent(name)
        if res != NssReturnCode.SUCCESS:
            return (res, get_errno(), result)

        res, errno, result_p = self._getnetgrent_r(result_p, buff, buff_len)
        while res == NssReturnCode.SUCCESS:
            if result_p[0].type == NetgroupType.GROUP_VAL:
                nested_netgroup = result_p[0].val.group
                if nested_netgroup not in self.known_groups:
                    self.needed_groups.append(nested_netgroup)
                    self.known_groups.append(nested_netgroup)

            if result_p[0].type == NetgroupType.TRIPLE_VAL:
                triple = result_p[0].val.triple
                result.append((triple.host and triple.host.decode('utf-8')
                               or "",
                               triple.user and triple.user.decode('utf-8')
                               or "",
                               triple.domain and triple.domain.decode('utf-8')
                               or ""))

            res, errno, result_p = self._getnetgrent_r(result_p, buff,
                                                       buff_len)

        if res != NssReturnCode.RETURN:
            return (res, errno, result)

        res = self._endnetgrent(result_p)

        return (res, errno, result)


def get_sssd_netgroups(name):
    """
    Function will return netgroup triplets for given user. It will gather
    netgroups only provided by sssd.
    The equivalent of "getent netgroup -s sss user"

    @param string name name of netgroup

    @return (int, int, List[(string, string, string]) (err, errno, netgroups)
        if err is NssReturnCode.SUCCESS netgroups will contain list of touples.
        Each touple will consist of 3 elements either string or None
        (host, user, domain).
    """

    retriever = NetgroupRetriever(name)

    return retriever.get_netgroups()