summaryrefslogtreecommitdiffstats
path: root/src/shared/nscd-flush.c
blob: 9b0ba2d67a0826099d674779ec1829fb8323ab5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/* SPDX-License-Identifier: LGPL-2.1-or-later */

#include <fcntl.h>
#include <poll.h>

#include "fd-util.h"
#include "io-util.h"
#include "nscd-flush.h"
#include "socket-util.h"
#include "strv.h"
#include "time-util.h"

#define NSCD_FLUSH_CACHE_TIMEOUT_USEC (5*USEC_PER_SEC)

struct nscdInvalidateRequest {
        int32_t version;
        int32_t type; /* in glibc this is an enum. We don't replicate this here 1:1. Also, wtf, how unportable is that
                       * even? */
        int32_t key_len;
        char dbname[];
};

static int nscd_flush_cache_one(const char *database, usec_t end) {
        size_t req_size, has_written = 0, has_read = 0, l;
        struct nscdInvalidateRequest *req;
        _cleanup_close_ int fd = -1;
        int32_t resp;
        int events, r;

        assert(database);

        l = strlen(database);
        req_size = offsetof(struct nscdInvalidateRequest, dbname) + l + 1;

        req = alloca_safe(req_size);
        *req = (struct nscdInvalidateRequest) {
                .version = 2,
                .type = 10,
                .key_len = l + 1,
        };

        strcpy(req->dbname, database);

        fd = socket(AF_UNIX, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0);
        if (fd < 0)
                return log_debug_errno(errno, "Failed to allocate nscd socket: %m");

        /* Note: connect() returns EINPROGRESS if O_NONBLOCK is set and establishing a connection takes time. The
         * kernel lets us know this way that the connection is now being established, and we should watch with poll()
         * to learn when it is fully established. That said, AF_UNIX on Linux never triggers this IRL (connect() is
         * always instant on AF_UNIX), hence handling this is mostly just an exercise in defensive, protocol-agnostic
         * programming.
         *
         * connect() returns EAGAIN if the socket's backlog limit has been reached. When we see this we give up right
         * away, after all this entire function here is written in a defensive style so that a non-responding nscd
         * doesn't stall us for good. (Even if we wanted to handle this better: the Linux kernel doesn't really have a
         * nice way to connect() to a server synchronously with a time limit that would also cover dealing with the
         * backlog limit. After all SO_RCVTIMEO and SR_SNDTIMEO don't apply to connect(), and alarm() is frickin' ugly
         * and not really reasonably usable from threads-aware code.) */
        r = connect_unix_path(fd, AT_FDCWD, "/run/nscd/socket");
        if (r < 0) {
                if (r == -EAGAIN)
                        return log_debug_errno(r, "nscd is overloaded (backlog limit reached) and refuses to take further connections: %m");
                if (r != -EINPROGRESS)
                        return log_debug_errno(r, "Failed to connect to nscd socket: %m");

                /* Continue in case of EINPROGRESS, but don't bother with send() or recv() until being notified that
                 * establishing the connection is complete. */
                events = 0;
        } else
                events = POLLIN|POLLOUT; /* Let's assume initially that we can write and read to the fd, to suppress
                                          * one poll() invocation */
        for (;;) {
                usec_t p;

                if (events & POLLOUT) {
                        ssize_t m;

                        assert(has_written < req_size);

                        m = send(fd, (uint8_t*) req + has_written, req_size - has_written, MSG_NOSIGNAL);
                        if (m < 0) {
                                if (errno != EAGAIN) /* Note that EAGAIN is returned by the kernel whenever it can't
                                                      * take the data right now, and that includes if the connect() is
                                                      * asynchronous and we saw EINPROGRESS on it, and it hasn't
                                                      * completed yet. */
                                        return log_debug_errno(errno, "Failed to write to nscd socket: %m");
                        } else
                                has_written += m;
                }

                if (events & (POLLIN|POLLERR|POLLHUP)) {
                        ssize_t m;

                        if (has_read >= sizeof(resp))
                                return log_debug_errno(SYNTHETIC_ERRNO(EIO), "Response from nscd longer than expected: %m");

                        m = recv(fd, (uint8_t*) &resp + has_read, sizeof(resp) - has_read, 0);
                        if (m < 0) {
                                if (errno != EAGAIN)
                                        return log_debug_errno(errno, "Failed to read from nscd socket: %m");
                        } else if (m == 0) { /* EOF */
                                if (has_read == 0 && has_written >= req_size) /* Older nscd immediately terminated the
                                                                               * connection, accept that as OK */
                                        return 1;

                                return log_debug_errno(SYNTHETIC_ERRNO(EIO), "nscd prematurely ended connection.");
                        } else
                                has_read += m;
                }

                if (has_written >= req_size && has_read >= sizeof(resp)) { /* done? */
                        if (resp < 0)
                                return log_debug_errno(SYNTHETIC_ERRNO(EBADMSG), "nscd sent us a negative error number: %i", resp);
                        if (resp > 0)
                                return log_debug_errno(resp, "nscd return failure code on invalidating '%s'.", database);
                        return 1;
                }

                p = now(CLOCK_MONOTONIC);
                if (p >= end)
                        return -ETIMEDOUT;

                events = fd_wait_for_event(fd, POLLIN | (has_written < req_size ? POLLOUT : 0), end - p);
                if (events < 0)
                        return events;
        }
}

int nscd_flush_cache(char **databases) {
        usec_t end;
        int r = 0;

        /* Tries to invalidate the specified database in nscd. We do this carefully, with a 5s timeout, so that we
         * don't block indefinitely on another service. */

        end = usec_add(now(CLOCK_MONOTONIC), NSCD_FLUSH_CACHE_TIMEOUT_USEC);

        STRV_FOREACH(i, databases) {
                int k;

                k = nscd_flush_cache_one(*i, end);
                if (k < 0 && r >= 0)
                        r = k;
        }

        return r;
}