diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-06 01:02:30 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-06 01:02:30 +0000 |
commit | 76cb841cb886eef6b3bee341a2266c76578724ad (patch) | |
tree | f5892e5ba6cc11949952a6ce4ecbe6d516d6ce58 /net/sunrpc | |
parent | Initial commit. (diff) | |
download | linux-76cb841cb886eef6b3bee341a2266c76578724ad.tar.xz linux-76cb841cb886eef6b3bee341a2266c76578724ad.zip |
Adding upstream version 4.19.249.upstream/4.19.249
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'net/sunrpc')
62 files changed, 45172 insertions, 0 deletions
diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig new file mode 100644 index 000000000..ac09ca803 --- /dev/null +++ b/net/sunrpc/Kconfig @@ -0,0 +1,64 @@ +config SUNRPC + tristate + depends on MULTIUSER + +config SUNRPC_GSS + tristate + select OID_REGISTRY + depends on MULTIUSER + +config SUNRPC_BACKCHANNEL + bool + depends on SUNRPC + +config SUNRPC_SWAP + bool + depends on SUNRPC + +config RPCSEC_GSS_KRB5 + tristate "Secure RPC: Kerberos V mechanism" + depends on SUNRPC && CRYPTO + depends on CRYPTO_MD5 && CRYPTO_DES && CRYPTO_CBC && CRYPTO_CTS + depends on CRYPTO_ECB && CRYPTO_HMAC && CRYPTO_SHA1 && CRYPTO_AES + depends on CRYPTO_ARC4 + default y + select SUNRPC_GSS + help + Choose Y here to enable Secure RPC using the Kerberos version 5 + GSS-API mechanism (RFC 1964). + + Secure RPC calls with Kerberos require an auxiliary user-space + daemon which may be found in the Linux nfs-utils package + available from http://linux-nfs.org/. In addition, user-space + Kerberos support should be installed. + + If unsure, say Y. + +config SUNRPC_DEBUG + bool "RPC: Enable dprintk debugging" + depends on SUNRPC && SYSCTL + select DEBUG_FS + help + This option enables a sysctl-based debugging interface + that is be used by the 'rpcdebug' utility to turn on or off + logging of different aspects of the kernel RPC activity. + + Disabling this option will make your kernel slightly smaller, + but makes troubleshooting NFS issues significantly harder. + + If unsure, say Y. + +config SUNRPC_XPRT_RDMA + tristate "RPC-over-RDMA transport" + depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS + default SUNRPC && INFINIBAND + select SG_POOL + help + This option allows the NFS client and server to use RDMA + transports (InfiniBand, iWARP, or RoCE). + + To compile this support as a module, choose M. The module + will be called rpcrdma.ko. + + If unsure, or you know there is no RDMA capability on your + hardware platform, say N. diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile new file mode 100644 index 000000000..090658c3d --- /dev/null +++ b/net/sunrpc/Makefile @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Linux kernel SUN RPC +# + + +obj-$(CONFIG_SUNRPC) += sunrpc.o +obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ + +sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ + auth.o auth_null.o auth_unix.o auth_generic.o \ + svc.o svcsock.o svcauth.o svcauth_unix.o \ + addr.o rpcb_clnt.o timer.o xdr.o \ + sunrpc_syms.o cache.o rpc_pipe.o \ + svc_xprt.o \ + xprtmultipath.o +sunrpc-$(CONFIG_SUNRPC_DEBUG) += debugfs.o +sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o +sunrpc-$(CONFIG_PROC_FS) += stats.o +sunrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/sunrpc/addr.c b/net/sunrpc/addr.c new file mode 100644 index 000000000..7404f0270 --- /dev/null +++ b/net/sunrpc/addr.c @@ -0,0 +1,357 @@ +/* + * Copyright 2009, Oracle. All rights reserved. + * + * Convert socket addresses to presentation addresses and universal + * addresses, and vice versa. + * + * Universal addresses are introduced by RFC 1833 and further refined by + * recent RFCs describing NFSv4. The universal address format is part + * of the external (network) interface provided by rpcbind version 3 + * and 4, and by NFSv4. Such an address is a string containing a + * presentation format IP address followed by a port number in + * "hibyte.lobyte" format. + * + * IPv6 addresses can also include a scope ID, typically denoted by + * a '%' followed by a device name or a non-negative integer. Refer to + * RFC 4291, Section 2.2 for details on IPv6 presentation formats. + */ + +#include <net/ipv6.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/slab.h> +#include <linux/export.h> + +#if IS_ENABLED(CONFIG_IPV6) + +static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, + char *buf, const int buflen) +{ + const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + const struct in6_addr *addr = &sin6->sin6_addr; + + /* + * RFC 4291, Section 2.2.2 + * + * Shorthanded ANY address + */ + if (ipv6_addr_any(addr)) + return snprintf(buf, buflen, "::"); + + /* + * RFC 4291, Section 2.2.2 + * + * Shorthanded loopback address + */ + if (ipv6_addr_loopback(addr)) + return snprintf(buf, buflen, "::1"); + + /* + * RFC 4291, Section 2.2.3 + * + * Special presentation address format for mapped v4 + * addresses. + */ + if (ipv6_addr_v4mapped(addr)) + return snprintf(buf, buflen, "::ffff:%pI4", + &addr->s6_addr32[3]); + + /* + * RFC 4291, Section 2.2.1 + */ + return snprintf(buf, buflen, "%pI6c", addr); +} + +static size_t rpc_ntop6(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + char scopebuf[IPV6_SCOPE_ID_LEN]; + size_t len; + int rc; + + len = rpc_ntop6_noscopeid(sap, buf, buflen); + if (unlikely(len == 0)) + return len; + + if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL)) + return len; + if (sin6->sin6_scope_id == 0) + return len; + + rc = snprintf(scopebuf, sizeof(scopebuf), "%c%u", + IPV6_SCOPE_DELIMITER, sin6->sin6_scope_id); + if (unlikely((size_t)rc >= sizeof(scopebuf))) + return 0; + + len += rc; + if (unlikely(len >= buflen)) + return 0; + + strcat(buf, scopebuf); + return len; +} + +#else /* !IS_ENABLED(CONFIG_IPV6) */ + +static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, + char *buf, const int buflen) +{ + return 0; +} + +static size_t rpc_ntop6(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + return 0; +} + +#endif /* !IS_ENABLED(CONFIG_IPV6) */ + +static int rpc_ntop4(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + const struct sockaddr_in *sin = (struct sockaddr_in *)sap; + + return snprintf(buf, buflen, "%pI4", &sin->sin_addr); +} + +/** + * rpc_ntop - construct a presentation address in @buf + * @sap: socket address + * @buf: construction area + * @buflen: size of @buf, in bytes + * + * Plants a %NUL-terminated string in @buf and returns the length + * of the string, excluding the %NUL. Otherwise zero is returned. + */ +size_t rpc_ntop(const struct sockaddr *sap, char *buf, const size_t buflen) +{ + switch (sap->sa_family) { + case AF_INET: + return rpc_ntop4(sap, buf, buflen); + case AF_INET6: + return rpc_ntop6(sap, buf, buflen); + } + + return 0; +} +EXPORT_SYMBOL_GPL(rpc_ntop); + +static size_t rpc_pton4(const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)sap; + u8 *addr = (u8 *)&sin->sin_addr.s_addr; + + if (buflen > INET_ADDRSTRLEN || salen < sizeof(struct sockaddr_in)) + return 0; + + memset(sap, 0, sizeof(struct sockaddr_in)); + + if (in4_pton(buf, buflen, addr, '\0', NULL) == 0) + return 0; + + sin->sin_family = AF_INET; + return sizeof(struct sockaddr_in); +} + +#if IS_ENABLED(CONFIG_IPV6) +static int rpc_parse_scope_id(struct net *net, const char *buf, + const size_t buflen, const char *delim, + struct sockaddr_in6 *sin6) +{ + char *p; + size_t len; + + if ((buf + buflen) == delim) + return 1; + + if (*delim != IPV6_SCOPE_DELIMITER) + return 0; + + if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL)) + return 0; + + len = (buf + buflen) - delim - 1; + p = kstrndup(delim + 1, len, GFP_KERNEL); + if (p) { + u32 scope_id = 0; + struct net_device *dev; + + dev = dev_get_by_name(net, p); + if (dev != NULL) { + scope_id = dev->ifindex; + dev_put(dev); + } else { + if (kstrtou32(p, 10, &scope_id) != 0) { + kfree(p); + return 0; + } + } + + kfree(p); + + sin6->sin6_scope_id = scope_id; + return 1; + } + + return 0; +} + +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + u8 *addr = (u8 *)&sin6->sin6_addr.in6_u; + const char *delim; + + if (buflen > (INET6_ADDRSTRLEN + IPV6_SCOPE_ID_LEN) || + salen < sizeof(struct sockaddr_in6)) + return 0; + + memset(sap, 0, sizeof(struct sockaddr_in6)); + + if (in6_pton(buf, buflen, addr, IPV6_SCOPE_DELIMITER, &delim) == 0) + return 0; + + if (!rpc_parse_scope_id(net, buf, buflen, delim, sin6)) + return 0; + + sin6->sin6_family = AF_INET6; + return sizeof(struct sockaddr_in6); +} +#else +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + return 0; +} +#endif + +/** + * rpc_pton - Construct a sockaddr in @sap + * @net: applicable network namespace + * @buf: C string containing presentation format IP address + * @buflen: length of presentation address in bytes + * @sap: buffer into which to plant socket address + * @salen: size of buffer in bytes + * + * Returns the size of the socket address if successful; otherwise + * zero is returned. + * + * Plants a socket address in @sap and returns the size of the + * socket address, if successful. Returns zero if an error + * occurred. + */ +size_t rpc_pton(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + unsigned int i; + + for (i = 0; i < buflen; i++) + if (buf[i] == ':') + return rpc_pton6(net, buf, buflen, sap, salen); + return rpc_pton4(buf, buflen, sap, salen); +} +EXPORT_SYMBOL_GPL(rpc_pton); + +/** + * rpc_sockaddr2uaddr - Construct a universal address string from @sap. + * @sap: socket address + * @gfp_flags: allocation mode + * + * Returns a %NUL-terminated string in dynamically allocated memory; + * otherwise NULL is returned if an error occurred. Caller must + * free the returned string. + */ +char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags) +{ + char portbuf[RPCBIND_MAXUADDRPLEN]; + char addrbuf[RPCBIND_MAXUADDRLEN]; + unsigned short port; + + switch (sap->sa_family) { + case AF_INET: + if (rpc_ntop4(sap, addrbuf, sizeof(addrbuf)) == 0) + return NULL; + port = ntohs(((struct sockaddr_in *)sap)->sin_port); + break; + case AF_INET6: + if (rpc_ntop6_noscopeid(sap, addrbuf, sizeof(addrbuf)) == 0) + return NULL; + port = ntohs(((struct sockaddr_in6 *)sap)->sin6_port); + break; + default: + return NULL; + } + + if (snprintf(portbuf, sizeof(portbuf), + ".%u.%u", port >> 8, port & 0xff) > (int)sizeof(portbuf)) + return NULL; + + if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) > sizeof(addrbuf)) + return NULL; + + return kstrdup(addrbuf, gfp_flags); +} + +/** + * rpc_uaddr2sockaddr - convert a universal address to a socket address. + * @net: applicable network namespace + * @uaddr: C string containing universal address to convert + * @uaddr_len: length of universal address string + * @sap: buffer into which to plant socket address + * @salen: size of buffer + * + * @uaddr does not have to be '\0'-terminated, but kstrtou8() and + * rpc_pton() require proper string termination to be successful. + * + * Returns the size of the socket address if successful; otherwise + * zero is returned. + */ +size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr, + const size_t uaddr_len, struct sockaddr *sap, + const size_t salen) +{ + char *c, buf[RPCBIND_MAXUADDRLEN + sizeof('\0')]; + u8 portlo, porthi; + unsigned short port; + + if (uaddr_len > RPCBIND_MAXUADDRLEN) + return 0; + + memcpy(buf, uaddr, uaddr_len); + + buf[uaddr_len] = '\0'; + c = strrchr(buf, '.'); + if (unlikely(c == NULL)) + return 0; + if (unlikely(kstrtou8(c + 1, 10, &portlo) != 0)) + return 0; + + *c = '\0'; + c = strrchr(buf, '.'); + if (unlikely(c == NULL)) + return 0; + if (unlikely(kstrtou8(c + 1, 10, &porthi) != 0)) + return 0; + + port = (unsigned short)((porthi << 8) | portlo); + + *c = '\0'; + if (rpc_pton(net, buf, strlen(buf), sap, salen) == 0) + return 0; + + switch (sap->sa_family) { + case AF_INET: + ((struct sockaddr_in *)sap)->sin_port = htons(port); + return sizeof(struct sockaddr_in); + case AF_INET6: + ((struct sockaddr_in6 *)sap)->sin6_port = htons(port); + return sizeof(struct sockaddr_in6); + } + + return 0; +} +EXPORT_SYMBOL_GPL(rpc_uaddr2sockaddr); diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c new file mode 100644 index 000000000..305ecea92 --- /dev/null +++ b/net/sunrpc/auth.c @@ -0,0 +1,896 @@ +/* + * linux/net/sunrpc/auth.c + * + * Generic RPC client authentication API. + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/cred.h> +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/errno.h> +#include <linux/hash.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/spinlock.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define RPC_CREDCACHE_DEFAULT_HASHBITS (4) +struct rpc_cred_cache { + struct hlist_head *hashtable; + unsigned int hashbits; + spinlock_t lock; +}; + +static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; + +static DEFINE_SPINLOCK(rpc_authflavor_lock); +static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { + &authnull_ops, /* AUTH_NULL */ + &authunix_ops, /* AUTH_UNIX */ + NULL, /* others can be loadable modules */ +}; + +static LIST_HEAD(cred_unused); +static unsigned long number_cred_unused; + +#define MAX_HASHTABLE_BITS (14) +static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) +{ + unsigned long num; + unsigned int nbits; + int ret; + + if (!val) + goto out_inval; + ret = kstrtoul(val, 0, &num); + if (ret) + goto out_inval; + nbits = fls(num - 1); + if (nbits > MAX_HASHTABLE_BITS || nbits < 2) + goto out_inval; + *(unsigned int *)kp->arg = nbits; + return 0; +out_inval: + return -EINVAL; +} + +static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) +{ + unsigned int nbits; + + nbits = *(unsigned int *)kp->arg; + return sprintf(buffer, "%u", 1U << nbits); +} + +#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); + +static const struct kernel_param_ops param_ops_hashtbl_sz = { + .set = param_set_hashtbl_sz, + .get = param_get_hashtbl_sz, +}; + +module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); +MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); + +static unsigned long auth_max_cred_cachesize = ULONG_MAX; +module_param(auth_max_cred_cachesize, ulong, 0644); +MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size"); + +static u32 +pseudoflavor_to_flavor(u32 flavor) { + if (flavor > RPC_AUTH_MAXFLAVOR) + return RPC_AUTH_GSS; + return flavor; +} + +int +rpcauth_register(const struct rpc_authops *ops) +{ + rpc_authflavor_t flavor; + int ret = -EPERM; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + spin_lock(&rpc_authflavor_lock); + if (auth_flavors[flavor] == NULL) { + auth_flavors[flavor] = ops; + ret = 0; + } + spin_unlock(&rpc_authflavor_lock); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_register); + +int +rpcauth_unregister(const struct rpc_authops *ops) +{ + rpc_authflavor_t flavor; + int ret = -EPERM; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + spin_lock(&rpc_authflavor_lock); + if (auth_flavors[flavor] == ops) { + auth_flavors[flavor] = NULL; + ret = 0; + } + spin_unlock(&rpc_authflavor_lock); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_unregister); + +/** + * rpcauth_get_pseudoflavor - check if security flavor is supported + * @flavor: a security flavor + * @info: a GSS mech OID, quality of protection, and service value + * + * Verifies that an appropriate kernel module is available or already loaded. + * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is + * not supported locally. + */ +rpc_authflavor_t +rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) +{ + const struct rpc_authops *ops; + rpc_authflavor_t pseudoflavor; + + ops = auth_flavors[flavor]; + if (ops == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + return RPC_AUTH_MAXFLAVOR; + } + spin_unlock(&rpc_authflavor_lock); + + pseudoflavor = flavor; + if (ops->info2flavor != NULL) + pseudoflavor = ops->info2flavor(info); + + module_put(ops->owner); + return pseudoflavor; +} +EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); + +/** + * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. + */ +int +rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) +{ + rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); + const struct rpc_authops *ops; + int result; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + + ops = auth_flavors[flavor]; + if (ops == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + return -ENOENT; + } + spin_unlock(&rpc_authflavor_lock); + + result = -ENOENT; + if (ops->flavor2info != NULL) + result = ops->flavor2info(pseudoflavor, info); + + module_put(ops->owner); + return result; +} +EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); + +/** + * rpcauth_list_flavors - discover registered flavors and pseudoflavors + * @array: array to fill in + * @size: size of "array" + * + * Returns the number of array items filled in, or a negative errno. + * + * The returned array is not sorted by any policy. Callers should not + * rely on the order of the items in the returned array. + */ +int +rpcauth_list_flavors(rpc_authflavor_t *array, int size) +{ + rpc_authflavor_t flavor; + int result = 0; + + spin_lock(&rpc_authflavor_lock); + for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { + const struct rpc_authops *ops = auth_flavors[flavor]; + rpc_authflavor_t pseudos[4]; + int i, len; + + if (result >= size) { + result = -ENOMEM; + break; + } + + if (ops == NULL) + continue; + if (ops->list_pseudoflavors == NULL) { + array[result++] = ops->au_flavor; + continue; + } + len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); + if (len < 0) { + result = len; + break; + } + for (i = 0; i < len; i++) { + if (result >= size) { + result = -ENOMEM; + break; + } + array[result++] = pseudos[i]; + } + } + spin_unlock(&rpc_authflavor_lock); + + dprintk("RPC: %s returns %d\n", __func__, result); + return result; +} +EXPORT_SYMBOL_GPL(rpcauth_list_flavors); + +struct rpc_auth * +rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct rpc_auth *auth; + const struct rpc_authops *ops; + u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); + + auth = ERR_PTR(-EINVAL); + if (flavor >= RPC_AUTH_MAXFLAVOR) + goto out; + + if ((ops = auth_flavors[flavor]) == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + goto out; + } + spin_unlock(&rpc_authflavor_lock); + auth = ops->create(args, clnt); + module_put(ops->owner); + if (IS_ERR(auth)) + return auth; + if (clnt->cl_auth) + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = auth; + +out: + return auth; +} +EXPORT_SYMBOL_GPL(rpcauth_create); + +void +rpcauth_release(struct rpc_auth *auth) +{ + if (!atomic_dec_and_test(&auth->au_count)) + return; + auth->au_ops->destroy(auth); +} + +static DEFINE_SPINLOCK(rpc_credcache_lock); + +static void +rpcauth_unhash_cred_locked(struct rpc_cred *cred) +{ + hlist_del_rcu(&cred->cr_hash); + smp_mb__before_atomic(); + clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); +} + +static int +rpcauth_unhash_cred(struct rpc_cred *cred) +{ + spinlock_t *cache_lock; + int ret; + + cache_lock = &cred->cr_auth->au_credcache->lock; + spin_lock(cache_lock); + ret = atomic_read(&cred->cr_count) == 0; + if (ret) + rpcauth_unhash_cred_locked(cred); + spin_unlock(cache_lock); + return ret; +} + +/* + * Initialize RPC credential cache + */ +int +rpcauth_init_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *new; + unsigned int hashsize; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + goto out_nocache; + new->hashbits = auth_hashbits; + hashsize = 1U << new->hashbits; + new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL); + if (!new->hashtable) + goto out_nohashtbl; + spin_lock_init(&new->lock); + auth->au_credcache = new; + return 0; +out_nohashtbl: + kfree(new); +out_nocache: + return -ENOMEM; +} +EXPORT_SYMBOL_GPL(rpcauth_init_credcache); + +/* + * Setup a credential key lifetime timeout notification + */ +int +rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred) +{ + if (!cred->cr_auth->au_ops->key_timeout) + return 0; + return cred->cr_auth->au_ops->key_timeout(auth, cred); +} +EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify); + +bool +rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred) +{ + if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT) + return false; + if (!cred->cr_ops->crkey_to_expire) + return false; + return cred->cr_ops->crkey_to_expire(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire); + +char * +rpcauth_stringify_acceptor(struct rpc_cred *cred) +{ + if (!cred->cr_ops->crstringify_acceptor) + return NULL; + return cred->cr_ops->crstringify_acceptor(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor); + +/* + * Destroy a list of credentials + */ +static inline +void rpcauth_destroy_credlist(struct list_head *head) +{ + struct rpc_cred *cred; + + while (!list_empty(head)) { + cred = list_entry(head->next, struct rpc_cred, cr_lru); + list_del_init(&cred->cr_lru); + put_rpccred(cred); + } +} + +/* + * Clear the RPC credential cache, and delete those credentials + * that are not referenced. + */ +void +rpcauth_clear_credcache(struct rpc_cred_cache *cache) +{ + LIST_HEAD(free); + struct hlist_head *head; + struct rpc_cred *cred; + unsigned int hashsize = 1U << cache->hashbits; + int i; + + spin_lock(&rpc_credcache_lock); + spin_lock(&cache->lock); + for (i = 0; i < hashsize; i++) { + head = &cache->hashtable[i]; + while (!hlist_empty(head)) { + cred = hlist_entry(head->first, struct rpc_cred, cr_hash); + get_rpccred(cred); + if (!list_empty(&cred->cr_lru)) { + list_del(&cred->cr_lru); + number_cred_unused--; + } + list_add_tail(&cred->cr_lru, &free); + rpcauth_unhash_cred_locked(cred); + } + } + spin_unlock(&cache->lock); + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); +} + +/* + * Destroy the RPC credential cache + */ +void +rpcauth_destroy_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *cache = auth->au_credcache; + + if (cache) { + auth->au_credcache = NULL; + rpcauth_clear_credcache(cache); + kfree(cache->hashtable); + kfree(cache); + } +} +EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); + + +#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ) + +/* + * Remove stale credentials. Avoid sleeping inside the loop. + */ +static long +rpcauth_prune_expired(struct list_head *free, int nr_to_scan) +{ + spinlock_t *cache_lock; + struct rpc_cred *cred, *next; + unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; + long freed = 0; + + list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { + + if (nr_to_scan-- == 0) + break; + /* + * Enforce a 60 second garbage collection moratorium + * Note that the cred_unused list must be time-ordered. + */ + if (time_in_range(cred->cr_expire, expired, jiffies) && + test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { + freed = SHRINK_STOP; + break; + } + + list_del_init(&cred->cr_lru); + number_cred_unused--; + freed++; + if (atomic_read(&cred->cr_count) != 0) + continue; + + cache_lock = &cred->cr_auth->au_credcache->lock; + spin_lock(cache_lock); + if (atomic_read(&cred->cr_count) == 0) { + get_rpccred(cred); + list_add_tail(&cred->cr_lru, free); + rpcauth_unhash_cred_locked(cred); + } + spin_unlock(cache_lock); + } + return freed; +} + +static unsigned long +rpcauth_cache_do_shrink(int nr_to_scan) +{ + LIST_HEAD(free); + unsigned long freed; + + spin_lock(&rpc_credcache_lock); + freed = rpcauth_prune_expired(&free, nr_to_scan); + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); + + return freed; +} + +/* + * Run memory cache shrinker. + */ +static unsigned long +rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) + +{ + if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) + return SHRINK_STOP; + + /* nothing left, don't come back */ + if (list_empty(&cred_unused)) + return SHRINK_STOP; + + return rpcauth_cache_do_shrink(sc->nr_to_scan); +} + +static unsigned long +rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) + +{ + return number_cred_unused * sysctl_vfs_cache_pressure / 100; +} + +static void +rpcauth_cache_enforce_limit(void) +{ + unsigned long diff; + unsigned int nr_to_scan; + + if (number_cred_unused <= auth_max_cred_cachesize) + return; + diff = number_cred_unused - auth_max_cred_cachesize; + nr_to_scan = 100; + if (diff < nr_to_scan) + nr_to_scan = diff; + rpcauth_cache_do_shrink(nr_to_scan); +} + +/* + * Look up a process' credentials in the authentication cache + */ +struct rpc_cred * +rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, + int flags, gfp_t gfp) +{ + LIST_HEAD(free); + struct rpc_cred_cache *cache = auth->au_credcache; + struct rpc_cred *cred = NULL, + *entry, *new; + unsigned int nr; + + nr = auth->au_ops->hash_cred(acred, cache->hashbits); + + rcu_read_lock(); + hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + if (flags & RPCAUTH_LOOKUP_RCU) { + if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) && + !test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags)) + cred = entry; + break; + } + spin_lock(&cache->lock); + if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { + spin_unlock(&cache->lock); + continue; + } + cred = get_rpccred(entry); + spin_unlock(&cache->lock); + break; + } + rcu_read_unlock(); + + if (cred != NULL) + goto found; + + if (flags & RPCAUTH_LOOKUP_RCU) + return ERR_PTR(-ECHILD); + + new = auth->au_ops->crcreate(auth, acred, flags, gfp); + if (IS_ERR(new)) { + cred = new; + goto out; + } + + spin_lock(&cache->lock); + hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + cred = get_rpccred(entry); + break; + } + if (cred == NULL) { + cred = new; + set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); + } else + list_add_tail(&new->cr_lru, &free); + spin_unlock(&cache->lock); + rpcauth_cache_enforce_limit(); +found: + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && + cred->cr_ops->cr_init != NULL && + !(flags & RPCAUTH_LOOKUP_NEW)) { + int res = cred->cr_ops->cr_init(auth, cred); + if (res < 0) { + put_rpccred(cred); + cred = ERR_PTR(res); + } + } + rpcauth_destroy_credlist(&free); +out: + return cred; +} +EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache); + +struct rpc_cred * +rpcauth_lookupcred(struct rpc_auth *auth, int flags) +{ + struct auth_cred acred; + struct rpc_cred *ret; + const struct cred *cred = current_cred(); + + dprintk("RPC: looking up %s cred\n", + auth->au_ops->au_name); + + memset(&acred, 0, sizeof(acred)); + acred.uid = cred->fsuid; + acred.gid = cred->fsgid; + acred.group_info = cred->group_info; + ret = auth->au_ops->lookup_cred(auth, &acred, flags); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_lookupcred); + +void +rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, + struct rpc_auth *auth, const struct rpc_credops *ops) +{ + INIT_HLIST_NODE(&cred->cr_hash); + INIT_LIST_HEAD(&cred->cr_lru); + atomic_set(&cred->cr_count, 1); + cred->cr_auth = auth; + cred->cr_ops = ops; + cred->cr_expire = jiffies; + cred->cr_uid = acred->uid; +} +EXPORT_SYMBOL_GPL(rpcauth_init_cred); + +struct rpc_cred * +rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) +{ + dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, + cred->cr_auth->au_ops->au_name, cred); + return get_rpccred(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); + +static struct rpc_cred * +rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .uid = GLOBAL_ROOT_UID, + .gid = GLOBAL_ROOT_GID, + }; + + dprintk("RPC: %5u looking up %s cred\n", + task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); + return auth->au_ops->lookup_cred(auth, &acred, lookupflags); +} + +static struct rpc_cred * +rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + + dprintk("RPC: %5u looking up %s cred\n", + task->tk_pid, auth->au_ops->au_name); + return rpcauth_lookupcred(auth, lookupflags); +} + +static int +rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *new; + int lookupflags = 0; + + if (flags & RPC_TASK_ASYNC) + lookupflags |= RPCAUTH_LOOKUP_NEW; + if (cred != NULL) + new = cred->cr_ops->crbind(task, cred, lookupflags); + else if (flags & RPC_TASK_ROOTCREDS) + new = rpcauth_bind_root_cred(task, lookupflags); + else + new = rpcauth_bind_new_cred(task, lookupflags); + if (IS_ERR(new)) + return PTR_ERR(new); + put_rpccred(req->rq_cred); + req->rq_cred = new; + return 0; +} + +void +put_rpccred(struct rpc_cred *cred) +{ + if (cred == NULL) + return; + /* Fast path for unhashed credentials */ + if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) { + if (atomic_dec_and_test(&cred->cr_count)) + cred->cr_ops->crdestroy(cred); + return; + } + + if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) + return; + if (!list_empty(&cred->cr_lru)) { + number_cred_unused--; + list_del_init(&cred->cr_lru); + } + if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { + cred->cr_expire = jiffies; + list_add_tail(&cred->cr_lru, &cred_unused); + number_cred_unused++; + goto out_nodestroy; + } + if (!rpcauth_unhash_cred(cred)) { + /* We were hashed and someone looked us up... */ + goto out_nodestroy; + } + } + spin_unlock(&rpc_credcache_lock); + cred->cr_ops->crdestroy(cred); + return; +out_nodestroy: + spin_unlock(&rpc_credcache_lock); +} +EXPORT_SYMBOL_GPL(put_rpccred); + +__be32 * +rpcauth_marshcred(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + dprintk("RPC: %5u marshaling %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + return cred->cr_ops->crmarshal(task, p); +} + +__be32 * +rpcauth_checkverf(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + dprintk("RPC: %5u validating %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + return cred->cr_ops->crvalidate(task, p); +} + +static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *data, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); + encode(rqstp, &xdr, obj); +} + +int +rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, + __be32 *data, void *obj) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", + task->tk_pid, cred->cr_ops->cr_name, cred); + if (cred->cr_ops->crwrap_req) + return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); + /* By default, we encode the arguments normally. */ + rpcauth_wrap_req_encode(encode, rqstp, data, obj); + return 0; +} + +static int +rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, + __be32 *data, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); + return decode(rqstp, &xdr, obj); +} + +int +rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, + __be32 *data, void *obj) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", + task->tk_pid, cred->cr_ops->cr_name, cred); + if (cred->cr_ops->crunwrap_resp) + return cred->cr_ops->crunwrap_resp(task, decode, rqstp, + data, obj); + /* By default, we decode the arguments normally. */ + return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); +} + +int +rpcauth_refreshcred(struct rpc_task *task) +{ + struct rpc_cred *cred; + int err; + + cred = task->tk_rqstp->rq_cred; + if (cred == NULL) { + err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags); + if (err < 0) + goto out; + cred = task->tk_rqstp->rq_cred; + } + dprintk("RPC: %5u refreshing %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + err = cred->cr_ops->crrefresh(task); +out: + if (err < 0) + task->tk_status = err; + return err; +} + +void +rpcauth_invalcred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + dprintk("RPC: %5u invalidating %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + if (cred) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); +} + +int +rpcauth_uptodatecred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + return cred == NULL || + test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; +} + +static struct shrinker rpc_cred_shrinker = { + .count_objects = rpcauth_cache_shrink_count, + .scan_objects = rpcauth_cache_shrink_scan, + .seeks = DEFAULT_SEEKS, +}; + +int __init rpcauth_init_module(void) +{ + int err; + + err = rpc_init_authunix(); + if (err < 0) + goto out1; + err = rpc_init_generic_auth(); + if (err < 0) + goto out2; + err = register_shrinker(&rpc_cred_shrinker); + if (err < 0) + goto out3; + return 0; +out3: + rpc_destroy_generic_auth(); +out2: + rpc_destroy_authunix(); +out1: + return err; +} + +void rpcauth_remove_module(void) +{ + rpc_destroy_authunix(); + rpc_destroy_generic_auth(); + unregister_shrinker(&rpc_cred_shrinker); +} diff --git a/net/sunrpc/auth_generic.c b/net/sunrpc/auth_generic.c new file mode 100644 index 000000000..1ac08dcbf --- /dev/null +++ b/net/sunrpc/auth_generic.c @@ -0,0 +1,293 @@ +/* + * Generic RPC credential + * + * Copyright (C) 2008, Trond Myklebust <Trond.Myklebust@netapp.com> + */ + +#include <linux/err.h> +#include <linux/slab.h> +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sched.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/sched.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define RPC_MACHINE_CRED_USERID GLOBAL_ROOT_UID +#define RPC_MACHINE_CRED_GROUPID GLOBAL_ROOT_GID + +struct generic_cred { + struct rpc_cred gc_base; + struct auth_cred acred; +}; + +static struct rpc_auth generic_auth; +static const struct rpc_credops generic_credops; + +/* + * Public call interface + */ +struct rpc_cred *rpc_lookup_cred(void) +{ + return rpcauth_lookupcred(&generic_auth, 0); +} +EXPORT_SYMBOL_GPL(rpc_lookup_cred); + +struct rpc_cred * +rpc_lookup_generic_cred(struct auth_cred *acred, int flags, gfp_t gfp) +{ + return rpcauth_lookup_credcache(&generic_auth, acred, flags, gfp); +} +EXPORT_SYMBOL_GPL(rpc_lookup_generic_cred); + +struct rpc_cred *rpc_lookup_cred_nonblock(void) +{ + return rpcauth_lookupcred(&generic_auth, RPCAUTH_LOOKUP_RCU); +} +EXPORT_SYMBOL_GPL(rpc_lookup_cred_nonblock); + +/* + * Public call interface for looking up machine creds. + */ +struct rpc_cred *rpc_lookup_machine_cred(const char *service_name) +{ + struct auth_cred acred = { + .uid = RPC_MACHINE_CRED_USERID, + .gid = RPC_MACHINE_CRED_GROUPID, + .principal = service_name, + .machine_cred = 1, + }; + + dprintk("RPC: looking up machine cred for service %s\n", + service_name); + return generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0); +} +EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred); + +static struct rpc_cred *generic_bind_cred(struct rpc_task *task, + struct rpc_cred *cred, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred *acred = &container_of(cred, struct generic_cred, gc_base)->acred; + + return auth->au_ops->lookup_cred(auth, acred, lookupflags); +} + +static int +generic_hash_cred(struct auth_cred *acred, unsigned int hashbits) +{ + return hash_64(from_kgid(&init_user_ns, acred->gid) | + ((u64)from_kuid(&init_user_ns, acred->uid) << + (sizeof(gid_t) * 8)), hashbits); +} + +/* + * Lookup generic creds for current process + */ +static struct rpc_cred * +generic_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(&generic_auth, acred, flags, GFP_KERNEL); +} + +static struct rpc_cred * +generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp) +{ + struct generic_cred *gcred; + + gcred = kmalloc(sizeof(*gcred), gfp); + if (gcred == NULL) + return ERR_PTR(-ENOMEM); + + rpcauth_init_cred(&gcred->gc_base, acred, &generic_auth, &generic_credops); + gcred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + + gcred->acred.uid = acred->uid; + gcred->acred.gid = acred->gid; + gcred->acred.group_info = acred->group_info; + gcred->acred.ac_flags = 0; + if (gcred->acred.group_info != NULL) + get_group_info(gcred->acred.group_info); + gcred->acred.machine_cred = acred->machine_cred; + gcred->acred.principal = acred->principal; + + dprintk("RPC: allocated %s cred %p for uid %d gid %d\n", + gcred->acred.machine_cred ? "machine" : "generic", + gcred, + from_kuid(&init_user_ns, acred->uid), + from_kgid(&init_user_ns, acred->gid)); + return &gcred->gc_base; +} + +static void +generic_free_cred(struct rpc_cred *cred) +{ + struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); + + dprintk("RPC: generic_free_cred %p\n", gcred); + if (gcred->acred.group_info != NULL) + put_group_info(gcred->acred.group_info); + kfree(gcred); +} + +static void +generic_free_cred_callback(struct rcu_head *head) +{ + struct rpc_cred *cred = container_of(head, struct rpc_cred, cr_rcu); + generic_free_cred(cred); +} + +static void +generic_destroy_cred(struct rpc_cred *cred) +{ + call_rcu(&cred->cr_rcu, generic_free_cred_callback); +} + +static int +machine_cred_match(struct auth_cred *acred, struct generic_cred *gcred, int flags) +{ + if (!gcred->acred.machine_cred || + gcred->acred.principal != acred->principal || + !uid_eq(gcred->acred.uid, acred->uid) || + !gid_eq(gcred->acred.gid, acred->gid)) + return 0; + return 1; +} + +/* + * Match credentials against current process creds. + */ +static int +generic_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) +{ + struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); + int i; + + if (acred->machine_cred) + return machine_cred_match(acred, gcred, flags); + + if (!uid_eq(gcred->acred.uid, acred->uid) || + !gid_eq(gcred->acred.gid, acred->gid) || + gcred->acred.machine_cred != 0) + goto out_nomatch; + + /* Optimisation in the case where pointers are identical... */ + if (gcred->acred.group_info == acred->group_info) + goto out_match; + + /* Slow path... */ + if (gcred->acred.group_info->ngroups != acred->group_info->ngroups) + goto out_nomatch; + for (i = 0; i < gcred->acred.group_info->ngroups; i++) { + if (!gid_eq(gcred->acred.group_info->gid[i], + acred->group_info->gid[i])) + goto out_nomatch; + } +out_match: + return 1; +out_nomatch: + return 0; +} + +int __init rpc_init_generic_auth(void) +{ + return rpcauth_init_credcache(&generic_auth); +} + +void rpc_destroy_generic_auth(void) +{ + rpcauth_destroy_credcache(&generic_auth); +} + +/* + * Test the the current time (now) against the underlying credential key expiry + * minus a timeout and setup notification. + * + * The normal case: + * If 'now' is before the key expiry minus RPC_KEY_EXPIRE_TIMEO, set + * the RPC_CRED_NOTIFY_TIMEOUT flag to setup the underlying credential + * rpc_credops crmatch routine to notify this generic cred when it's key + * expiration is within RPC_KEY_EXPIRE_TIMEO, and return 0. + * + * The error case: + * If the underlying cred lookup fails, return -EACCES. + * + * The 'almost' error case: + * If 'now' is within key expiry minus RPC_KEY_EXPIRE_TIMEO, but not within + * key expiry minus RPC_KEY_EXPIRE_FAIL, set the RPC_CRED_EXPIRE_SOON bit + * on the acred ac_flags and return 0. + */ +static int +generic_key_timeout(struct rpc_auth *auth, struct rpc_cred *cred) +{ + struct auth_cred *acred = &container_of(cred, struct generic_cred, + gc_base)->acred; + struct rpc_cred *tcred; + int ret = 0; + + + /* Fast track for non crkey_timeout (no key) underlying credentials */ + if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT) + return 0; + + /* Fast track for the normal case */ + if (test_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags)) + return 0; + + /* lookup_cred either returns a valid referenced rpc_cred, or PTR_ERR */ + tcred = auth->au_ops->lookup_cred(auth, acred, 0); + if (IS_ERR(tcred)) + return -EACCES; + + /* Test for the almost error case */ + ret = tcred->cr_ops->crkey_timeout(tcred); + if (ret != 0) { + set_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); + ret = 0; + } else { + /* In case underlying cred key has been reset */ + if (test_and_clear_bit(RPC_CRED_KEY_EXPIRE_SOON, + &acred->ac_flags)) + dprintk("RPC: UID %d Credential key reset\n", + from_kuid(&init_user_ns, tcred->cr_uid)); + /* set up fasttrack for the normal case */ + set_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags); + } + + put_rpccred(tcred); + return ret; +} + +static const struct rpc_authops generic_auth_ops = { + .owner = THIS_MODULE, + .au_name = "Generic", + .hash_cred = generic_hash_cred, + .lookup_cred = generic_lookup_cred, + .crcreate = generic_create_cred, + .key_timeout = generic_key_timeout, +}; + +static struct rpc_auth generic_auth = { + .au_ops = &generic_auth_ops, + .au_count = ATOMIC_INIT(0), +}; + +static bool generic_key_to_expire(struct rpc_cred *cred) +{ + struct auth_cred *acred = &container_of(cred, struct generic_cred, + gc_base)->acred; + return test_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); +} + +static const struct rpc_credops generic_credops = { + .cr_name = "Generic cred", + .crdestroy = generic_destroy_cred, + .crbind = generic_bind_cred, + .crmatch = generic_match, + .crkey_to_expire = generic_key_to_expire, +}; diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile new file mode 100644 index 000000000..c374268b0 --- /dev/null +++ b/net/sunrpc/auth_gss/Makefile @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Linux kernel rpcsec_gss implementation +# + +obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o + +auth_rpcgss-y := auth_gss.o gss_generic_token.o \ + gss_mech_switch.o svcauth_gss.o \ + gss_rpc_upcall.o gss_rpc_xdr.o + +obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o + +rpcsec_gss_krb5-y := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ + gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c new file mode 100644 index 000000000..e61c48c1b --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -0,0 +1,2127 @@ +/* + * linux/net/sunrpc/auth_gss/auth_gss.c + * + * RPCSEC_GSS client authentication. + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Dug Song <dugsong@monkey.org> + * Andy Adamson <andros@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + + +#include <linux/module.h> +#include <linux/init.h> +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sched.h> +#include <linux/pagemap.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/uaccess.h> +#include <linux/hashtable.h> + +#include "auth_gss_internal.h" +#include "../netns.h" + +static const struct rpc_authops authgss_ops; + +static const struct rpc_credops gss_credops; +static const struct rpc_credops gss_nullops; + +#define GSS_RETRY_EXPIRED 5 +static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED; + +#define GSS_KEY_EXPIRE_TIMEO 240 +static unsigned int gss_key_expire_timeo = GSS_KEY_EXPIRE_TIMEO; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define GSS_CRED_SLACK (RPC_MAX_AUTH_SIZE * 2) +/* length of a krb5 verifier (48), plus data added before arguments when + * using integrity (two 4-byte integers): */ +#define GSS_VERF_SLACK 100 + +static DEFINE_HASHTABLE(gss_auth_hash_table, 4); +static DEFINE_SPINLOCK(gss_auth_hash_lock); + +struct gss_pipe { + struct rpc_pipe_dir_object pdo; + struct rpc_pipe *pipe; + struct rpc_clnt *clnt; + const char *name; + struct kref kref; +}; + +struct gss_auth { + struct kref kref; + struct hlist_node hash; + struct rpc_auth rpc_auth; + struct gss_api_mech *mech; + enum rpc_gss_svc service; + struct rpc_clnt *client; + struct net *net; + /* + * There are two upcall pipes; dentry[1], named "gssd", is used + * for the new text-based upcall; dentry[0] is named after the + * mechanism (for example, "krb5") and exists for + * backwards-compatibility with older gssd's. + */ + struct gss_pipe *gss_pipe[2]; + const char *target_name; +}; + +/* pipe_version >= 0 if and only if someone has a pipe open. */ +static DEFINE_SPINLOCK(pipe_version_lock); +static struct rpc_wait_queue pipe_version_rpc_waitqueue; +static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue); +static void gss_put_auth(struct gss_auth *gss_auth); + +static void gss_free_ctx(struct gss_cl_ctx *); +static const struct rpc_pipe_ops gss_upcall_ops_v0; +static const struct rpc_pipe_ops gss_upcall_ops_v1; + +static inline struct gss_cl_ctx * +gss_get_ctx(struct gss_cl_ctx *ctx) +{ + refcount_inc(&ctx->count); + return ctx; +} + +static inline void +gss_put_ctx(struct gss_cl_ctx *ctx) +{ + if (refcount_dec_and_test(&ctx->count)) + gss_free_ctx(ctx); +} + +/* gss_cred_set_ctx: + * called by gss_upcall_callback and gss_create_upcall in order + * to set the gss context. The actual exchange of an old context + * and a new one is protected by the pipe->lock. + */ +static void +gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + return; + gss_get_ctx(ctx); + rcu_assign_pointer(gss_cred->gc_ctx, ctx); + set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + smp_mb__before_atomic(); + clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags); +} + +static struct gss_cl_ctx * +gss_cred_get_ctx(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx = NULL; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (ctx) + gss_get_ctx(ctx); + rcu_read_unlock(); + return ctx; +} + +static struct gss_cl_ctx * +gss_alloc_context(void) +{ + struct gss_cl_ctx *ctx; + + ctx = kzalloc(sizeof(*ctx), GFP_NOFS); + if (ctx != NULL) { + ctx->gc_proc = RPC_GSS_PROC_DATA; + ctx->gc_seq = 1; /* NetApp 6.4R1 doesn't accept seq. no. 0 */ + spin_lock_init(&ctx->gc_seq_lock); + refcount_set(&ctx->count,1); + } + return ctx; +} + +#define GSSD_MIN_TIMEOUT (60 * 60) +static const void * +gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm) +{ + const void *q; + unsigned int seclen; + unsigned int timeout; + unsigned long now = jiffies; + u32 window_size; + int ret; + + /* First unsigned int gives the remaining lifetime in seconds of the + * credential - e.g. the remaining TGT lifetime for Kerberos or + * the -t value passed to GSSD. + */ + p = simple_get_bytes(p, end, &timeout, sizeof(timeout)); + if (IS_ERR(p)) + goto err; + if (timeout == 0) + timeout = GSSD_MIN_TIMEOUT; + ctx->gc_expiry = now + ((unsigned long)timeout * HZ); + /* Sequence number window. Determines the maximum number of + * simultaneous requests + */ + p = simple_get_bytes(p, end, &window_size, sizeof(window_size)); + if (IS_ERR(p)) + goto err; + ctx->gc_win = window_size; + /* gssd signals an error by passing ctx->gc_win = 0: */ + if (ctx->gc_win == 0) { + /* + * in which case, p points to an error code. Anything other + * than -EKEYEXPIRED gets converted to -EACCES. + */ + p = simple_get_bytes(p, end, &ret, sizeof(ret)); + if (!IS_ERR(p)) + p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) : + ERR_PTR(-EACCES); + goto err; + } + /* copy the opaque wire context */ + p = simple_get_netobj(p, end, &ctx->gc_wire_ctx); + if (IS_ERR(p)) + goto err; + /* import the opaque security context */ + p = simple_get_bytes(p, end, &seclen, sizeof(seclen)); + if (IS_ERR(p)) + goto err; + q = (const void *)((const char *)p + seclen); + if (unlikely(q > end || q < p)) { + p = ERR_PTR(-EFAULT); + goto err; + } + ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_NOFS); + if (ret < 0) { + p = ERR_PTR(ret); + goto err; + } + + /* is there any trailing data? */ + if (q == end) { + p = q; + goto done; + } + + /* pull in acceptor name (if there is one) */ + p = simple_get_netobj(q, end, &ctx->gc_acceptor); + if (IS_ERR(p)) + goto err; +done: + dprintk("RPC: %s Success. gc_expiry %lu now %lu timeout %u acceptor %.*s\n", + __func__, ctx->gc_expiry, now, timeout, ctx->gc_acceptor.len, + ctx->gc_acceptor.data); + return p; +err: + dprintk("RPC: %s returns error %ld\n", __func__, -PTR_ERR(p)); + return p; +} + +/* XXX: Need some documentation about why UPCALL_BUF_LEN is so small. + * Is user space expecting no more than UPCALL_BUF_LEN bytes? + * Note that there are now _two_ NI_MAXHOST sized data items + * being passed in this string. + */ +#define UPCALL_BUF_LEN 256 + +struct gss_upcall_msg { + refcount_t count; + kuid_t uid; + struct rpc_pipe_msg msg; + struct list_head list; + struct gss_auth *auth; + struct rpc_pipe *pipe; + struct rpc_wait_queue rpc_waitqueue; + wait_queue_head_t waitqueue; + struct gss_cl_ctx *ctx; + char databuf[UPCALL_BUF_LEN]; +}; + +static int get_pipe_version(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; + + spin_lock(&pipe_version_lock); + if (sn->pipe_version >= 0) { + atomic_inc(&sn->pipe_users); + ret = sn->pipe_version; + } else + ret = -EAGAIN; + spin_unlock(&pipe_version_lock); + return ret; +} + +static void put_pipe_version(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (atomic_dec_and_lock(&sn->pipe_users, &pipe_version_lock)) { + sn->pipe_version = -1; + spin_unlock(&pipe_version_lock); + } +} + +static void +gss_release_msg(struct gss_upcall_msg *gss_msg) +{ + struct net *net = gss_msg->auth->net; + if (!refcount_dec_and_test(&gss_msg->count)) + return; + put_pipe_version(net); + BUG_ON(!list_empty(&gss_msg->list)); + if (gss_msg->ctx != NULL) + gss_put_ctx(gss_msg->ctx); + rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue); + gss_put_auth(gss_msg->auth); + kfree(gss_msg); +} + +static struct gss_upcall_msg * +__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) + continue; + if (auth && pos->auth->service != auth->service) + continue; + refcount_inc(&pos->count); + dprintk("RPC: %s found msg %p\n", __func__, pos); + return pos; + } + dprintk("RPC: %s found nothing\n", __func__); + return NULL; +} + +/* Try to add an upcall to the pipefs queue. + * If an upcall owned by our uid already exists, then we return a reference + * to that upcall instead of adding the new upcall. + */ +static inline struct gss_upcall_msg * +gss_add_msg(struct gss_upcall_msg *gss_msg) +{ + struct rpc_pipe *pipe = gss_msg->pipe; + struct gss_upcall_msg *old; + + spin_lock(&pipe->lock); + old = __gss_find_upcall(pipe, gss_msg->uid, gss_msg->auth); + if (old == NULL) { + refcount_inc(&gss_msg->count); + list_add(&gss_msg->list, &pipe->in_downcall); + } else + gss_msg = old; + spin_unlock(&pipe->lock); + return gss_msg; +} + +static void +__gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + list_del_init(&gss_msg->list); + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + wake_up_all(&gss_msg->waitqueue); + refcount_dec(&gss_msg->count); +} + +static void +gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + struct rpc_pipe *pipe = gss_msg->pipe; + + if (list_empty(&gss_msg->list)) + return; + spin_lock(&pipe->lock); + if (!list_empty(&gss_msg->list)) + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); +} + +static void +gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg) +{ + switch (gss_msg->msg.errno) { + case 0: + if (gss_msg->ctx == NULL) + break; + clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx); + break; + case -EKEYEXPIRED: + set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + } + gss_cred->gc_upcall_timestamp = jiffies; + gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); +} + +static void +gss_upcall_callback(struct rpc_task *task) +{ + struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall; + struct rpc_pipe *pipe = gss_msg->pipe; + + spin_lock(&pipe->lock); + gss_handle_downcall_result(gss_cred, gss_msg); + spin_unlock(&pipe->lock); + task->tk_status = gss_msg->msg.errno; + gss_release_msg(gss_msg); +} + +static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg) +{ + uid_t uid = from_kuid(&init_user_ns, gss_msg->uid); + memcpy(gss_msg->databuf, &uid, sizeof(uid)); + gss_msg->msg.data = gss_msg->databuf; + gss_msg->msg.len = sizeof(uid); + + BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf)); +} + +static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, + const char *service_name, + const char *target_name) +{ + struct gss_api_mech *mech = gss_msg->auth->mech; + char *p = gss_msg->databuf; + size_t buflen = sizeof(gss_msg->databuf); + int len; + + len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name, + from_kuid(&init_user_ns, gss_msg->uid)); + buflen -= len; + p += len; + gss_msg->msg.len = len; + + /* + * target= is a full service principal that names the remote + * identity that we are authenticating to. + */ + if (target_name) { + len = scnprintf(p, buflen, "target=%s ", target_name); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + + /* + * gssd uses service= and srchost= to select a matching key from + * the system's keytab to use as the source principal. + * + * service= is the service name part of the source principal, + * or "*" (meaning choose any). + * + * srchost= is the hostname part of the source principal. When + * not provided, gssd uses the local hostname. + */ + if (service_name) { + char *c = strchr(service_name, '@'); + + if (!c) + len = scnprintf(p, buflen, "service=%s ", + service_name); + else + len = scnprintf(p, buflen, + "service=%.*s srchost=%s ", + (int)(c - service_name), + service_name, c + 1); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + + if (mech->gm_upcall_enctypes) { + len = scnprintf(p, buflen, "enctypes=%s ", + mech->gm_upcall_enctypes); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + len = scnprintf(p, buflen, "\n"); + if (len == 0) + goto out_overflow; + gss_msg->msg.len += len; + + gss_msg->msg.data = gss_msg->databuf; + return 0; +out_overflow: + WARN_ON_ONCE(1); + return -ENOMEM; +} + +static struct gss_upcall_msg * +gss_alloc_msg(struct gss_auth *gss_auth, + kuid_t uid, const char *service_name) +{ + struct gss_upcall_msg *gss_msg; + int vers; + int err = -ENOMEM; + + gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS); + if (gss_msg == NULL) + goto err; + vers = get_pipe_version(gss_auth->net); + err = vers; + if (err < 0) + goto err_free_msg; + gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe; + INIT_LIST_HEAD(&gss_msg->list); + rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); + init_waitqueue_head(&gss_msg->waitqueue); + refcount_set(&gss_msg->count, 1); + gss_msg->uid = uid; + gss_msg->auth = gss_auth; + switch (vers) { + case 0: + gss_encode_v0_msg(gss_msg); + break; + default: + err = gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name); + if (err) + goto err_put_pipe_version; + } + kref_get(&gss_auth->kref); + return gss_msg; +err_put_pipe_version: + put_pipe_version(gss_auth->net); +err_free_msg: + kfree(gss_msg); +err: + return ERR_PTR(err); +} + +static struct gss_upcall_msg * +gss_setup_upcall(struct gss_auth *gss_auth, struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_new, *gss_msg; + kuid_t uid = cred->cr_uid; + + gss_new = gss_alloc_msg(gss_auth, uid, gss_cred->gc_principal); + if (IS_ERR(gss_new)) + return gss_new; + gss_msg = gss_add_msg(gss_new); + if (gss_msg == gss_new) { + int res; + refcount_inc(&gss_msg->count); + res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg); + if (res) { + gss_unhash_msg(gss_new); + refcount_dec(&gss_msg->count); + gss_release_msg(gss_new); + gss_msg = ERR_PTR(res); + } + } else + gss_release_msg(gss_new); + return gss_msg; +} + +static void warn_gssd(void) +{ + dprintk("AUTH_GSS upcall failed. Please check user daemon is running.\n"); +} + +static inline int +gss_refresh_upcall(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_auth *gss_auth = container_of(cred->cr_auth, + struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg; + struct rpc_pipe *pipe; + int err = 0; + + dprintk("RPC: %5u %s for uid %u\n", + task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_uid)); + gss_msg = gss_setup_upcall(gss_auth, cred); + if (PTR_ERR(gss_msg) == -EAGAIN) { + /* XXX: warning on the first, under the assumption we + * shouldn't normally hit this case on a refresh. */ + warn_gssd(); + task->tk_timeout = 15*HZ; + rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL); + return -EAGAIN; + } + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + pipe = gss_msg->pipe; + spin_lock(&pipe->lock); + if (gss_cred->gc_upcall != NULL) + rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL); + else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) { + task->tk_timeout = 0; + gss_cred->gc_upcall = gss_msg; + /* gss_upcall_callback will release the reference to gss_upcall_msg */ + refcount_inc(&gss_msg->count); + rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback); + } else { + gss_handle_downcall_result(gss_cred, gss_msg); + err = gss_msg->msg.errno; + } + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); +out: + dprintk("RPC: %5u %s for uid %u result %d\n", + task->tk_pid, __func__, + from_kuid(&init_user_ns, cred->cr_uid), err); + return err; +} + +static inline int +gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct net *net = gss_auth->net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe; + struct rpc_cred *cred = &gss_cred->gc_base; + struct gss_upcall_msg *gss_msg; + DEFINE_WAIT(wait); + int err; + + dprintk("RPC: %s for uid %u\n", + __func__, from_kuid(&init_user_ns, cred->cr_uid)); +retry: + err = 0; + /* if gssd is down, just skip upcalling altogether */ + if (!gssd_running(net)) { + warn_gssd(); + return -EACCES; + } + gss_msg = gss_setup_upcall(gss_auth, cred); + if (PTR_ERR(gss_msg) == -EAGAIN) { + err = wait_event_interruptible_timeout(pipe_version_waitqueue, + sn->pipe_version >= 0, 15 * HZ); + if (sn->pipe_version < 0) { + warn_gssd(); + err = -EACCES; + } + if (err < 0) + goto out; + goto retry; + } + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + pipe = gss_msg->pipe; + for (;;) { + prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE); + spin_lock(&pipe->lock); + if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) { + break; + } + spin_unlock(&pipe->lock); + if (fatal_signal_pending(current)) { + err = -ERESTARTSYS; + goto out_intr; + } + schedule(); + } + if (gss_msg->ctx) + gss_cred_set_ctx(cred, gss_msg->ctx); + else + err = gss_msg->msg.errno; + spin_unlock(&pipe->lock); +out_intr: + finish_wait(&gss_msg->waitqueue, &wait); + gss_release_msg(gss_msg); +out: + dprintk("RPC: %s for uid %u result %d\n", + __func__, from_kuid(&init_user_ns, cred->cr_uid), err); + return err; +} + +#define MSG_BUF_MAXSIZE 1024 + +static ssize_t +gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) +{ + const void *p, *end; + void *buf; + struct gss_upcall_msg *gss_msg; + struct rpc_pipe *pipe = RPC_I(file_inode(filp))->pipe; + struct gss_cl_ctx *ctx; + uid_t id; + kuid_t uid; + ssize_t err = -EFBIG; + + if (mlen > MSG_BUF_MAXSIZE) + goto out; + err = -ENOMEM; + buf = kmalloc(mlen, GFP_NOFS); + if (!buf) + goto out; + + err = -EFAULT; + if (copy_from_user(buf, src, mlen)) + goto err; + + end = (const void *)((char *)buf + mlen); + p = simple_get_bytes(buf, end, &id, sizeof(id)); + if (IS_ERR(p)) { + err = PTR_ERR(p); + goto err; + } + + uid = make_kuid(&init_user_ns, id); + if (!uid_valid(uid)) { + err = -EINVAL; + goto err; + } + + err = -ENOMEM; + ctx = gss_alloc_context(); + if (ctx == NULL) + goto err; + + err = -ENOENT; + /* Find a matching upcall */ + spin_lock(&pipe->lock); + gss_msg = __gss_find_upcall(pipe, uid, NULL); + if (gss_msg == NULL) { + spin_unlock(&pipe->lock); + goto err_put_ctx; + } + list_del_init(&gss_msg->list); + spin_unlock(&pipe->lock); + + p = gss_fill_context(p, end, ctx, gss_msg->auth->mech); + if (IS_ERR(p)) { + err = PTR_ERR(p); + switch (err) { + case -EACCES: + case -EKEYEXPIRED: + gss_msg->msg.errno = err; + err = mlen; + break; + case -EFAULT: + case -ENOMEM: + case -EINVAL: + case -ENOSYS: + gss_msg->msg.errno = -EAGAIN; + break; + default: + printk(KERN_CRIT "%s: bad return from " + "gss_fill_context: %zd\n", __func__, err); + gss_msg->msg.errno = -EIO; + } + goto err_release_msg; + } + gss_msg->ctx = gss_get_ctx(ctx); + err = mlen; + +err_release_msg: + spin_lock(&pipe->lock); + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); +err_put_ctx: + gss_put_ctx(ctx); +err: + kfree(buf); +out: + dprintk("RPC: %s returning %zd\n", __func__, err); + return err; +} + +static int gss_pipe_open(struct inode *inode, int new_version) +{ + struct net *net = inode->i_sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret = 0; + + spin_lock(&pipe_version_lock); + if (sn->pipe_version < 0) { + /* First open of any gss pipe determines the version: */ + sn->pipe_version = new_version; + rpc_wake_up(&pipe_version_rpc_waitqueue); + wake_up(&pipe_version_waitqueue); + } else if (sn->pipe_version != new_version) { + /* Trying to open a pipe of a different version */ + ret = -EBUSY; + goto out; + } + atomic_inc(&sn->pipe_users); +out: + spin_unlock(&pipe_version_lock); + return ret; + +} + +static int gss_pipe_open_v0(struct inode *inode) +{ + return gss_pipe_open(inode, 0); +} + +static int gss_pipe_open_v1(struct inode *inode) +{ + return gss_pipe_open(inode, 1); +} + +static void +gss_pipe_release(struct inode *inode) +{ + struct net *net = inode->i_sb->s_fs_info; + struct rpc_pipe *pipe = RPC_I(inode)->pipe; + struct gss_upcall_msg *gss_msg; + +restart: + spin_lock(&pipe->lock); + list_for_each_entry(gss_msg, &pipe->in_downcall, list) { + + if (!list_empty(&gss_msg->msg.list)) + continue; + gss_msg->msg.errno = -EPIPE; + refcount_inc(&gss_msg->count); + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); + goto restart; + } + spin_unlock(&pipe->lock); + + put_pipe_version(net); +} + +static void +gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); + + if (msg->errno < 0) { + dprintk("RPC: %s releasing msg %p\n", + __func__, gss_msg); + refcount_inc(&gss_msg->count); + gss_unhash_msg(gss_msg); + if (msg->errno == -ETIMEDOUT) + warn_gssd(); + gss_release_msg(gss_msg); + } + gss_release_msg(gss_msg); +} + +static void gss_pipe_dentry_destroy(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *gss_pipe = pdo->pdo_data; + struct rpc_pipe *pipe = gss_pipe->pipe; + + if (pipe->dentry != NULL) { + rpc_unlink(pipe->dentry); + pipe->dentry = NULL; + } +} + +static int gss_pipe_dentry_create(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *p = pdo->pdo_data; + struct dentry *dentry; + + dentry = rpc_mkpipe_dentry(dir, p->name, p->clnt, p->pipe); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + p->pipe->dentry = dentry; + return 0; +} + +static const struct rpc_pipe_dir_object_ops gss_pipe_dir_object_ops = { + .create = gss_pipe_dentry_create, + .destroy = gss_pipe_dentry_destroy, +}; + +static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct gss_pipe *p; + int err = -ENOMEM; + + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (p == NULL) + goto err; + p->pipe = rpc_mkpipe_data(upcall_ops, RPC_PIPE_WAIT_FOR_OPEN); + if (IS_ERR(p->pipe)) { + err = PTR_ERR(p->pipe); + goto err_free_gss_pipe; + } + p->name = name; + p->clnt = clnt; + kref_init(&p->kref); + rpc_init_pipe_dir_object(&p->pdo, + &gss_pipe_dir_object_ops, + p); + return p; +err_free_gss_pipe: + kfree(p); +err: + return ERR_PTR(err); +} + +struct gss_alloc_pdo { + struct rpc_clnt *clnt; + const char *name; + const struct rpc_pipe_ops *upcall_ops; +}; + +static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + if (pdo->pdo_ops != &gss_pipe_dir_object_ops) + return 0; + gss_pipe = container_of(pdo, struct gss_pipe, pdo); + if (strcmp(gss_pipe->name, args->name) != 0) + return 0; + if (!kref_get_unless_zero(&gss_pipe->kref)) + return 0; + return 1; +} + +static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops); + if (!IS_ERR(gss_pipe)) + return &gss_pipe->pdo; + return NULL; +} + +static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct net *net = rpc_net_ns(clnt); + struct rpc_pipe_dir_object *pdo; + struct gss_alloc_pdo args = { + .clnt = clnt, + .name = name, + .upcall_ops = upcall_ops, + }; + + pdo = rpc_find_or_alloc_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + gss_pipe_match_pdo, + gss_pipe_alloc_pdo, + &args); + if (pdo != NULL) + return container_of(pdo, struct gss_pipe, pdo); + return ERR_PTR(-ENOMEM); +} + +static void __gss_pipe_free(struct gss_pipe *p) +{ + struct rpc_clnt *clnt = p->clnt; + struct net *net = rpc_net_ns(clnt); + + rpc_remove_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + &p->pdo); + rpc_destroy_pipe_data(p->pipe); + kfree(p); +} + +static void __gss_pipe_release(struct kref *kref) +{ + struct gss_pipe *p = container_of(kref, struct gss_pipe, kref); + + __gss_pipe_free(p); +} + +static void gss_pipe_free(struct gss_pipe *p) +{ + if (p != NULL) + kref_put(&p->kref, __gss_pipe_release); +} + +/* + * NOTE: we have the opportunity to use different + * parameters based on the input flavor (which must be a pseudoflavor) + */ +static struct gss_auth * +gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + rpc_authflavor_t flavor = args->pseudoflavor; + struct gss_auth *gss_auth; + struct gss_pipe *gss_pipe; + struct rpc_auth * auth; + int err = -ENOMEM; /* XXX? */ + + dprintk("RPC: creating GSS authenticator for client %p\n", clnt); + + if (!try_module_get(THIS_MODULE)) + return ERR_PTR(err); + if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) + goto out_dec; + INIT_HLIST_NODE(&gss_auth->hash); + gss_auth->target_name = NULL; + if (args->target_name) { + gss_auth->target_name = kstrdup(args->target_name, GFP_KERNEL); + if (gss_auth->target_name == NULL) + goto err_free; + } + gss_auth->client = clnt; + gss_auth->net = get_net(rpc_net_ns(clnt)); + err = -EINVAL; + gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); + if (!gss_auth->mech) { + dprintk("RPC: Pseudoflavor %d not found!\n", flavor); + goto err_put_net; + } + gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); + if (gss_auth->service == 0) + goto err_put_mech; + if (!gssd_running(gss_auth->net)) + goto err_put_mech; + auth = &gss_auth->rpc_auth; + auth->au_cslack = GSS_CRED_SLACK >> 2; + auth->au_rslack = GSS_VERF_SLACK >> 2; + auth->au_flags = 0; + auth->au_ops = &authgss_ops; + auth->au_flavor = flavor; + if (gss_pseudoflavor_to_datatouch(gss_auth->mech, flavor)) + auth->au_flags |= RPCAUTH_AUTH_DATATOUCH; + atomic_set(&auth->au_count, 1); + kref_init(&gss_auth->kref); + + err = rpcauth_init_credcache(auth); + if (err) + goto err_put_mech; + /* + * Note: if we created the old pipe first, then someone who + * examined the directory at the right moment might conclude + * that we supported only the old pipe. So we instead create + * the new pipe first. + */ + gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_credcache; + } + gss_auth->gss_pipe[1] = gss_pipe; + + gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name, + &gss_upcall_ops_v0); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_pipe_1; + } + gss_auth->gss_pipe[0] = gss_pipe; + + return gss_auth; +err_destroy_pipe_1: + gss_pipe_free(gss_auth->gss_pipe[1]); +err_destroy_credcache: + rpcauth_destroy_credcache(auth); +err_put_mech: + gss_mech_put(gss_auth->mech); +err_put_net: + put_net(gss_auth->net); +err_free: + kfree(gss_auth->target_name); + kfree(gss_auth); +out_dec: + module_put(THIS_MODULE); + return ERR_PTR(err); +} + +static void +gss_free(struct gss_auth *gss_auth) +{ + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_mech_put(gss_auth->mech); + put_net(gss_auth->net); + kfree(gss_auth->target_name); + + kfree(gss_auth); + module_put(THIS_MODULE); +} + +static void +gss_free_callback(struct kref *kref) +{ + struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref); + + gss_free(gss_auth); +} + +static void +gss_put_auth(struct gss_auth *gss_auth) +{ + kref_put(&gss_auth->kref, gss_free_callback); +} + +static void +gss_destroy(struct rpc_auth *auth) +{ + struct gss_auth *gss_auth = container_of(auth, + struct gss_auth, rpc_auth); + + dprintk("RPC: destroying GSS authenticator %p flavor %d\n", + auth, auth->au_flavor); + + if (hash_hashed(&gss_auth->hash)) { + spin_lock(&gss_auth_hash_lock); + hash_del(&gss_auth->hash); + spin_unlock(&gss_auth_hash_lock); + } + + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_auth->gss_pipe[0] = NULL; + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_auth->gss_pipe[1] = NULL; + rpcauth_destroy_credcache(auth); + + gss_put_auth(gss_auth); +} + +/* + * Auths may be shared between rpc clients that were cloned from a + * common client with the same xprt, if they also share the flavor and + * target_name. + * + * The auth is looked up from the oldest parent sharing the same + * cl_xprt, and the auth itself references only that common parent + * (which is guaranteed to last as long as any of its descendants). + */ +static struct gss_auth * +gss_auth_find_or_add_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt, + struct gss_auth *new) +{ + struct gss_auth *gss_auth; + unsigned long hashval = (unsigned long)clnt; + + spin_lock(&gss_auth_hash_lock); + hash_for_each_possible(gss_auth_hash_table, + gss_auth, + hash, + hashval) { + if (gss_auth->client != clnt) + continue; + if (gss_auth->rpc_auth.au_flavor != args->pseudoflavor) + continue; + if (gss_auth->target_name != args->target_name) { + if (gss_auth->target_name == NULL) + continue; + if (args->target_name == NULL) + continue; + if (strcmp(gss_auth->target_name, args->target_name)) + continue; + } + if (!atomic_inc_not_zero(&gss_auth->rpc_auth.au_count)) + continue; + goto out; + } + if (new) + hash_add(gss_auth_hash_table, &new->hash, hashval); + gss_auth = new; +out: + spin_unlock(&gss_auth_hash_lock); + return gss_auth; +} + +static struct gss_auth * +gss_create_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct gss_auth *new; + + gss_auth = gss_auth_find_or_add_hashed(args, clnt, NULL); + if (gss_auth != NULL) + goto out; + new = gss_create_new(args, clnt); + if (IS_ERR(new)) + return new; + gss_auth = gss_auth_find_or_add_hashed(args, clnt, new); + if (gss_auth != new) + gss_destroy(&new->rpc_auth); +out: + return gss_auth; +} + +static struct rpc_auth * +gss_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); + + while (clnt != clnt->cl_parent) { + struct rpc_clnt *parent = clnt->cl_parent; + /* Find the original parent for this transport */ + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) + break; + clnt = parent; + } + + gss_auth = gss_create_hashed(args, clnt); + if (IS_ERR(gss_auth)) + return ERR_CAST(gss_auth); + return &gss_auth->rpc_auth; +} + +/* + * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call + * to the server with the GSS control procedure field set to + * RPC_GSS_PROC_DESTROY. This should normally cause the server to release + * all RPCSEC_GSS state associated with that context. + */ +static int +gss_destroying_context(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); + struct rpc_task *task; + + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0) + return 0; + + ctx->gc_proc = RPC_GSS_PROC_DESTROY; + cred->cr_ops = &gss_nullops; + + /* Take a reference to ensure the cred will be destroyed either + * by the RPC call or by the put_rpccred() below */ + get_rpccred(cred); + + task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT); + if (!IS_ERR(task)) + rpc_put_task(task); + + put_rpccred(cred); + return 1; +} + +/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure + * to create a new cred or context, so they check that things have been + * allocated before freeing them. */ +static void +gss_do_free_ctx(struct gss_cl_ctx *ctx) +{ + dprintk("RPC: %s\n", __func__); + + gss_delete_sec_context(&ctx->gc_gss_ctx); + kfree(ctx->gc_wire_ctx.data); + kfree(ctx->gc_acceptor.data); + kfree(ctx); +} + +static void +gss_free_ctx_callback(struct rcu_head *head) +{ + struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu); + gss_do_free_ctx(ctx); +} + +static void +gss_free_ctx(struct gss_cl_ctx *ctx) +{ + call_rcu(&ctx->gc_rcu, gss_free_ctx_callback); +} + +static void +gss_free_cred(struct gss_cred *gss_cred) +{ + dprintk("RPC: %s cred=%p\n", __func__, gss_cred); + kfree(gss_cred); +} + +static void +gss_free_cred_callback(struct rcu_head *head) +{ + struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu); + gss_free_cred(gss_cred); +} + +static void +gss_destroy_nullcred(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); + + RCU_INIT_POINTER(gss_cred->gc_ctx, NULL); + call_rcu(&cred->cr_rcu, gss_free_cred_callback); + if (ctx) + gss_put_ctx(ctx); + gss_put_auth(gss_auth); +} + +static void +gss_destroy_cred(struct rpc_cred *cred) +{ + + if (gss_destroying_context(cred)) + return; + gss_destroy_nullcred(cred); +} + +static int +gss_hash_cred(struct auth_cred *acred, unsigned int hashbits) +{ + return hash_64(from_kuid(&init_user_ns, acred->uid), hashbits); +} + +/* + * Lookup RPCSEC_GSS cred for the current process + */ +static struct rpc_cred * +gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(auth, acred, flags, GFP_NOFS); +} + +static struct rpc_cred * +gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *cred = NULL; + int err = -ENOMEM; + + dprintk("RPC: %s for uid %d, flavor %d\n", + __func__, from_kuid(&init_user_ns, acred->uid), + auth->au_flavor); + + if (!(cred = kzalloc(sizeof(*cred), gfp))) + goto out_err; + + rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops); + /* + * Note: in order to force a call to call_refresh(), we deliberately + * fail to flag the credential as RPCAUTH_CRED_UPTODATE. + */ + cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW; + cred->gc_service = gss_auth->service; + cred->gc_principal = NULL; + if (acred->machine_cred) + cred->gc_principal = acred->principal; + kref_get(&gss_auth->kref); + return &cred->gc_base; + +out_err: + dprintk("RPC: %s failed with error %d\n", __func__, err); + return ERR_PTR(err); +} + +static int +gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base); + int err; + + do { + err = gss_create_upcall(gss_auth, gss_cred); + } while (err == -EAGAIN); + return err; +} + +static char * +gss_stringify_acceptor(struct rpc_cred *cred) +{ + char *string = NULL; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned int len; + struct xdr_netobj *acceptor; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx) + goto out; + + len = ctx->gc_acceptor.len; + rcu_read_unlock(); + + /* no point if there's no string */ + if (!len) + return NULL; +realloc: + string = kmalloc(len + 1, GFP_KERNEL); + if (!string) + return NULL; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + + /* did the ctx disappear or was it replaced by one with no acceptor? */ + if (!ctx || !ctx->gc_acceptor.len) { + kfree(string); + string = NULL; + goto out; + } + + acceptor = &ctx->gc_acceptor; + + /* + * Did we find a new acceptor that's longer than the original? Allocate + * a longer buffer and try again. + */ + if (len < acceptor->len) { + len = acceptor->len; + rcu_read_unlock(); + kfree(string); + goto realloc; + } + + memcpy(string, acceptor->data, acceptor->len); + string[acceptor->len] = '\0'; +out: + rcu_read_unlock(); + return string; +} + +/* + * Returns -EACCES if GSS context is NULL or will expire within the + * timeout (miliseconds) + */ +static int +gss_key_timeout(struct rpc_cred *rc) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned long timeout = jiffies + (gss_key_expire_timeo * HZ); + int ret = 0; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(timeout, ctx->gc_expiry)) + ret = -EACCES; + rcu_read_unlock(); + + return ret; +} + +static int +gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + int ret; + + if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags)) + goto out; + /* Don't match with creds that have expired. */ + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(jiffies, ctx->gc_expiry)) { + rcu_read_unlock(); + return 0; + } + rcu_read_unlock(); + if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags)) + return 0; +out: + if (acred->principal != NULL) { + if (gss_cred->gc_principal == NULL) + return 0; + ret = strcmp(acred->principal, gss_cred->gc_principal) == 0; + goto check_expire; + } + if (gss_cred->gc_principal != NULL) + return 0; + ret = uid_eq(rc->cr_uid, acred->uid); + +check_expire: + if (ret == 0) + return ret; + + /* Notify acred users of GSS context expiration timeout */ + if (test_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags) && + (gss_key_timeout(rc) != 0)) { + /* test will now be done from generic cred */ + test_and_clear_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags); + /* tell NFS layer that key will expire soon */ + set_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); + } + return ret; +} + +/* +* Marshal credentials. +* Maybe we should keep a cached credential for performance reasons. +*/ +static __be32 * +gss_marshal(struct rpc_task *task, __be32 *p) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *cred_len; + u32 maj_stat = 0; + struct xdr_netobj mic; + struct kvec iov; + struct xdr_buf verf_buf; + + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + + *p++ = htonl(RPC_AUTH_GSS); + cred_len = p++; + + spin_lock(&ctx->gc_seq_lock); + req->rq_seqno = ctx->gc_seq++; + spin_unlock(&ctx->gc_seq_lock); + + *p++ = htonl((u32) RPC_GSS_VERSION); + *p++ = htonl((u32) ctx->gc_proc); + *p++ = htonl((u32) req->rq_seqno); + *p++ = htonl((u32) gss_cred->gc_service); + p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); + *cred_len = htonl((p - (cred_len + 1)) << 2); + + /* We compute the checksum for the verifier over the xdr-encoded bytes + * starting with the xid and ending at the end of the credential: */ + iov.iov_base = xprt_skip_transport_header(req->rq_xprt, + req->rq_snd_buf.head[0].iov_base); + iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; + xdr_buf_from_iov(&iov, &verf_buf); + + /* set verifier flavor*/ + *p++ = htonl(RPC_AUTH_GSS); + + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) { + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + } else if (maj_stat != 0) { + printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat); + goto out_put_ctx; + } + p = xdr_encode_opaque(p, NULL, mic.len); + gss_put_ctx(ctx); + return p; +out_put_ctx: + gss_put_ctx(ctx); + return NULL; +} + +static int gss_renew_cred(struct rpc_task *task) +{ + struct rpc_cred *oldcred = task->tk_rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(oldcred, + struct gss_cred, + gc_base); + struct rpc_auth *auth = oldcred->cr_auth; + struct auth_cred acred = { + .uid = oldcred->cr_uid, + .principal = gss_cred->gc_principal, + .machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0), + }; + struct rpc_cred *new; + + new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW); + if (IS_ERR(new)) + return PTR_ERR(new); + task->tk_rqstp->rq_cred = new; + put_rpccred(oldcred); + return 0; +} + +static int gss_cred_is_negative_entry(struct rpc_cred *cred) +{ + if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) { + unsigned long now = jiffies; + unsigned long begin, expire; + struct gss_cred *gss_cred; + + gss_cred = container_of(cred, struct gss_cred, gc_base); + begin = gss_cred->gc_upcall_timestamp; + expire = begin + gss_expired_cred_retry_delay * HZ; + + if (time_in_range_open(now, begin, expire)) + return 1; + } + return 0; +} + +/* +* Refresh credentials. XXX - finish +*/ +static int +gss_refresh(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + int ret = 0; + + if (gss_cred_is_negative_entry(cred)) + return -EKEYEXPIRED; + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && + !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { + ret = gss_renew_cred(task); + if (ret < 0) + goto out; + cred = task->tk_rqstp->rq_cred; + } + + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + ret = gss_refresh_upcall(task); +out: + return ret; +} + +/* Dummy refresh routine: used only when destroying the context */ +static int +gss_refresh_null(struct rpc_task *task) +{ + return 0; +} + +static __be32 * +gss_validate(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *seq = NULL; + struct kvec iov; + struct xdr_buf verf_buf; + struct xdr_netobj mic; + u32 flav,len; + u32 maj_stat; + __be32 *ret = ERR_PTR(-EIO); + + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + + flav = ntohl(*p++); + if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE) + goto out_bad; + if (flav != RPC_AUTH_GSS) + goto out_bad; + seq = kmalloc(4, GFP_NOFS); + if (!seq) + goto out_bad; + *seq = htonl(task->tk_rqstp->rq_seqno); + iov.iov_base = seq; + iov.iov_len = 4; + xdr_buf_from_iov(&iov, &verf_buf); + mic.data = (u8 *)p; + mic.len = len; + + ret = ERR_PTR(-EACCES); + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat) { + dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n", + task->tk_pid, __func__, maj_stat); + goto out_bad; + } + /* We leave it to unwrap to calculate au_rslack. For now we just + * calculate the length of the verifier: */ + cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; + gss_put_ctx(ctx); + dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n", + task->tk_pid, __func__); + kfree(seq); + return p + XDR_QUADLEN(len); +out_bad: + gss_put_ctx(ctx); + dprintk("RPC: %5u %s failed ret %ld.\n", task->tk_pid, __func__, + PTR_ERR(ret)); + kfree(seq); + return ret; +} + +static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p); + encode(rqstp, &xdr, obj); +} + +static inline int +gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + struct xdr_buf integ_buf; + __be32 *integ_len = NULL; + struct xdr_netobj mic; + u32 offset; + __be32 *q; + struct kvec *iov; + u32 maj_stat = 0; + int status = -EIO; + + integ_len = p++; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + *p++ = htonl(rqstp->rq_seqno); + + gss_wrap_req_encode(encode, rqstp, p, obj); + + if (xdr_buf_subsegment(snd_buf, &integ_buf, + offset, snd_buf->len - offset)) + return status; + *integ_len = htonl(integ_buf.len); + + /* guess whether we're in the head or the tail: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) + iov = snd_buf->tail; + else + iov = snd_buf->head; + p = iov->iov_base + iov->iov_len; + mic.data = (u8 *)(p + 1); + + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + status = -EIO; /* XXX? */ + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + return status; + q = xdr_encode_opaque(p, NULL, mic.len); + + offset = (u8 *)q - (u8 *)p; + iov->iov_len += offset; + snd_buf->len += offset; + return 0; +} + +static void +priv_release_snd_buf(struct rpc_rqst *rqstp) +{ + int i; + + for (i=0; i < rqstp->rq_enc_pages_num; i++) + __free_page(rqstp->rq_enc_pages[i]); + kfree(rqstp->rq_enc_pages); + rqstp->rq_release_snd_buf = NULL; +} + +static int +alloc_enc_pages(struct rpc_rqst *rqstp) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + int first, last, i; + + if (rqstp->rq_release_snd_buf) + rqstp->rq_release_snd_buf(rqstp); + + if (snd_buf->page_len == 0) { + rqstp->rq_enc_pages_num = 0; + return 0; + } + + first = snd_buf->page_base >> PAGE_SHIFT; + last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_SHIFT; + rqstp->rq_enc_pages_num = last - first + 1 + 1; + rqstp->rq_enc_pages + = kmalloc_array(rqstp->rq_enc_pages_num, + sizeof(struct page *), + GFP_NOFS); + if (!rqstp->rq_enc_pages) + goto out; + for (i=0; i < rqstp->rq_enc_pages_num; i++) { + rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS); + if (rqstp->rq_enc_pages[i] == NULL) + goto out_free; + } + rqstp->rq_release_snd_buf = priv_release_snd_buf; + return 0; +out_free: + rqstp->rq_enc_pages_num = i; + priv_release_snd_buf(rqstp); +out: + return -EAGAIN; +} + +static inline int +gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + u32 offset; + u32 maj_stat; + int status; + __be32 *opaque_len; + struct page **inpages; + int first; + int pad; + struct kvec *iov; + char *tmp; + + opaque_len = p++; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + *p++ = htonl(rqstp->rq_seqno); + + gss_wrap_req_encode(encode, rqstp, p, obj); + + status = alloc_enc_pages(rqstp); + if (status) + return status; + first = snd_buf->page_base >> PAGE_SHIFT; + inpages = snd_buf->pages + first; + snd_buf->pages = rqstp->rq_enc_pages; + snd_buf->page_base -= first << PAGE_SHIFT; + /* + * Give the tail its own page, in case we need extra space in the + * head when wrapping: + * + * call_allocate() allocates twice the slack space required + * by the authentication flavor to rq_callsize. + * For GSS, slack is GSS_CRED_SLACK. + */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) { + tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); + memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); + snd_buf->tail[0].iov_base = tmp; + } + maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); + /* slack space should prevent this ever happening: */ + BUG_ON(snd_buf->len > snd_buf->buflen); + status = -EIO; + /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was + * done anyway, so it's safe to put the request on the wire: */ + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + return status; + + *opaque_len = htonl(snd_buf->len - offset); + /* guess whether we're in the head or the tail: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) + iov = snd_buf->tail; + else + iov = snd_buf->head; + p = iov->iov_base + iov->iov_len; + pad = 3 - ((snd_buf->len - offset - 1) & 3); + memset(p, 0, pad); + iov->iov_len += pad; + snd_buf->len += pad; + + return 0; +} + +static int +gss_wrap_req(struct rpc_task *task, + kxdreproc_t encode, void *rqstp, __be32 *p, void *obj) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + int status = -EIO; + + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + if (ctx->gc_proc != RPC_GSS_PROC_DATA) { + /* The spec seems a little ambiguous here, but I think that not + * wrapping context destruction requests makes the most sense. + */ + gss_wrap_req_encode(encode, rqstp, p, obj); + status = 0; + goto out; + } + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + gss_wrap_req_encode(encode, rqstp, p, obj); + status = 0; + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj); + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj); + break; + } +out: + gss_put_ctx(ctx); + dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status); + return status; +} + +static inline int +gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_rqst *rqstp, __be32 **p) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + struct xdr_buf integ_buf; + struct xdr_netobj mic; + u32 data_offset, mic_offset; + u32 integ_len; + u32 maj_stat; + int status = -EIO; + + integ_len = ntohl(*(*p)++); + if (integ_len & 3) + return status; + data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; + mic_offset = integ_len + data_offset; + if (mic_offset > rcv_buf->len) + return status; + if (ntohl(*(*p)++) != rqstp->rq_seqno) + return status; + + if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, + mic_offset - data_offset)) + return status; + + if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset)) + return status; + + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + return status; + return 0; +} + +static inline int +gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_rqst *rqstp, __be32 **p) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + u32 offset; + u32 opaque_len; + u32 maj_stat; + int status = -EIO; + + opaque_len = ntohl(*(*p)++); + offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; + if (offset + opaque_len > rcv_buf->len) + return status; + /* remove padding: */ + rcv_buf->len = offset + opaque_len; + + maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + return status; + if (ntohl(*(*p)++) != rqstp->rq_seqno) + return status; + + return 0; +} + +static int +gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p); + return decode(rqstp, &xdr, obj); +} + +static int +gss_unwrap_resp(struct rpc_task *task, + kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *savedp = p; + struct kvec *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head; + int savedlen = head->iov_len; + int status = -EIO; + + if (ctx->gc_proc != RPC_GSS_PROC_DATA) + goto out_decode; + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); + if (status) + goto out; + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); + if (status) + goto out; + break; + } + /* take into account extra slack for integrity and privacy cases: */ + cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp) + + (savedlen - head->iov_len); +out_decode: + status = gss_unwrap_req_decode(decode, rqstp, p, obj); +out: + gss_put_ctx(ctx); + dprintk("RPC: %5u %s returning %d\n", + task->tk_pid, __func__, status); + return status; +} + +static const struct rpc_authops authgss_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_GSS, + .au_name = "RPCSEC_GSS", + .create = gss_create, + .destroy = gss_destroy, + .hash_cred = gss_hash_cred, + .lookup_cred = gss_lookup_cred, + .crcreate = gss_create_cred, + .list_pseudoflavors = gss_mech_list_pseudoflavors, + .info2flavor = gss_mech_info2flavor, + .flavor2info = gss_mech_flavor2info, +}; + +static const struct rpc_credops gss_credops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_cred, + .cr_init = gss_cred_init, + .crbind = rpcauth_generic_bind_cred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crkey_timeout = gss_key_timeout, + .crstringify_acceptor = gss_stringify_acceptor, +}; + +static const struct rpc_credops gss_nullops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_nullcred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh_null, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crstringify_acceptor = gss_stringify_acceptor, +}; + +static const struct rpc_pipe_ops gss_upcall_ops_v0 = { + .upcall = rpc_pipe_generic_upcall, + .downcall = gss_pipe_downcall, + .destroy_msg = gss_pipe_destroy_msg, + .open_pipe = gss_pipe_open_v0, + .release_pipe = gss_pipe_release, +}; + +static const struct rpc_pipe_ops gss_upcall_ops_v1 = { + .upcall = rpc_pipe_generic_upcall, + .downcall = gss_pipe_downcall, + .destroy_msg = gss_pipe_destroy_msg, + .open_pipe = gss_pipe_open_v1, + .release_pipe = gss_pipe_release, +}; + +static __net_init int rpcsec_gss_init_net(struct net *net) +{ + return gss_svc_init_net(net); +} + +static __net_exit void rpcsec_gss_exit_net(struct net *net) +{ + gss_svc_shutdown_net(net); +} + +static struct pernet_operations rpcsec_gss_net_ops = { + .init = rpcsec_gss_init_net, + .exit = rpcsec_gss_exit_net, +}; + +/* + * Initialize RPCSEC_GSS module + */ +static int __init init_rpcsec_gss(void) +{ + int err = 0; + + err = rpcauth_register(&authgss_ops); + if (err) + goto out; + err = gss_svc_init(); + if (err) + goto out_unregister; + err = register_pernet_subsys(&rpcsec_gss_net_ops); + if (err) + goto out_svc_exit; + rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version"); + return 0; +out_svc_exit: + gss_svc_shutdown(); +out_unregister: + rpcauth_unregister(&authgss_ops); +out: + return err; +} + +static void __exit exit_rpcsec_gss(void) +{ + unregister_pernet_subsys(&rpcsec_gss_net_ops); + gss_svc_shutdown(); + rpcauth_unregister(&authgss_ops); + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} + +MODULE_ALIAS("rpc-auth-6"); +MODULE_LICENSE("GPL"); +module_param_named(expired_cred_retry_delay, + gss_expired_cred_retry_delay, + uint, 0644); +MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until " + "the RPC engine retries an expired credential"); + +module_param_named(key_expire_timeo, + gss_key_expire_timeo, + uint, 0644); +MODULE_PARM_DESC(key_expire_timeo, "Time (in seconds) at the end of a " + "credential keys lifetime where the NFS layer cleans up " + "prior to key expiration"); + +module_init(init_rpcsec_gss) +module_exit(exit_rpcsec_gss) diff --git a/net/sunrpc/auth_gss/auth_gss_internal.h b/net/sunrpc/auth_gss/auth_gss_internal.h new file mode 100644 index 000000000..f6d9631bd --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss_internal.h @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/auth_gss/auth_gss_internal.h + * + * Internal definitions for RPCSEC_GSS client authentication + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + */ +#include <linux/err.h> +#include <linux/string.h> +#include <linux/sunrpc/xdr.h> + +static inline const void * +simple_get_bytes(const void *p, const void *end, void *res, size_t len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static inline const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) +{ + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + if (len) { + dest->data = kmemdup(p, len, GFP_NOFS); + if (unlikely(dest->data == NULL)) + return ERR_PTR(-ENOMEM); + } else + dest->data = NULL; + dest->len = len; + return q; +} diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c new file mode 100644 index 000000000..fe97f3106 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_generic_token.c @@ -0,0 +1,233 @@ +/* + * linux/net/sunrpc/gss_generic_token.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/string.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/gss_asn1.h> + + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* TWRITE_STR from gssapiP_generic.h */ +#define TWRITE_STR(ptr, str, len) \ + memcpy((ptr), (char *) (str), (len)); \ + (ptr) += (len); + +/* XXXX this code currently makes the assumption that a mech oid will + never be longer than 127 bytes. This assumption is not inherent in + the interfaces, so the code can be fixed if the OSI namespace + balloons unexpectedly. */ + +/* Each token looks like this: + +0x60 tag for APPLICATION 0, SEQUENCE + (constructed, definite-length) + <length> possible multiple bytes, need to parse/generate + 0x06 tag for OBJECT IDENTIFIER + <moid_length> compile-time constant string (assume 1 byte) + <moid_bytes> compile-time constant string + <inner_bytes> the ANY containing the application token + bytes 0,1 are the token type + bytes 2,n are the token data + +For the purposes of this abstraction, the token "header" consists of +the sequence tag and length octets, the mech OID DER encoding, and the +first two inner bytes, which indicate the token type. The token +"body" consists of everything else. + +*/ + +static int +der_length_size( int length) +{ + if (length < (1<<7)) + return 1; + else if (length < (1<<8)) + return 2; +#if (SIZEOF_INT == 2) + else + return 3; +#else + else if (length < (1<<16)) + return 3; + else if (length < (1<<24)) + return 4; + else + return 5; +#endif +} + +static void +der_write_length(unsigned char **buf, int length) +{ + if (length < (1<<7)) { + *(*buf)++ = (unsigned char) length; + } else { + *(*buf)++ = (unsigned char) (der_length_size(length)+127); +#if (SIZEOF_INT > 2) + if (length >= (1<<24)) + *(*buf)++ = (unsigned char) (length>>24); + if (length >= (1<<16)) + *(*buf)++ = (unsigned char) ((length>>16)&0xff); +#endif + if (length >= (1<<8)) + *(*buf)++ = (unsigned char) ((length>>8)&0xff); + *(*buf)++ = (unsigned char) (length&0xff); + } +} + +/* returns decoded length, or < 0 on failure. Advances buf and + decrements bufsize */ + +static int +der_read_length(unsigned char **buf, int *bufsize) +{ + unsigned char sf; + int ret; + + if (*bufsize < 1) + return -1; + sf = *(*buf)++; + (*bufsize)--; + if (sf & 0x80) { + if ((sf &= 0x7f) > ((*bufsize)-1)) + return -1; + if (sf > SIZEOF_INT) + return -1; + ret = 0; + for (; sf; sf--) { + ret = (ret<<8) + (*(*buf)++); + (*bufsize)--; + } + } else { + ret = sf; + } + + return ret; +} + +/* returns the length of a token, given the mech oid and the body size */ + +int +g_token_size(struct xdr_netobj *mech, unsigned int body_size) +{ + /* set body_size to sequence contents size */ + body_size += 2 + (int) mech->len; /* NEED overflow check */ + return 1 + der_length_size(body_size) + body_size; +} + +EXPORT_SYMBOL_GPL(g_token_size); + +/* fills in a buffer with the token header. The buffer is assumed to + be the right size. buf is advanced past the token header */ + +void +g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf) +{ + *(*buf)++ = 0x60; + der_write_length(buf, 2 + mech->len + body_size); + *(*buf)++ = 0x06; + *(*buf)++ = (unsigned char) mech->len; + TWRITE_STR(*buf, mech->data, ((int) mech->len)); +} + +EXPORT_SYMBOL_GPL(g_make_token_header); + +/* + * Given a buffer containing a token, reads and verifies the token, + * leaving buf advanced past the token header, and setting body_size + * to the number of remaining bytes. Returns 0 on success, + * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the + * mechanism in the token does not match the mech argument. buf and + * *body_size are left unmodified on error. + */ +u32 +g_verify_token_header(struct xdr_netobj *mech, int *body_size, + unsigned char **buf_in, int toksize) +{ + unsigned char *buf = *buf_in; + int seqsize; + struct xdr_netobj toid; + int ret = 0; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + if (*buf++ != 0x60) + return G_BAD_TOK_HEADER; + + if ((seqsize = der_read_length(&buf, &toksize)) < 0) + return G_BAD_TOK_HEADER; + + if (seqsize != toksize) + return G_BAD_TOK_HEADER; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + if (*buf++ != 0x06) + return G_BAD_TOK_HEADER; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + toid.len = *buf++; + + if ((toksize-=toid.len) < 0) + return G_BAD_TOK_HEADER; + toid.data = buf; + buf+=toid.len; + + if (! g_OID_equal(&toid, mech)) + ret = G_WRONG_MECH; + + /* G_WRONG_MECH is not returned immediately because it's more important + to return G_BAD_TOK_HEADER if the token header is in fact bad */ + + if ((toksize-=2) < 0) + return G_BAD_TOK_HEADER; + + if (ret) + return ret; + + if (!ret) { + *buf_in = buf; + *body_size = toksize; + } + + return ret; +} + +EXPORT_SYMBOL_GPL(g_verify_token_header); diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c new file mode 100644 index 000000000..0220e1ca5 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -0,0 +1,1083 @@ +/* + * linux/net/sunrpc/gss_krb5_crypto.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * Bruce Fields <bfields@umich.edu> + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <crypto/algapi.h> +#include <crypto/hash.h> +#include <crypto/skcipher.h> +#include <linux/err.h> +#include <linux/types.h> +#include <linux/mm.h> +#include <linux/scatterlist.h> +#include <linux/highmem.h> +#include <linux/pagemap.h> +#include <linux/random.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +u32 +krb5_encrypt( + struct crypto_skcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; + SKCIPHER_REQUEST_ON_STACK(req, tfm); + + if (length % crypto_skcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", + crypto_skcipher_ivsize(tfm)); + goto out; + } + + if (iv) + memcpy(local_iv, iv, crypto_skcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + skcipher_request_set_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, length, local_iv); + + ret = crypto_skcipher_encrypt(req); + skcipher_request_zero(req); +out: + dprintk("RPC: krb5_encrypt returns %d\n", ret); + return ret; +} + +u32 +krb5_decrypt( + struct crypto_skcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; + SKCIPHER_REQUEST_ON_STACK(req, tfm); + + if (length % crypto_skcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", + crypto_skcipher_ivsize(tfm)); + goto out; + } + if (iv) + memcpy(local_iv,iv, crypto_skcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + skcipher_request_set_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, length, local_iv); + + ret = crypto_skcipher_decrypt(req); + skcipher_request_zero(req); +out: + dprintk("RPC: gss_k5decrypt returns %d\n",ret); + return ret; +} + +static int +checksummer(struct scatterlist *sg, void *data) +{ + struct ahash_request *req = data; + + ahash_request_set_crypt(req, sg, NULL, sg->length); + + return crypto_ahash_update(req); +} + +static int +arcfour_hmac_md5_usage_to_salt(unsigned int usage, u8 salt[4]) +{ + unsigned int ms_usage; + + switch (usage) { + case KG_USAGE_SIGN: + ms_usage = 15; + break; + case KG_USAGE_SEAL: + ms_usage = 13; + break; + default: + return -EINVAL; + } + salt[0] = (ms_usage >> 0) & 0xff; + salt[1] = (ms_usage >> 8) & 0xff; + salt[2] = (ms_usage >> 16) & 0xff; + salt[3] = (ms_usage >> 24) & 0xff; + + return 0; +} + +static u32 +make_checksum_hmac_md5(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct scatterlist sg[1]; + int err = -1; + u8 *checksumdata; + u8 *rc4salt; + struct crypto_ahash *md5; + struct crypto_ahash *hmac_md5; + struct ahash_request *req; + + if (cksumkey == NULL) + return GSS_S_FAILURE; + + if (cksumout->len < kctx->gk5e->cksumlength) { + dprintk("%s: checksum buffer length, %u, too small for %s\n", + __func__, cksumout->len, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + rc4salt = kmalloc_array(4, sizeof(*rc4salt), GFP_NOFS); + if (!rc4salt) + return GSS_S_FAILURE; + + if (arcfour_hmac_md5_usage_to_salt(usage, rc4salt)) { + dprintk("%s: invalid usage value %u\n", __func__, usage); + goto out_free_rc4salt; + } + + checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_NOFS); + if (!checksumdata) + goto out_free_rc4salt; + + md5 = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(md5)) + goto out_free_cksum; + + hmac_md5 = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(hmac_md5)) + goto out_free_md5; + + req = ahash_request_alloc(md5, GFP_NOFS); + if (!req) + goto out_free_hmac_md5; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + err = crypto_ahash_init(req); + if (err) + goto out; + sg_init_one(sg, rc4salt, 4); + ahash_request_set_crypt(req, sg, NULL, 4); + err = crypto_ahash_update(req); + if (err) + goto out; + + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out; + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); + if (err) + goto out; + + ahash_request_free(req); + req = ahash_request_alloc(hmac_md5, GFP_NOFS); + if (!req) + goto out_free_hmac_md5; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + err = crypto_ahash_setkey(hmac_md5, cksumkey, kctx->gk5e->keylength); + if (err) + goto out; + + sg_init_one(sg, checksumdata, crypto_ahash_digestsize(md5)); + ahash_request_set_crypt(req, sg, checksumdata, + crypto_ahash_digestsize(md5)); + err = crypto_ahash_digest(req); + if (err) + goto out; + + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + cksumout->len = kctx->gk5e->cksumlength; +out: + ahash_request_free(req); +out_free_hmac_md5: + crypto_free_ahash(hmac_md5); +out_free_md5: + crypto_free_ahash(md5); +out_free_cksum: + kfree(checksumdata); +out_free_rc4salt: + kfree(rc4salt); + return err ? GSS_S_FAILURE : 0; +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * The checksum is performed over the first 8 bytes of the + * gss token header and then over the data body + */ +u32 +make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct crypto_ahash *tfm; + struct ahash_request *req; + struct scatterlist sg[1]; + int err = -1; + u8 *checksumdata; + unsigned int checksumlen; + + if (kctx->gk5e->ctype == CKSUMTYPE_HMAC_MD5_ARCFOUR) + return make_checksum_hmac_md5(kctx, header, hdrlen, + body, body_offset, + cksumkey, usage, cksumout); + + if (cksumout->len < kctx->gk5e->cksumlength) { + dprintk("%s: checksum buffer length, %u, too small for %s\n", + __func__, cksumout->len, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_NOFS); + if (checksumdata == NULL) + return GSS_S_FAILURE; + + tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + goto out_free_cksum; + + req = ahash_request_alloc(tfm, GFP_NOFS); + if (!req) + goto out_free_ahash; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + checksumlen = crypto_ahash_digestsize(tfm); + + if (cksumkey != NULL) { + err = crypto_ahash_setkey(tfm, cksumkey, + kctx->gk5e->keylength); + if (err) + goto out; + } + + err = crypto_ahash_init(req); + if (err) + goto out; + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out; + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); + if (err) + goto out; + + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_RSA_MD5: + err = kctx->gk5e->encrypt(kctx->seq, NULL, checksumdata, + checksumdata, checksumlen); + if (err) + goto out; + memcpy(cksumout->data, + checksumdata + checksumlen - kctx->gk5e->cksumlength, + kctx->gk5e->cksumlength); + break; + case CKSUMTYPE_HMAC_SHA1_DES3: + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } + cksumout->len = kctx->gk5e->cksumlength; +out: + ahash_request_free(req); +out_free_ahash: + crypto_free_ahash(tfm); +out_free_cksum: + kfree(checksumdata); + return err ? GSS_S_FAILURE : 0; +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * Per rfc4121, sec. 4.2.4, the checksum is performed over the data + * body then over the first 16 octets of the MIC token + * Inclusion of the header data in the calculation of the + * checksum is optional. + */ +u32 +make_checksum_v2(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct crypto_ahash *tfm; + struct ahash_request *req; + struct scatterlist sg[1]; + int err = -1; + u8 *checksumdata; + + if (kctx->gk5e->keyed_cksum == 0) { + dprintk("%s: expected keyed hash for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + if (cksumkey == NULL) { + dprintk("%s: no key supplied for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_NOFS); + if (!checksumdata) + return GSS_S_FAILURE; + + tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + goto out_free_cksum; + + req = ahash_request_alloc(tfm, GFP_NOFS); + if (!req) + goto out_free_ahash; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + err = crypto_ahash_setkey(tfm, cksumkey, kctx->gk5e->keylength); + if (err) + goto out; + + err = crypto_ahash_init(req); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out; + if (header != NULL) { + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); + if (err) + goto out; + } + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); + if (err) + goto out; + + cksumout->len = kctx->gk5e->cksumlength; + + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_HMAC_SHA1_96_AES128: + case CKSUMTYPE_HMAC_SHA1_96_AES256: + /* note that this truncates the hash */ + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } +out: + ahash_request_free(req); +out_free_ahash: + crypto_free_ahash(tfm); +out_free_cksum: + kfree(checksumdata); + return err ? GSS_S_FAILURE : 0; +} + +struct encryptor_desc { + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; + struct skcipher_request *req; + int pos; + struct xdr_buf *outbuf; + struct page **pages; + struct scatterlist infrags[4]; + struct scatterlist outfrags[4]; + int fragno; + int fraglen; +}; + +static int +encryptor(struct scatterlist *sg, void *data) +{ + struct encryptor_desc *desc = data; + struct xdr_buf *outbuf = desc->outbuf; + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(desc->req); + struct page *in_page; + int thislen = desc->fraglen + sg->length; + int fraglen, ret; + int page_pos; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + + page_pos = desc->pos - outbuf->head[0].iov_len; + if (page_pos >= 0 && page_pos < outbuf->page_len) { + /* pages are not in place: */ + int i = (page_pos + outbuf->page_base) >> PAGE_SHIFT; + in_page = desc->pages[i]; + } else { + in_page = sg_page(sg); + } + sg_set_page(&desc->infrags[desc->fragno], in_page, sg->length, + sg->offset); + sg_set_page(&desc->outfrags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + desc->pos += sg->length; + + fraglen = thislen & (crypto_skcipher_blocksize(tfm) - 1); + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->infrags[desc->fragno - 1]); + sg_mark_end(&desc->outfrags[desc->fragno - 1]); + + skcipher_request_set_crypt(desc->req, desc->infrags, desc->outfrags, + thislen, desc->iv); + + ret = crypto_skcipher_encrypt(desc->req); + if (ret) + return ret; + + sg_init_table(desc->infrags, 4); + sg_init_table(desc->outfrags, 4); + + if (fraglen) { + sg_set_page(&desc->outfrags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->infrags[0] = desc->outfrags[0]; + sg_assign_page(&desc->infrags[0], in_page); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_encrypt_xdr_buf(struct crypto_skcipher *tfm, struct xdr_buf *buf, + int offset, struct page **pages) +{ + int ret; + struct encryptor_desc desc; + SKCIPHER_REQUEST_ON_STACK(req, tfm); + + BUG_ON((buf->len - offset) % crypto_skcipher_blocksize(tfm) != 0); + + skcipher_request_set_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.req = req; + desc.pos = offset; + desc.outbuf = buf; + desc.pages = pages; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc); + skcipher_request_zero(req); + return ret; +} + +struct decryptor_desc { + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; + struct skcipher_request *req; + struct scatterlist frags[4]; + int fragno; + int fraglen; +}; + +static int +decryptor(struct scatterlist *sg, void *data) +{ + struct decryptor_desc *desc = data; + int thislen = desc->fraglen + sg->length; + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(desc->req); + int fraglen, ret; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + sg_set_page(&desc->frags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + + fraglen = thislen & (crypto_skcipher_blocksize(tfm) - 1); + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->frags[desc->fragno - 1]); + + skcipher_request_set_crypt(desc->req, desc->frags, desc->frags, + thislen, desc->iv); + + ret = crypto_skcipher_decrypt(desc->req); + if (ret) + return ret; + + sg_init_table(desc->frags, 4); + + if (fraglen) { + sg_set_page(&desc->frags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_decrypt_xdr_buf(struct crypto_skcipher *tfm, struct xdr_buf *buf, + int offset) +{ + int ret; + struct decryptor_desc desc; + SKCIPHER_REQUEST_ON_STACK(req, tfm); + + /* XXXJBF: */ + BUG_ON((buf->len - offset) % crypto_skcipher_blocksize(tfm) != 0); + + skcipher_request_set_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.req = req; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.frags, 4); + + ret = xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); + skcipher_request_zero(req); + return ret; +} + +/* + * This function makes the assumption that it was ultimately called + * from gss_wrap(). + * + * The client auth_gss code moves any existing tail data into a + * separate page before calling gss_wrap. + * The server svcauth_gss code ensures that both the head and the + * tail have slack space of RPC_MAX_AUTH_SIZE before calling gss_wrap. + * + * Even with that guarantee, this function may be called more than + * once in the processing of gss_wrap(). The best we can do is + * verify at compile-time (see GSS_KRB5_SLACK_CHECK) that the + * largest expected shift will fit within RPC_MAX_AUTH_SIZE. + * At run-time we can verify that a single invocation of this + * function doesn't attempt to use more the RPC_MAX_AUTH_SIZE. + */ + +int +xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) +{ + u8 *p; + + if (shiftlen == 0) + return 0; + + BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE); + BUG_ON(shiftlen > RPC_MAX_AUTH_SIZE); + + p = buf->head[0].iov_base + base; + + memmove(p + shiftlen, p, buf->head[0].iov_len - base); + + buf->head[0].iov_len += shiftlen; + buf->len += shiftlen; + + return 0; +} + +static u32 +gss_krb5_cts_crypt(struct crypto_skcipher *cipher, struct xdr_buf *buf, + u32 offset, u8 *iv, struct page **pages, int encrypt) +{ + u32 ret; + struct scatterlist sg[1]; + SKCIPHER_REQUEST_ON_STACK(req, cipher); + u8 *data; + struct page **save_pages; + u32 len = buf->len - offset; + + if (len > GSS_KRB5_MAX_BLOCKSIZE * 2) { + WARN_ON(0); + return -ENOMEM; + } + data = kmalloc(GSS_KRB5_MAX_BLOCKSIZE * 2, GFP_NOFS); + if (!data) + return -ENOMEM; + + /* + * For encryption, we want to read from the cleartext + * page cache pages, and write the encrypted data to + * the supplied xdr_buf pages. + */ + save_pages = buf->pages; + if (encrypt) + buf->pages = pages; + + ret = read_bytes_from_xdr_buf(buf, offset, data, len); + buf->pages = save_pages; + if (ret) + goto out; + + sg_init_one(sg, data, len); + + skcipher_request_set_tfm(req, cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, len, iv); + + if (encrypt) + ret = crypto_skcipher_encrypt(req); + else + ret = crypto_skcipher_decrypt(req); + + skcipher_request_zero(req); + + if (ret) + goto out; + + ret = write_bytes_to_xdr_buf(buf, offset, data, len); + +out: + kfree(data); + return ret; +} + +u32 +gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + u32 err; + struct xdr_netobj hmac; + u8 *cksumkey; + u8 *ecptr; + struct crypto_skcipher *cipher, *aux_cipher; + int blocksize; + struct page **save_pages; + int nblocks, nbytes; + struct encryptor_desc desc; + u32 cbcbytes; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksumkey = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } else { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksumkey = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } + blocksize = crypto_skcipher_blocksize(cipher); + + /* hide the gss token header and insert the confounder */ + offset += GSS_KRB5_TOK_HDR_LEN; + if (xdr_extend_head(buf, offset, kctx->gk5e->conflen)) + return GSS_S_FAILURE; + gss_krb5_make_confounder(buf->head[0].iov_base + offset, kctx->gk5e->conflen); + offset -= GSS_KRB5_TOK_HDR_LEN; + + if (buf->tail[0].iov_base != NULL) { + ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len; + } else { + buf->tail[0].iov_base = buf->head[0].iov_base + + buf->head[0].iov_len; + buf->tail[0].iov_len = 0; + ecptr = buf->tail[0].iov_base; + } + + /* copy plaintext gss token header after filler (if any) */ + memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN); + buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN; + buf->len += GSS_KRB5_TOK_HDR_LEN; + + /* Do the HMAC */ + hmac.len = GSS_KRB5_MAX_CKSUM_LEN; + hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len; + + /* + * When we are called, pages points to the real page cache + * data -- which we can't go and encrypt! buf->pages points + * to scratch pages which we are going to send off to the + * client/server. Swap in the plaintext pages to calculate + * the hmac. + */ + save_pages = buf->pages; + buf->pages = pages; + + err = make_checksum_v2(kctx, NULL, 0, buf, + offset + GSS_KRB5_TOK_HDR_LEN, + cksumkey, usage, &hmac); + buf->pages = save_pages; + if (err) + return GSS_S_FAILURE; + + nbytes = buf->len - offset - GSS_KRB5_TOK_HDR_LEN; + nblocks = (nbytes + blocksize - 1) / blocksize; + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + + desc.pos = offset + GSS_KRB5_TOK_HDR_LEN; + desc.fragno = 0; + desc.fraglen = 0; + desc.pages = pages; + desc.outbuf = buf; + desc.req = req; + + skcipher_request_set_tfm(req, aux_cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + err = xdr_process_buf(buf, offset + GSS_KRB5_TOK_HDR_LEN, + cbcbytes, encryptor, &desc); + skcipher_request_zero(req); + if (err) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + err = gss_krb5_cts_crypt(cipher, buf, + offset + GSS_KRB5_TOK_HDR_LEN + cbcbytes, + desc.iv, pages, 1); + if (err) { + err = GSS_S_FAILURE; + goto out_err; + } + + /* Now update buf to account for HMAC */ + buf->tail[0].iov_len += kctx->gk5e->cksumlength; + buf->len += kctx->gk5e->cksumlength; + +out_err: + if (err) + err = GSS_S_FAILURE; + return err; +} + +u32 +gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, + u32 *headskip, u32 *tailskip) +{ + struct xdr_buf subbuf; + u32 ret = 0; + u8 *cksum_key; + struct crypto_skcipher *cipher, *aux_cipher; + struct xdr_netobj our_hmac_obj; + u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + int nblocks, blocksize, cbcbytes; + struct decryptor_desc desc; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksum_key = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } else { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksum_key = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } + blocksize = crypto_skcipher_blocksize(cipher); + + + /* create a segment skipping the header and leaving out the checksum */ + xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN, + (buf->len - offset - GSS_KRB5_TOK_HDR_LEN - + kctx->gk5e->cksumlength)); + + nblocks = (subbuf.len + blocksize - 1) / blocksize; + + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + + desc.fragno = 0; + desc.fraglen = 0; + desc.req = req; + + skcipher_request_set_tfm(req, aux_cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.frags, 4); + + ret = xdr_process_buf(&subbuf, 0, cbcbytes, decryptor, &desc); + skcipher_request_zero(req); + if (ret) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + ret = gss_krb5_cts_crypt(cipher, &subbuf, cbcbytes, desc.iv, NULL, 0); + if (ret) + goto out_err; + + + /* Calculate our hmac over the plaintext data */ + our_hmac_obj.len = sizeof(our_hmac); + our_hmac_obj.data = our_hmac; + + ret = make_checksum_v2(kctx, NULL, 0, &subbuf, 0, + cksum_key, usage, &our_hmac_obj); + if (ret) + goto out_err; + + /* Get the packet's hmac value */ + ret = read_bytes_from_xdr_buf(buf, buf->len - kctx->gk5e->cksumlength, + pkt_hmac, kctx->gk5e->cksumlength); + if (ret) + goto out_err; + + if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { + ret = GSS_S_BAD_SIG; + goto out_err; + } + *headskip = kctx->gk5e->conflen; + *tailskip = kctx->gk5e->cksumlength; +out_err: + if (ret && ret != GSS_S_BAD_SIG) + ret = GSS_S_FAILURE; + return ret; +} + +/* + * Compute Kseq given the initial session key and the checksum. + * Set the key of the given cipher. + */ +int +krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, + unsigned char *cksum) +{ + struct crypto_shash *hmac; + struct shash_desc *desc; + u8 Kseq[GSS_KRB5_MAX_KEYLEN]; + u32 zeroconstant = 0; + int err; + + dprintk("%s: entered\n", __func__); + + hmac = crypto_alloc_shash(kctx->gk5e->cksum_name, 0, 0); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld, allocating hash '%s'\n", + __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); + return PTR_ERR(hmac); + } + + desc = kmalloc(sizeof(*desc) + crypto_shash_descsize(hmac), + GFP_NOFS); + if (!desc) { + dprintk("%s: failed to allocate shash descriptor for '%s'\n", + __func__, kctx->gk5e->cksum_name); + crypto_free_shash(hmac); + return -ENOMEM; + } + + desc->tfm = hmac; + desc->flags = 0; + + /* Compute intermediate Kseq from session key */ + err = crypto_shash_setkey(hmac, kctx->Ksess, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = crypto_shash_digest(desc, (u8 *)&zeroconstant, 4, Kseq); + if (err) + goto out_err; + + /* Compute final Kseq from the checksum and intermediate Kseq */ + err = crypto_shash_setkey(hmac, Kseq, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = crypto_shash_digest(desc, cksum, 8, Kseq); + if (err) + goto out_err; + + err = crypto_skcipher_setkey(cipher, Kseq, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = 0; + +out_err: + kzfree(desc); + crypto_free_shash(hmac); + dprintk("%s: returning %d\n", __func__, err); + return err; +} + +/* + * Compute Kcrypt given the initial session key and the plaintext seqnum. + * Set the key of cipher kctx->enc. + */ +int +krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, + s32 seqnum) +{ + struct crypto_shash *hmac; + struct shash_desc *desc; + u8 Kcrypt[GSS_KRB5_MAX_KEYLEN]; + u8 zeroconstant[4] = {0}; + u8 seqnumarray[4]; + int err, i; + + dprintk("%s: entered, seqnum %u\n", __func__, seqnum); + + hmac = crypto_alloc_shash(kctx->gk5e->cksum_name, 0, 0); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld, allocating hash '%s'\n", + __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); + return PTR_ERR(hmac); + } + + desc = kmalloc(sizeof(*desc) + crypto_shash_descsize(hmac), + GFP_NOFS); + if (!desc) { + dprintk("%s: failed to allocate shash descriptor for '%s'\n", + __func__, kctx->gk5e->cksum_name); + crypto_free_shash(hmac); + return -ENOMEM; + } + + desc->tfm = hmac; + desc->flags = 0; + + /* Compute intermediate Kcrypt from session key */ + for (i = 0; i < kctx->gk5e->keylength; i++) + Kcrypt[i] = kctx->Ksess[i] ^ 0xf0; + + err = crypto_shash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = crypto_shash_digest(desc, zeroconstant, 4, Kcrypt); + if (err) + goto out_err; + + /* Compute final Kcrypt from the seqnum and intermediate Kcrypt */ + err = crypto_shash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + seqnumarray[0] = (unsigned char) ((seqnum >> 24) & 0xff); + seqnumarray[1] = (unsigned char) ((seqnum >> 16) & 0xff); + seqnumarray[2] = (unsigned char) ((seqnum >> 8) & 0xff); + seqnumarray[3] = (unsigned char) ((seqnum >> 0) & 0xff); + + err = crypto_shash_digest(desc, seqnumarray, 4, Kcrypt); + if (err) + goto out_err; + + err = crypto_skcipher_setkey(cipher, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = 0; + +out_err: + kzfree(desc); + crypto_free_shash(hmac); + dprintk("%s: returning %d\n", __func__, err); + return err; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c new file mode 100644 index 000000000..f7fe2d2b8 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -0,0 +1,326 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <crypto/skcipher.h> +#include <linux/err.h> +#include <linux/types.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> +#include <linux/lcm.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* + * This is the n-fold function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +static void krb5_nfold(u32 inbits, const u8 *in, + u32 outbits, u8 *out) +{ + unsigned long ulcm; + int byte, i, msbit; + + /* the code below is more readable if I make these bytes + instead of bits */ + + inbits >>= 3; + outbits >>= 3; + + /* first compute lcm(n,k) */ + ulcm = lcm(inbits, outbits); + + /* now do the real work */ + + memset(out, 0, outbits); + byte = 0; + + /* this will end up cycling through k lcm(k,n)/k times, which + is correct */ + for (i = ulcm-1; i >= 0; i--) { + /* compute the msbit in k which gets added into this byte */ + msbit = ( + /* first, start with the msbit in the first, + * unrotated byte */ + ((inbits << 3) - 1) + /* then, for each byte, shift to the right + * for each repetition */ + + (((inbits << 3) + 13) * (i/inbits)) + /* last, pick out the correct byte within + * that shifted repetition */ + + ((inbits - (i % inbits)) << 3) + ) % (inbits << 3); + + /* pull out the byte value itself */ + byte += (((in[((inbits - 1) - (msbit >> 3)) % inbits] << 8)| + (in[((inbits) - (msbit >> 3)) % inbits])) + >> ((msbit & 7) + 1)) & 0xff; + + /* do the addition */ + byte += out[i % outbits]; + out[i % outbits] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + + } + + /* if there's a carry bit left over, add it back in */ + if (byte) { + for (i = outbits - 1; i >= 0; i--) { + /* do the addition */ + byte += out[i]; + out[i] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + } + } +} + +/* + * This is the DK (derive_key) function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *in_constant, + gfp_t gfp_mask) +{ + size_t blocksize, keybytes, keylength, n; + unsigned char *inblockdata, *outblockdata, *rawkey; + struct xdr_netobj inblock, outblock; + struct crypto_skcipher *cipher; + u32 ret = EINVAL; + + blocksize = gk5e->blocksize; + keybytes = gk5e->keybytes; + keylength = gk5e->keylength; + + if ((inkey->len != keylength) || (outkey->len != keylength)) + goto err_return; + + cipher = crypto_alloc_skcipher(gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + goto err_return; + if (crypto_skcipher_setkey(cipher, inkey->data, inkey->len)) + goto err_return; + + /* allocate and set up buffers */ + + ret = ENOMEM; + inblockdata = kmalloc(blocksize, gfp_mask); + if (inblockdata == NULL) + goto err_free_cipher; + + outblockdata = kmalloc(blocksize, gfp_mask); + if (outblockdata == NULL) + goto err_free_in; + + rawkey = kmalloc(keybytes, gfp_mask); + if (rawkey == NULL) + goto err_free_out; + + inblock.data = (char *) inblockdata; + inblock.len = blocksize; + + outblock.data = (char *) outblockdata; + outblock.len = blocksize; + + /* initialize the input block */ + + if (in_constant->len == inblock.len) { + memcpy(inblock.data, in_constant->data, inblock.len); + } else { + krb5_nfold(in_constant->len * 8, in_constant->data, + inblock.len * 8, inblock.data); + } + + /* loop encrypting the blocks until enough key bytes are generated */ + + n = 0; + while (n < keybytes) { + (*(gk5e->encrypt))(cipher, NULL, inblock.data, + outblock.data, inblock.len); + + if ((keybytes - n) <= outblock.len) { + memcpy(rawkey + n, outblock.data, (keybytes - n)); + break; + } + + memcpy(rawkey + n, outblock.data, outblock.len); + memcpy(inblock.data, outblock.data, outblock.len); + n += outblock.len; + } + + /* postprocess the key */ + + inblock.data = (char *) rawkey; + inblock.len = keybytes; + + BUG_ON(gk5e->mk_key == NULL); + ret = (*(gk5e->mk_key))(gk5e, &inblock, outkey); + if (ret) { + dprintk("%s: got %d from mk_key function for '%s'\n", + __func__, ret, gk5e->encrypt_name); + goto err_free_raw; + } + + /* clean memory, free resources and exit */ + + ret = 0; + +err_free_raw: + memset(rawkey, 0, keybytes); + kfree(rawkey); +err_free_out: + memset(outblockdata, 0, blocksize); + kfree(outblockdata); +err_free_in: + memset(inblockdata, 0, blocksize); + kfree(inblockdata); +err_free_cipher: + crypto_free_skcipher(cipher); +err_return: + return ret; +} + +#define smask(step) ((1<<step)-1) +#define pstep(x, step) (((x)&smask(step))^(((x)>>step)&smask(step))) +#define parity_char(x) pstep(pstep(pstep((x), 4), 2), 1) + +static void mit_des_fixup_key_parity(u8 key[8]) +{ + int i; + for (i = 0; i < 8; i++) { + key[i] &= 0xfe; + key[i] |= 1^parity_char(key[i]); + } +} + +/* + * This is the des3 key derivation postprocess function + */ +u32 gss_krb5_des3_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + int i; + u32 ret = EINVAL; + + if (key->len != 24) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 21) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + + /* take the seven bytes, move them around into the top 7 bits of the + 8 key bytes, then compute the parity bits. Do this three times. */ + + for (i = 0; i < 3; i++) { + memcpy(key->data + i*8, randombits->data + i*7, 7); + key->data[i*8+7] = (((key->data[i*8]&1)<<1) | + ((key->data[i*8+1]&1)<<2) | + ((key->data[i*8+2]&1)<<3) | + ((key->data[i*8+3]&1)<<4) | + ((key->data[i*8+4]&1)<<5) | + ((key->data[i*8+5]&1)<<6) | + ((key->data[i*8+6]&1)<<7)); + + mit_des_fixup_key_parity(key->data + i*8); + } + ret = 0; +err_out: + return ret; +} + +/* + * This is the aes key derivation postprocess function + */ +u32 gss_krb5_aes_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + u32 ret = EINVAL; + + if (key->len != 16 && key->len != 32) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 16 && randombits->len != 32) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + if (randombits->len != key->len) { + dprintk("%s: randombits->len is %d, key->len is %d\n", + __func__, randombits->len, key->len); + goto err_out; + } + memcpy(key->data, randombits->data, key->len); + ret = 0; +err_out: + return ret; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c new file mode 100644 index 000000000..14f2823ad --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -0,0 +1,766 @@ +/* + * linux/net/sunrpc/gss_krb5_mech.c + * + * Copyright (c) 2001-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * J. Bruce Fields <bfields@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <crypto/hash.h> +#include <crypto/skcipher.h> +#include <linux/err.h> +#include <linux/module.h> +#include <linux/init.h> +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/gss_krb5_enctypes.h> + +#include "auth_gss_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct gss_api_mech gss_kerberos_mech; /* forward declaration */ + +static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { + /* + * DES (All DES enctypes are mapped to the same gss functionality) + */ + { + .etype = ENCTYPE_DES_CBC_RAW, + .ctype = CKSUMTYPE_RSA_MD5, + .name = "des-cbc-crc", + .encrypt_name = "cbc(des)", + .cksum_name = "md5", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = NULL, + .signalg = SGN_ALG_DES_MAC_MD5, + .sealalg = SEAL_ALG_DES, + .keybytes = 7, + .keylength = 8, + .blocksize = 8, + .conflen = 8, + .cksumlength = 8, + .keyed_cksum = 0, + }, + /* + * RC4-HMAC + */ + { + .etype = ENCTYPE_ARCFOUR_HMAC, + .ctype = CKSUMTYPE_HMAC_MD5_ARCFOUR, + .name = "rc4-hmac", + .encrypt_name = "ecb(arc4)", + .cksum_name = "hmac(md5)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = NULL, + .signalg = SGN_ALG_HMAC_MD5, + .sealalg = SEAL_ALG_MICROSOFT_RC4, + .keybytes = 16, + .keylength = 16, + .blocksize = 1, + .conflen = 8, + .cksumlength = 8, + .keyed_cksum = 1, + }, + /* + * 3DES + */ + { + .etype = ENCTYPE_DES3_CBC_RAW, + .ctype = CKSUMTYPE_HMAC_SHA1_DES3, + .name = "des3-hmac-sha1", + .encrypt_name = "cbc(des3_ede)", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_des3_make_key, + .signalg = SGN_ALG_HMAC_SHA1_DES3_KD, + .sealalg = SEAL_ALG_DES3KD, + .keybytes = 21, + .keylength = 24, + .blocksize = 8, + .conflen = 8, + .cksumlength = 20, + .keyed_cksum = 1, + }, + /* + * AES128 + */ + { + .etype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES128, + .name = "aes128-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 16, + .keylength = 16, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, + /* + * AES256 + */ + { + .etype = ENCTYPE_AES256_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES256, + .name = "aes256-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 32, + .keylength = 32, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, +}; + +static const int num_supported_enctypes = + ARRAY_SIZE(supported_gss_krb5_enctypes); + +static int +supported_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return 1; + return 0; +} + +static const struct gss_krb5_enctype * +get_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return &supported_gss_krb5_enctypes[i]; + return NULL; +} + +static inline const void * +get_key(const void *p, const void *end, + struct krb5_ctx *ctx, struct crypto_skcipher **res) +{ + struct xdr_netobj key; + int alg; + + p = simple_get_bytes(p, end, &alg, sizeof(alg)); + if (IS_ERR(p)) + goto out_err; + + switch (alg) { + case ENCTYPE_DES_CBC_CRC: + case ENCTYPE_DES_CBC_MD4: + case ENCTYPE_DES_CBC_MD5: + /* Map all these key types to ENCTYPE_DES_CBC_RAW */ + alg = ENCTYPE_DES_CBC_RAW; + break; + } + + if (!supported_gss_krb5_enctype(alg)) { + printk(KERN_WARNING "gss_kerberos_mech: unsupported " + "encryption key algorithm %d\n", alg); + p = ERR_PTR(-EINVAL); + goto out_err; + } + p = simple_get_netobj(p, end, &key); + if (IS_ERR(p)) + goto out_err; + + *res = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(*res)) { + printk(KERN_WARNING "gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); + *res = NULL; + goto out_err_free_key; + } + if (crypto_skcipher_setkey(*res, key.data, key.len)) { + printk(KERN_WARNING "gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); + goto out_err_free_tfm; + } + + kfree(key.data); + return p; + +out_err_free_tfm: + crypto_free_skcipher(*res); +out_err_free_key: + kfree(key.data); + p = ERR_PTR(-EINVAL); +out_err: + return p; +} + +static int +gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) +{ + int tmp; + + p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); + if (IS_ERR(p)) + goto out_err; + + /* Old format supports only DES! Any other enctype uses new format */ + ctx->enctype = ENCTYPE_DES_CBC_RAW; + + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + p = ERR_PTR(-EINVAL); + goto out_err; + } + + /* The downcall format was designed before we completely understood + * the uses of the context fields; so it includes some stuff we + * just give some minimal sanity-checking, and some we ignore + * completely (like the next twenty bytes): */ + if (unlikely(p + 20 > end || p + 20 < p)) { + p = ERR_PTR(-EFAULT); + goto out_err; + } + p += 20; + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err; + if (tmp != SGN_ALG_DES_MAC_MD5) { + p = ERR_PTR(-ENOSYS); + goto out_err; + } + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err; + if (tmp != SEAL_ALG_DES) { + p = ERR_PTR(-ENOSYS); + goto out_err; + } + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) + goto out_err; + p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); + if (IS_ERR(p)) + goto out_err; + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) + goto out_err; + p = get_key(p, end, ctx, &ctx->enc); + if (IS_ERR(p)) + goto out_err_free_mech; + p = get_key(p, end, ctx, &ctx->seq); + if (IS_ERR(p)) + goto out_err_free_key1; + if (p != end) { + p = ERR_PTR(-EFAULT); + goto out_err_free_key2; + } + + return 0; + +out_err_free_key2: + crypto_free_skcipher(ctx->seq); +out_err_free_key1: + crypto_free_skcipher(ctx->enc); +out_err_free_mech: + kfree(ctx->mech_used.data); +out_err: + return PTR_ERR(p); +} + +static struct crypto_skcipher * +context_v2_alloc_cipher(struct krb5_ctx *ctx, const char *cname, u8 *key) +{ + struct crypto_skcipher *cp; + + cp = crypto_alloc_skcipher(cname, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(cp)) { + dprintk("gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", cname); + return NULL; + } + if (crypto_skcipher_setkey(cp, key, ctx->gk5e->keylength)) { + dprintk("gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", cname); + crypto_free_skcipher(cp); + return NULL; + } + return cp; +} + +static inline void +set_cdata(u8 cdata[GSS_KRB5_K5CLENGTH], u32 usage, u8 seed) +{ + cdata[0] = (usage>>24)&0xff; + cdata[1] = (usage>>16)&0xff; + cdata[2] = (usage>>8)&0xff; + cdata[3] = usage&0xff; + cdata[4] = seed; +} + +static int +context_derive_keys_des3(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* seq uses the raw key */ + ctx->seq = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->seq == NULL) + goto out_err; + + ctx->enc = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->enc == NULL) + goto out_free_seq; + + /* derive cksum */ + set_cdata(cdata, KG_USAGE_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->cksum; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving cksum key\n", + __func__, err); + goto out_free_enc; + } + + return 0; + +out_free_enc: + crypto_free_skcipher(ctx->enc); +out_free_seq: + crypto_free_skcipher(ctx->seq); +out_err: + return -EINVAL; +} + +/* + * Note that RC4 depends on deriving keys using the sequence + * number or the checksum of a token. Therefore, the final keys + * cannot be calculated until the token is being constructed! + */ +static int +context_derive_keys_rc4(struct krb5_ctx *ctx) +{ + struct crypto_shash *hmac; + char sigkeyconstant[] = "signaturekey"; + int slen = strlen(sigkeyconstant) + 1; /* include null terminator */ + struct shash_desc *desc; + int err; + + dprintk("RPC: %s: entered\n", __func__); + /* + * derive cksum (aka Ksign) key + */ + hmac = crypto_alloc_shash(ctx->gk5e->cksum_name, 0, 0); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld allocating hash '%s'\n", + __func__, PTR_ERR(hmac), ctx->gk5e->cksum_name); + err = PTR_ERR(hmac); + goto out_err; + } + + err = crypto_shash_setkey(hmac, ctx->Ksess, ctx->gk5e->keylength); + if (err) + goto out_err_free_hmac; + + + desc = kmalloc(sizeof(*desc) + crypto_shash_descsize(hmac), GFP_NOFS); + if (!desc) { + dprintk("%s: failed to allocate hash descriptor for '%s'\n", + __func__, ctx->gk5e->cksum_name); + err = -ENOMEM; + goto out_err_free_hmac; + } + + desc->tfm = hmac; + desc->flags = 0; + + err = crypto_shash_digest(desc, sigkeyconstant, slen, ctx->cksum); + kzfree(desc); + if (err) + goto out_err_free_hmac; + /* + * allocate hash, and skciphers for data and seqnum encryption + */ + ctx->enc = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(ctx->enc)) { + err = PTR_ERR(ctx->enc); + goto out_err_free_hmac; + } + + ctx->seq = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(ctx->seq)) { + crypto_free_skcipher(ctx->enc); + err = PTR_ERR(ctx->seq); + goto out_err_free_hmac; + } + + dprintk("RPC: %s: returning success\n", __func__); + + err = 0; + +out_err_free_hmac: + crypto_free_shash(hmac); +out_err: + dprintk("RPC: %s: returning %d\n", __func__, err); + return err; +} + +static int +context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* initiator seal encryption */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->initiator_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_seal key\n", + __func__, err); + goto out_err; + } + ctx->initiator_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->initiator_seal); + if (ctx->initiator_enc == NULL) + goto out_err; + + /* acceptor seal encryption */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->acceptor_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_seal key\n", + __func__, err); + goto out_free_initiator_enc; + } + ctx->acceptor_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->acceptor_seal); + if (ctx->acceptor_enc == NULL) + goto out_free_initiator_enc; + + /* initiator sign checksum */ + set_cdata(cdata, KG_USAGE_INITIATOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->initiator_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor sign checksum */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->acceptor_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* initiator seal integrity */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->initiator_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor seal integrity */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->acceptor_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + switch (ctx->enctype) { + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + ctx->initiator_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->initiator_seal); + if (ctx->initiator_enc_aux == NULL) + goto out_free_acceptor_enc; + ctx->acceptor_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->acceptor_seal); + if (ctx->acceptor_enc_aux == NULL) { + crypto_free_skcipher(ctx->initiator_enc_aux); + goto out_free_acceptor_enc; + } + } + + return 0; + +out_free_acceptor_enc: + crypto_free_skcipher(ctx->acceptor_enc); +out_free_initiator_enc: + crypto_free_skcipher(ctx->initiator_enc); +out_err: + return -EINVAL; +} + +static int +gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, + gfp_t gfp_mask) +{ + int keylen; + + p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags)); + if (IS_ERR(p)) + goto out_err; + ctx->initiate = ctx->flags & KRB5_CTX_FLAG_INITIATOR; + + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) + goto out_err; + p = simple_get_bytes(p, end, &ctx->seq_send64, sizeof(ctx->seq_send64)); + if (IS_ERR(p)) + goto out_err; + /* set seq_send for use by "older" enctypes */ + ctx->seq_send = ctx->seq_send64; + if (ctx->seq_send64 != ctx->seq_send) { + dprintk("%s: seq_send64 %lx, seq_send %x overflow?\n", __func__, + (unsigned long)ctx->seq_send64, ctx->seq_send); + p = ERR_PTR(-EINVAL); + goto out_err; + } + p = simple_get_bytes(p, end, &ctx->enctype, sizeof(ctx->enctype)); + if (IS_ERR(p)) + goto out_err; + /* Map ENCTYPE_DES3_CBC_SHA1 to ENCTYPE_DES3_CBC_RAW */ + if (ctx->enctype == ENCTYPE_DES3_CBC_SHA1) + ctx->enctype = ENCTYPE_DES3_CBC_RAW; + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + dprintk("gss_kerberos_mech: unsupported krb5 enctype %u\n", + ctx->enctype); + p = ERR_PTR(-EINVAL); + goto out_err; + } + keylen = ctx->gk5e->keylength; + + p = simple_get_bytes(p, end, ctx->Ksess, keylen); + if (IS_ERR(p)) + goto out_err; + + if (p != end) { + p = ERR_PTR(-EINVAL); + goto out_err; + } + + ctx->mech_used.data = kmemdup(gss_kerberos_mech.gm_oid.data, + gss_kerberos_mech.gm_oid.len, gfp_mask); + if (unlikely(ctx->mech_used.data == NULL)) { + p = ERR_PTR(-ENOMEM); + goto out_err; + } + ctx->mech_used.len = gss_kerberos_mech.gm_oid.len; + + switch (ctx->enctype) { + case ENCTYPE_DES3_CBC_RAW: + return context_derive_keys_des3(ctx, gfp_mask); + case ENCTYPE_ARCFOUR_HMAC: + return context_derive_keys_rc4(ctx); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return context_derive_keys_new(ctx, gfp_mask); + default: + return -EINVAL; + } + +out_err: + return PTR_ERR(p); +} + +static int +gss_import_sec_context_kerberos(const void *p, size_t len, + struct gss_ctx *ctx_id, + time_t *endtime, + gfp_t gfp_mask) +{ + const void *end = (const void *)((const char *)p + len); + struct krb5_ctx *ctx; + int ret; + + ctx = kzalloc(sizeof(*ctx), gfp_mask); + if (ctx == NULL) + return -ENOMEM; + + if (len == 85) + ret = gss_import_v1_context(p, end, ctx); + else + ret = gss_import_v2_context(p, end, ctx, gfp_mask); + + if (ret == 0) { + ctx_id->internal_ctx_id = ctx; + if (endtime) + *endtime = ctx->endtime; + } else + kfree(ctx); + + dprintk("RPC: %s: returning %d\n", __func__, ret); + return ret; +} + +static void +gss_delete_sec_context_kerberos(void *internal_ctx) { + struct krb5_ctx *kctx = internal_ctx; + + crypto_free_skcipher(kctx->seq); + crypto_free_skcipher(kctx->enc); + crypto_free_skcipher(kctx->acceptor_enc); + crypto_free_skcipher(kctx->initiator_enc); + crypto_free_skcipher(kctx->acceptor_enc_aux); + crypto_free_skcipher(kctx->initiator_enc_aux); + kfree(kctx->mech_used.data); + kfree(kctx); +} + +static const struct gss_api_ops gss_kerberos_ops = { + .gss_import_sec_context = gss_import_sec_context_kerberos, + .gss_get_mic = gss_get_mic_kerberos, + .gss_verify_mic = gss_verify_mic_kerberos, + .gss_wrap = gss_wrap_kerberos, + .gss_unwrap = gss_unwrap_kerberos, + .gss_delete_sec_context = gss_delete_sec_context_kerberos, +}; + +static struct pf_desc gss_kerberos_pfs[] = { + [0] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_NONE, + .name = "krb5", + }, + [1] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5I, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_INTEGRITY, + .name = "krb5i", + .datatouch = true, + }, + [2] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5P, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_PRIVACY, + .name = "krb5p", + .datatouch = true, + }, +}; + +MODULE_ALIAS("rpc-auth-gss-krb5"); +MODULE_ALIAS("rpc-auth-gss-krb5i"); +MODULE_ALIAS("rpc-auth-gss-krb5p"); +MODULE_ALIAS("rpc-auth-gss-390003"); +MODULE_ALIAS("rpc-auth-gss-390004"); +MODULE_ALIAS("rpc-auth-gss-390005"); +MODULE_ALIAS("rpc-auth-gss-1.2.840.113554.1.2.2"); + +static struct gss_api_mech gss_kerberos_mech = { + .gm_name = "krb5", + .gm_owner = THIS_MODULE, + .gm_oid = { 9, "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02" }, + .gm_ops = &gss_kerberos_ops, + .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs), + .gm_pfs = gss_kerberos_pfs, + .gm_upcall_enctypes = KRB5_SUPPORTED_ENCTYPES, +}; + +static int __init init_kerberos_module(void) +{ + int status; + + status = gss_mech_register(&gss_kerberos_mech); + if (status) + printk("Failed to register kerberos gss mechanism!\n"); + return status; +} + +static void __exit cleanup_kerberos_module(void) +{ + gss_mech_unregister(&gss_kerberos_mech); +} + +MODULE_LICENSE("GPL"); +module_init(init_kerberos_module); +module_exit(cleanup_kerberos_module); diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c new file mode 100644 index 000000000..e1f057184 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -0,0 +1,232 @@ +/* + * linux/net/sunrpc/gss_krb5_seal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5seal.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * J. Bruce Fields <bfields@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/types.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/random.h> +#include <linux/crypto.h> +#include <linux/atomic.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +DEFINE_SPINLOCK(krb5_seq_lock); + +static void * +setup_token(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + u16 *ptr; + void *krb5_hdr; + int body_size = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + + token->len = g_token_size(&ctx->mech_used, body_size); + + ptr = (u16 *)token->data; + g_make_token_header(&ctx->mech_used, body_size, (unsigned char **)&ptr); + + /* ptr now at start of header described in rfc 1964, section 1.2.1: */ + krb5_hdr = ptr; + *ptr++ = KG_TOK_MIC_MSG; + /* + * signalg is stored as if it were converted from LE to host endian, even + * though it's an opaque pair of bytes according to the RFC. + */ + *ptr++ = (__force u16)cpu_to_le16(ctx->gk5e->signalg); + *ptr++ = SEAL_ALG_NONE; + *ptr = 0xffff; + + return krb5_hdr; +} + +static void * +setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + u16 *ptr; + void *krb5_hdr; + u8 *p, flags = 0x00; + + if ((ctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= 0x01; + if (ctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) + flags |= 0x04; + + /* Per rfc 4121, sec 4.2.6.1, there is no header, + * just start the token */ + krb5_hdr = ptr = (u16 *)token->data; + + *ptr++ = KG2_TOK_MIC; + p = (u8 *)ptr; + *p++ = flags; + *p++ = 0xff; + ptr = (u16 *)p; + *ptr++ = 0xffff; + *ptr = 0xffff; + + token->len = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + return krb5_hdr; +} + +static u32 +gss_get_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + void *ptr; + s32 now; + u32 seq_send; + u8 *cksumkey; + + dprintk("RPC: %s\n", __func__); + BUG_ON(ctx == NULL); + + now = get_seconds(); + + ptr = setup_token(ctx, token); + + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(ctx, ptr, 8, text, 0, cksumkey, + KG_USAGE_SIGN, &md5cksum)) + return GSS_S_FAILURE; + + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); + + spin_lock(&krb5_seq_lock); + seq_send = ctx->seq_send++; + spin_unlock(&krb5_seq_lock); + + if (krb5_make_seq_num(ctx, ctx->seq, ctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8)) + return GSS_S_FAILURE; + + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = { .len = sizeof(cksumdata), + .data = cksumdata}; + void *krb5_hdr; + s32 now; + u64 seq_send; + u8 *cksumkey; + unsigned int cksum_usage; + __be64 seq_send_be64; + + dprintk("RPC: %s\n", __func__); + + krb5_hdr = setup_token_v2(ctx, token); + + /* Set up the sequence number. Now 64-bits in clear + * text and w/o direction indicator */ + spin_lock(&krb5_seq_lock); + seq_send = ctx->seq_send64++; + spin_unlock(&krb5_seq_lock); + + seq_send_be64 = cpu_to_be64(seq_send); + memcpy(krb5_hdr + 8, (char *) &seq_send_be64, 8); + + if (ctx->initiate) { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } else { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } + + if (make_checksum_v2(ctx, krb5_hdr, GSS_KRB5_TOK_HDR_LEN, + text, 0, cksumkey, cksum_usage, &cksumobj)) + return GSS_S_FAILURE; + + memcpy(krb5_hdr + GSS_KRB5_TOK_HDR_LEN, cksumobj.data, cksumobj.len); + + now = get_seconds(); + + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +u32 +gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_get_mic_v1(ctx, text, token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_get_mic_v2(ctx, text, token); + } +} diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c new file mode 100644 index 000000000..2d2ed6772 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -0,0 +1,193 @@ +/* + * linux/net/sunrpc/gss_krb5_seqnum.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/util_seqnum.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include <crypto/skcipher.h> +#include <linux/types.h> +#include <linux/sunrpc/gss_krb5.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static s32 +krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, + unsigned char *cksum, unsigned char *buf) +{ + struct crypto_skcipher *cipher; + unsigned char *plain; + s32 code; + + dprintk("RPC: %s:\n", __func__); + cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return PTR_ERR(cipher); + + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; + + plain[0] = (unsigned char) ((seqnum >> 24) & 0xff); + plain[1] = (unsigned char) ((seqnum >> 16) & 0xff); + plain[2] = (unsigned char) ((seqnum >> 8) & 0xff); + plain[3] = (unsigned char) ((seqnum >> 0) & 0xff); + plain[4] = direction; + plain[5] = direction; + plain[6] = direction; + plain[7] = direction; + + code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); + if (code) + goto out; + + code = krb5_encrypt(cipher, cksum, plain, buf, 8); +out: + crypto_free_skcipher(cipher); + kfree(plain); + return code; +} +s32 +krb5_make_seq_num(struct krb5_ctx *kctx, + struct crypto_skcipher *key, + int direction, + u32 seqnum, + unsigned char *cksum, unsigned char *buf) +{ + unsigned char *plain; + s32 code; + + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) + return krb5_make_rc4_seq_num(kctx, direction, seqnum, + cksum, buf); + + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; + + plain[0] = (unsigned char) (seqnum & 0xff); + plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); + plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); + plain[3] = (unsigned char) ((seqnum >> 24) & 0xff); + + plain[4] = direction; + plain[5] = direction; + plain[6] = direction; + plain[7] = direction; + + code = krb5_encrypt(key, cksum, plain, buf, 8); + kfree(plain); + return code; +} + +static s32 +krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, + unsigned char *buf, int *direction, s32 *seqnum) +{ + struct crypto_skcipher *cipher; + unsigned char *plain; + s32 code; + + dprintk("RPC: %s:\n", __func__); + cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return PTR_ERR(cipher); + + code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); + if (code) + goto out; + + plain = kmalloc(8, GFP_NOFS); + if (!plain) { + code = -ENOMEM; + goto out; + } + + code = krb5_decrypt(cipher, cksum, buf, plain, 8); + if (code) + goto out_plain; + + if ((plain[4] != plain[5]) || (plain[4] != plain[6]) + || (plain[4] != plain[7])) { + code = (s32)KG_BAD_SEQ; + goto out_plain; + } + + *direction = plain[4]; + + *seqnum = ((plain[0] << 24) | (plain[1] << 16) | + (plain[2] << 8) | (plain[3])); +out_plain: + kfree(plain); +out: + crypto_free_skcipher(cipher); + return code; +} + +s32 +krb5_get_seq_num(struct krb5_ctx *kctx, + unsigned char *cksum, + unsigned char *buf, + int *direction, u32 *seqnum) +{ + s32 code; + struct crypto_skcipher *key = kctx->seq; + unsigned char *plain; + + dprintk("RPC: krb5_get_seq_num:\n"); + + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) + return krb5_get_rc4_seq_num(kctx, cksum, buf, + direction, seqnum); + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; + + if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) + goto out; + + if ((plain[4] != plain[5]) || (plain[4] != plain[6]) || + (plain[4] != plain[7])) { + code = (s32)KG_BAD_SEQ; + goto out; + } + + *direction = plain[4]; + + *seqnum = ((plain[0]) | + (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); + +out: + kfree(plain); + return code; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c new file mode 100644 index 000000000..ef2b25b86 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c @@ -0,0 +1,227 @@ +/* + * linux/net/sunrpc/gss_krb5_unseal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5unseal.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/types.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/crypto.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* read_token is a mic token, and message_buffer is the data that the mic was + * supposedly taken over. */ + +static u32 +gss_verify_mic_v1(struct krb5_ctx *ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + int signalg; + int sealalg; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + s32 now; + int direction; + u32 seqnum; + unsigned char *ptr = (unsigned char *)read_token->data; + int bodysize; + u8 *cksumkey; + + dprintk("RPC: krb5_read_token\n"); + + if (g_verify_token_header(&ctx->mech_used, &bodysize, &ptr, + read_token->len)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_MIC_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_MIC_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != ctx->gk5e->signalg) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != SEAL_ALG_NONE) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(ctx, ptr, 8, message_buffer, 0, + cksumkey, KG_USAGE_SIGN, &md5cksum)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = get_seconds(); + + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + if (krb5_get_seq_num(ctx, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, + &direction, &seqnum)) + return GSS_S_FAILURE; + + if ((ctx->initiate && direction != 0xff) || + (!ctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + return GSS_S_COMPLETE; +} + +static u32 +gss_verify_mic_v2(struct krb5_ctx *ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = {.len = sizeof(cksumdata), + .data = cksumdata}; + s32 now; + u8 *ptr = read_token->data; + u8 *cksumkey; + u8 flags; + int i; + unsigned int cksum_usage; + __be16 be16_ptr; + + dprintk("RPC: %s\n", __func__); + + memcpy(&be16_ptr, (char *) ptr, 2); + if (be16_to_cpu(be16_ptr) != KG2_TOK_MIC) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!ctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (ctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if (flags & KG2_TOKEN_FLAG_SEALED) { + dprintk("%s: token has unexpected sealed flag\n", __func__); + return GSS_S_FAILURE; + } + + for (i = 3; i < 8; i++) + if (ptr[i] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + if (ctx->initiate) { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } else { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } + + if (make_checksum_v2(ctx, ptr, GSS_KRB5_TOK_HDR_LEN, message_buffer, 0, + cksumkey, cksum_usage, &cksumobj)) + return GSS_S_FAILURE; + + if (memcmp(cksumobj.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + now = get_seconds(); + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + return GSS_S_COMPLETE; +} + +u32 +gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, + struct xdr_buf *message_buffer, + struct xdr_netobj *read_token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_verify_mic_v1(ctx, message_buffer, read_token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_verify_mic_v2(ctx, message_buffer, read_token); + } +} diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c new file mode 100644 index 000000000..39a2e6729 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -0,0 +1,623 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + +#include <crypto/skcipher.h> +#include <linux/types.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/random.h> +#include <linux/pagemap.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static inline int +gss_krb5_padding(int blocksize, int length) +{ + return blocksize - (length % blocksize); +} + +static inline void +gss_krb5_add_padding(struct xdr_buf *buf, int offset, int blocksize) +{ + int padding = gss_krb5_padding(blocksize, buf->len - offset); + char *p; + struct kvec *iov; + + if (buf->page_len || buf->tail[0].iov_len) + iov = &buf->tail[0]; + else + iov = &buf->head[0]; + p = iov->iov_base + iov->iov_len; + iov->iov_len += padding; + buf->len += padding; + memset(p, padding, padding); +} + +static inline int +gss_krb5_remove_padding(struct xdr_buf *buf, int blocksize) +{ + u8 *ptr; + u8 pad; + size_t len = buf->len; + + if (len <= buf->head[0].iov_len) { + pad = *(u8 *)(buf->head[0].iov_base + len - 1); + if (pad > buf->head[0].iov_len) + return -EINVAL; + buf->head[0].iov_len -= pad; + goto out; + } else + len -= buf->head[0].iov_len; + if (len <= buf->page_len) { + unsigned int last = (buf->page_base + len - 1) + >>PAGE_SHIFT; + unsigned int offset = (buf->page_base + len - 1) + & (PAGE_SIZE - 1); + ptr = kmap_atomic(buf->pages[last]); + pad = *(ptr + offset); + kunmap_atomic(ptr); + goto out; + } else + len -= buf->page_len; + BUG_ON(len > buf->tail[0].iov_len); + pad = *(u8 *)(buf->tail[0].iov_base + len - 1); +out: + /* XXX: NOTE: we do not adjust the page lengths--they represent + * a range of data in the real filesystem page cache, and we need + * to know that range so the xdr code can properly place read data. + * However adjusting the head length, as we do above, is harmless. + * In the case of a request that fits into a single page, the server + * also uses length and head length together to determine the original + * start of the request to copy the request for deferal; so it's + * easier on the server if we adjust head and tail length in tandem. + * It's not really a problem that we don't fool with the page and + * tail lengths, though--at worst badly formed xdr might lead the + * server to attempt to parse the padding. + * XXX: Document all these weird requirements for gss mechanism + * wrap/unwrap functions. */ + if (pad > blocksize) + return -EINVAL; + if (buf->len > pad) + buf->len -= pad; + else + return -EINVAL; + return 0; +} + +void +gss_krb5_make_confounder(char *p, u32 conflen) +{ + static u64 i = 0; + u64 *q = (u64 *)p; + + /* rfc1964 claims this should be "random". But all that's really + * necessary is that it be unique. And not even that is necessary in + * our case since our "gssapi" implementation exists only to support + * rpcsec_gss, so we know that the only buffers we will ever encrypt + * already begin with a unique sequence number. Just to hedge my bets + * I'll make a half-hearted attempt at something unique, but ensuring + * uniqueness would mean worrying about atomicity and rollover, and I + * don't care enough. */ + + /* initialize to random value */ + if (i == 0) { + i = prandom_u32(); + i = (i << 32) | prandom_u32(); + } + + switch (conflen) { + case 16: + *q++ = i++; + /* fall through */ + case 8: + *q++ = i++; + break; + default: + BUG(); + } +} + +/* Assumptions: the head and tail of inbuf are ours to play with. + * The pages, however, may be real pages in the page cache and we replace + * them with scratch pages from **pages before writing to them. */ +/* XXX: obviously the above should be documentation of wrap interface, + * and shouldn't be in this kerberos-specific file. */ + +/* XXX factor out common code with seal/unseal. */ + +static u32 +gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + int blocksize = 0, plainlen; + unsigned char *ptr, *msg_start; + s32 now; + int headlen; + struct page **tmp_pages; + u32 seq_send; + u8 *cksumkey; + u32 conflen = kctx->gk5e->conflen; + + dprintk("RPC: %s\n", __func__); + + now = get_seconds(); + + blocksize = crypto_skcipher_blocksize(kctx->enc); + gss_krb5_add_padding(buf, offset, blocksize); + BUG_ON((buf->len - offset) % blocksize); + plainlen = conflen + buf->len - offset; + + headlen = g_token_size(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength + plainlen) - + (buf->len - offset); + + ptr = buf->head[0].iov_base + offset; + /* shift data to make room for header. */ + xdr_extend_head(buf, offset, headlen); + + /* XXX Would be cleverer to encrypt while copying. */ + BUG_ON((buf->len - offset - headlen) % blocksize); + + g_make_token_header(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + + kctx->gk5e->cksumlength + plainlen, &ptr); + + + /* ptr now at header described in rfc 1964, section 1.2.1: */ + ptr[0] = (unsigned char) ((KG_TOK_WRAP_MSG >> 8) & 0xff); + ptr[1] = (unsigned char) (KG_TOK_WRAP_MSG & 0xff); + + msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength; + + /* + * signalg and sealalg are stored as if they were converted from LE + * to host endian, even though they're opaque pairs of bytes according + * to the RFC. + */ + *(__le16 *)(ptr + 2) = cpu_to_le16(kctx->gk5e->signalg); + *(__le16 *)(ptr + 4) = cpu_to_le16(kctx->gk5e->sealalg); + ptr[6] = 0xff; + ptr[7] = 0xff; + + gss_krb5_make_confounder(msg_start, conflen); + + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; + + /* XXXJBF: UGH!: */ + tmp_pages = buf->pages; + buf->pages = pages; + if (make_checksum(kctx, ptr, 8, buf, offset + headlen - conflen, + cksumkey, KG_USAGE_SEAL, &md5cksum)) + return GSS_S_FAILURE; + buf->pages = tmp_pages; + + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); + + spin_lock(&krb5_seq_lock); + seq_send = kctx->seq_send++; + spin_unlock(&krb5_seq_lock); + + /* XXX would probably be more efficient to compute checksum + * and encrypt at the same time: */ + if ((krb5_make_seq_num(kctx, kctx->seq, kctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8))) + return GSS_S_FAILURE; + + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { + struct crypto_skcipher *cipher; + int err; + cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return GSS_S_FAILURE; + + krb5_rc4_setup_enc_key(kctx, cipher, seq_send); + + err = gss_encrypt_xdr_buf(cipher, buf, + offset + headlen - conflen, pages); + crypto_free_skcipher(cipher); + if (err) + return GSS_S_FAILURE; + } else { + if (gss_encrypt_xdr_buf(kctx->enc, buf, + offset + headlen - conflen, pages)) + return GSS_S_FAILURE; + } + + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) +{ + int signalg; + int sealalg; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + s32 now; + int direction; + s32 seqnum; + unsigned char *ptr; + int bodysize; + void *data_start, *orig_start; + int data_len; + int blocksize; + u32 conflen = kctx->gk5e->conflen; + int crypt_offset; + u8 *cksumkey; + + dprintk("RPC: gss_unwrap_kerberos\n"); + + ptr = (u8 *)buf->head[0].iov_base + offset; + if (g_verify_token_header(&kctx->mech_used, &bodysize, &ptr, + buf->len - offset)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_WRAP_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_WRAP_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + /* get the sign and seal algorithms */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != kctx->gk5e->signalg) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != kctx->gk5e->sealalg) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + /* + * Data starts after token header and checksum. ptr points + * to the beginning of the token header + */ + crypt_offset = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) - + (unsigned char *)buf->head[0].iov_base; + + /* + * Need plaintext seqnum to derive encryption key for arcfour-hmac + */ + if (krb5_get_seq_num(kctx, ptr + GSS_KRB5_TOK_HDR_LEN, + ptr + 8, &direction, &seqnum)) + return GSS_S_BAD_SIG; + + if ((kctx->initiate && direction != 0xff) || + (!kctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { + struct crypto_skcipher *cipher; + int err; + + cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return GSS_S_FAILURE; + + krb5_rc4_setup_enc_key(kctx, cipher, seqnum); + + err = gss_decrypt_xdr_buf(cipher, buf, crypt_offset); + crypto_free_skcipher(cipher); + if (err) + return GSS_S_DEFECTIVE_TOKEN; + } else { + if (gss_decrypt_xdr_buf(kctx->enc, buf, crypt_offset)) + return GSS_S_DEFECTIVE_TOKEN; + } + + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(kctx, ptr, 8, buf, crypt_offset, + cksumkey, KG_USAGE_SEAL, &md5cksum)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + kctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = get_seconds(); + + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + /* Copy the data back to the right position. XXX: Would probably be + * better to copy and encrypt at the same time. */ + + blocksize = crypto_skcipher_blocksize(kctx->enc); + data_start = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) + + conflen; + orig_start = buf->head[0].iov_base + offset; + data_len = (buf->head[0].iov_base + buf->head[0].iov_len) - data_start; + memmove(orig_start, data_start, data_len); + buf->head[0].iov_len -= (data_start - orig_start); + buf->len -= (data_start - orig_start); + + if (gss_krb5_remove_padding(buf, blocksize)) + return GSS_S_DEFECTIVE_TOKEN; + + return GSS_S_COMPLETE; +} + +/* + * We can shift data by up to LOCAL_BUF_LEN bytes in a pass. If we need + * to do more than that, we shift repeatedly. Kevin Coffman reports + * seeing 28 bytes as the value used by Microsoft clients and servers + * with AES, so this constant is chosen to allow handling 28 in one pass + * without using too much stack space. + * + * If that proves to a problem perhaps we could use a more clever + * algorithm. + */ +#define LOCAL_BUF_LEN 32u + +static void rotate_buf_a_little(struct xdr_buf *buf, unsigned int shift) +{ + char head[LOCAL_BUF_LEN]; + char tmp[LOCAL_BUF_LEN]; + unsigned int this_len, i; + + BUG_ON(shift > LOCAL_BUF_LEN); + + read_bytes_from_xdr_buf(buf, 0, head, shift); + for (i = 0; i + shift < buf->len; i += LOCAL_BUF_LEN) { + this_len = min(LOCAL_BUF_LEN, buf->len - (i + shift)); + read_bytes_from_xdr_buf(buf, i+shift, tmp, this_len); + write_bytes_to_xdr_buf(buf, i, tmp, this_len); + } + write_bytes_to_xdr_buf(buf, buf->len - shift, head, shift); +} + +static void _rotate_left(struct xdr_buf *buf, unsigned int shift) +{ + int shifted = 0; + int this_shift; + + shift %= buf->len; + while (shifted < shift) { + this_shift = min(shift - shifted, LOCAL_BUF_LEN); + rotate_buf_a_little(buf, this_shift); + shifted += this_shift; + } +} + +static void rotate_left(u32 base, struct xdr_buf *buf, unsigned int shift) +{ + struct xdr_buf subbuf; + + xdr_buf_subsegment(buf, &subbuf, base, buf->len - base); + _rotate_left(&subbuf, shift); +} + +static u32 +gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + u8 *ptr, *plainhdr; + s32 now; + u8 flags = 0x00; + __be16 *be16ptr; + __be64 *be64ptr; + u32 err; + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->encrypt_v2 == NULL) + return GSS_S_FAILURE; + + /* make room for gss token header */ + if (xdr_extend_head(buf, offset, GSS_KRB5_TOK_HDR_LEN)) + return GSS_S_FAILURE; + + /* construct gss token header */ + ptr = plainhdr = buf->head[0].iov_base + offset; + *ptr++ = (unsigned char) ((KG2_TOK_WRAP>>8) & 0xff); + *ptr++ = (unsigned char) (KG2_TOK_WRAP & 0xff); + + if ((kctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= KG2_TOKEN_FLAG_SENTBYACCEPTOR; + if ((kctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) != 0) + flags |= KG2_TOKEN_FLAG_ACCEPTORSUBKEY; + /* We always do confidentiality in wrap tokens */ + flags |= KG2_TOKEN_FLAG_SEALED; + + *ptr++ = flags; + *ptr++ = 0xff; + be16ptr = (__be16 *)ptr; + + *be16ptr++ = 0; + /* "inner" token header always uses 0 for RRC */ + *be16ptr++ = 0; + + be64ptr = (__be64 *)be16ptr; + spin_lock(&krb5_seq_lock); + *be64ptr = cpu_to_be64(kctx->seq_send64++); + spin_unlock(&krb5_seq_lock); + + err = (*kctx->gk5e->encrypt_v2)(kctx, offset, buf, pages); + if (err) + return err; + + now = get_seconds(); + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) +{ + s32 now; + u8 *ptr; + u8 flags = 0x00; + u16 ec, rrc; + int err; + u32 headskip, tailskip; + u8 decrypted_hdr[GSS_KRB5_TOK_HDR_LEN]; + unsigned int movelen; + + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->decrypt_v2 == NULL) + return GSS_S_FAILURE; + + ptr = buf->head[0].iov_base + offset; + + if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_WRAP) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!kctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (kctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if ((flags & KG2_TOKEN_FLAG_SEALED) == 0) { + dprintk("%s: token missing expected sealed flag\n", __func__); + return GSS_S_DEFECTIVE_TOKEN; + } + + if (ptr[3] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + ec = be16_to_cpup((__be16 *)(ptr + 4)); + rrc = be16_to_cpup((__be16 *)(ptr + 6)); + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + if (rrc != 0) + rotate_left(offset + 16, buf, rrc); + + err = (*kctx->gk5e->decrypt_v2)(kctx, offset, buf, + &headskip, &tailskip); + if (err) + return GSS_S_FAILURE; + + /* + * Retrieve the decrypted gss token header and verify + * it against the original + */ + err = read_bytes_from_xdr_buf(buf, + buf->len - GSS_KRB5_TOK_HDR_LEN - tailskip, + decrypted_hdr, GSS_KRB5_TOK_HDR_LEN); + if (err) { + dprintk("%s: error %u getting decrypted_hdr\n", __func__, err); + return GSS_S_FAILURE; + } + if (memcmp(ptr, decrypted_hdr, 6) + || memcmp(ptr + 8, decrypted_hdr + 8, 8)) { + dprintk("%s: token hdr, plaintext hdr mismatch!\n", __func__); + return GSS_S_FAILURE; + } + + /* do sequencing checks */ + + /* it got through unscathed. Make sure the context is unexpired */ + now = get_seconds(); + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * Move the head data back to the right position in xdr_buf. + * We ignore any "ec" data since it might be in the head or + * the tail, and we really don't need to deal with it. + * Note that buf->head[0].iov_len may indicate the available + * head buffer space rather than that actually occupied. + */ + movelen = min_t(unsigned int, buf->head[0].iov_len, buf->len); + movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip; + BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > + buf->head[0].iov_len); + memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen); + buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip; + buf->len -= GSS_KRB5_TOK_HDR_LEN + headskip; + + /* Trim off the trailing "extra count" and checksum blob */ + xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + return GSS_S_COMPLETE; +} + +u32 +gss_wrap_kerberos(struct gss_ctx *gctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_wrap_kerberos_v1(kctx, offset, buf, pages); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_wrap_kerberos_v2(kctx, offset, buf, pages); + } +} + +u32 +gss_unwrap_kerberos(struct gss_ctx *gctx, int offset, struct xdr_buf *buf) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_unwrap_kerberos_v1(kctx, offset, buf); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_unwrap_kerberos_v2(kctx, offset, buf); + } +} diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c new file mode 100644 index 000000000..c7d88f979 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -0,0 +1,499 @@ +/* + * linux/net/sunrpc/gss_mech_switch.c + * + * Copyright (c) 2001 The Regents of the University of Michigan. + * All rights reserved. + * + * J. Bruce Fields <bfields@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/module.h> +#include <linux/oid_registry.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/sunrpc/gss_asn1.h> +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/clnt.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static LIST_HEAD(registered_mechs); +static DEFINE_SPINLOCK(registered_mechs_lock); + +static void +gss_mech_free(struct gss_api_mech *gm) +{ + struct pf_desc *pf; + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + if (pf->domain) + auth_domain_put(pf->domain); + kfree(pf->auth_domain_name); + pf->auth_domain_name = NULL; + } +} + +static inline char * +make_auth_domain_name(char *name) +{ + static char *prefix = "gss/"; + char *new; + + new = kmalloc(strlen(name) + strlen(prefix) + 1, GFP_KERNEL); + if (new) { + strcpy(new, prefix); + strcat(new, name); + } + return new; +} + +static int +gss_mech_svc_setup(struct gss_api_mech *gm) +{ + struct auth_domain *dom; + struct pf_desc *pf; + int i, status; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + pf->auth_domain_name = make_auth_domain_name(pf->name); + status = -ENOMEM; + if (pf->auth_domain_name == NULL) + goto out; + dom = svcauth_gss_register_pseudoflavor( + pf->pseudoflavor, pf->auth_domain_name); + if (IS_ERR(dom)) { + status = PTR_ERR(dom); + goto out; + } + pf->domain = dom; + } + return 0; +out: + gss_mech_free(gm); + return status; +} + +/** + * gss_mech_register - register a GSS mechanism + * @gm: GSS mechanism handle + * + * Returns zero if successful, or a negative errno. + */ +int gss_mech_register(struct gss_api_mech *gm) +{ + int status; + + status = gss_mech_svc_setup(gm); + if (status) + return status; + spin_lock(®istered_mechs_lock); + list_add(&gm->gm_list, ®istered_mechs); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); + return 0; +} +EXPORT_SYMBOL_GPL(gss_mech_register); + +/** + * gss_mech_unregister - release a GSS mechanism + * @gm: GSS mechanism handle + * + */ +void gss_mech_unregister(struct gss_api_mech *gm) +{ + spin_lock(®istered_mechs_lock); + list_del(&gm->gm_list); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); + gss_mech_free(gm); +} +EXPORT_SYMBOL_GPL(gss_mech_unregister); + +struct gss_api_mech *gss_mech_get(struct gss_api_mech *gm) +{ + __module_get(gm->gm_owner); + return gm; +} +EXPORT_SYMBOL(gss_mech_get); + +static struct gss_api_mech * +_gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *pos, *gm = NULL; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (0 == strcmp(name, pos->gm_name)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + spin_unlock(®istered_mechs_lock); + return gm; + +} + +struct gss_api_mech * gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *gm = NULL; + + gm = _gss_mech_get_by_name(name); + if (!gm) { + request_module("rpc-auth-gss-%s", name); + gm = _gss_mech_get_by_name(name); + } + return gm; +} + +struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) +{ + struct gss_api_mech *pos, *gm = NULL; + char buf[32]; + + if (sprint_oid(obj->data, obj->len, buf, sizeof(buf)) < 0) + return NULL; + dprintk("RPC: %s(%s)\n", __func__, buf); + request_module("rpc-auth-gss-%s", buf); + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (obj->len == pos->gm_oid.len) { + if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + } + spin_unlock(®istered_mechs_lock); + return gm; +} + +static inline int +mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return 1; + } + return 0; +} + +static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *gm = NULL, *pos; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (!mech_supports_pseudoflavor(pos, pseudoflavor)) + continue; + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + spin_unlock(®istered_mechs_lock); + return gm; +} + +struct gss_api_mech * +gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *gm; + + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + + if (!gm) { + request_module("rpc-auth-gss-%u", pseudoflavor); + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + } + return gm; +} + +/** + * gss_mech_list_pseudoflavors - Discover registered GSS pseudoflavors + * @array: array to fill in + * @size: size of "array" + * + * Returns the number of array items filled in, or a negative errno. + * + * The returned array is not sorted by any policy. Callers should not + * rely on the order of the items in the returned array. + */ +int gss_mech_list_pseudoflavors(rpc_authflavor_t *array_ptr, int size) +{ + struct gss_api_mech *pos = NULL; + int j, i = 0; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + for (j = 0; j < pos->gm_pf_num; j++) { + if (i >= size) { + spin_unlock(®istered_mechs_lock); + return -ENOMEM; + } + array_ptr[i++] = pos->gm_pfs[j].pseudoflavor; + } + } + spin_unlock(®istered_mechs_lock); + return i; +} + +/** + * gss_svc_to_pseudoflavor - map a GSS service number to a pseudoflavor + * @gm: GSS mechanism handle + * @qop: GSS quality-of-protection value + * @service: GSS service value + * + * Returns a matching security flavor, or RPC_AUTH_MAXFLAVOR if none is found. + */ +rpc_authflavor_t gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 qop, + u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].qop == qop && + gm->gm_pfs[i].service == service) { + return gm->gm_pfs[i].pseudoflavor; + } + } + return RPC_AUTH_MAXFLAVOR; +} + +/** + * gss_mech_info2flavor - look up a pseudoflavor given a GSS tuple + * @info: a GSS mech OID, quality of protection, and service value + * + * Returns a matching pseudoflavor, or RPC_AUTH_MAXFLAVOR if the tuple is + * not supported. + */ +rpc_authflavor_t gss_mech_info2flavor(struct rpcsec_gss_info *info) +{ + rpc_authflavor_t pseudoflavor; + struct gss_api_mech *gm; + + gm = gss_mech_get_by_OID(&info->oid); + if (gm == NULL) + return RPC_AUTH_MAXFLAVOR; + + pseudoflavor = gss_svc_to_pseudoflavor(gm, info->qop, info->service); + + gss_mech_put(gm); + return pseudoflavor; +} + +/** + * gss_mech_flavor2info - look up a GSS tuple for a given pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. Otherwise a negative errno is returned. + */ +int gss_mech_flavor2info(rpc_authflavor_t pseudoflavor, + struct rpcsec_gss_info *info) +{ + struct gss_api_mech *gm; + int i; + + gm = gss_mech_get_by_pseudoflavor(pseudoflavor); + if (gm == NULL) + return -ENOENT; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) { + memcpy(info->oid.data, gm->gm_oid.data, gm->gm_oid.len); + info->oid.len = gm->gm_oid.len; + info->qop = gm->gm_pfs[i].qop; + info->service = gm->gm_pfs[i].service; + gss_mech_put(gm); + return 0; + } + } + + gss_mech_put(gm); + return -ENOENT; +} + +u32 +gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].service; + } + return 0; +} +EXPORT_SYMBOL(gss_pseudoflavor_to_service); + +bool +gss_pseudoflavor_to_datatouch(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].datatouch; + } + return false; +} + +char * +gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].service == service) + return gm->gm_pfs[i].auth_domain_name; + } + return NULL; +} + +void +gss_mech_put(struct gss_api_mech * gm) +{ + if (gm) + module_put(gm->gm_owner); +} +EXPORT_SYMBOL(gss_mech_put); + +/* The mech could probably be determined from the token instead, but it's just + * as easy for now to pass it in. */ +int +gss_import_sec_context(const void *input_token, size_t bufsize, + struct gss_api_mech *mech, + struct gss_ctx **ctx_id, + time_t *endtime, + gfp_t gfp_mask) +{ + if (!(*ctx_id = kzalloc(sizeof(**ctx_id), gfp_mask))) + return -ENOMEM; + (*ctx_id)->mech_type = gss_mech_get(mech); + + return mech->gm_ops->gss_import_sec_context(input_token, bufsize, + *ctx_id, endtime, gfp_mask); +} + +/* gss_get_mic: compute a mic over message and return mic_token. */ + +u32 +gss_get_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_get_mic(context_handle, + message, + mic_token); +} + +/* gss_verify_mic: check whether the provided mic_token verifies message. */ + +u32 +gss_verify_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_verify_mic(context_handle, + message, + mic_token); +} + +/* + * This function is called from both the client and server code. + * Each makes guarantees about how much "slack" space is available + * for the underlying function in "buf"'s head and tail while + * performing the wrap. + * + * The client and server code allocate RPC_MAX_AUTH_SIZE extra + * space in both the head and tail which is available for use by + * the wrap function. + * + * Underlying functions should verify they do not use more than + * RPC_MAX_AUTH_SIZE of extra space in either the head or tail + * when performing the wrap. + */ +u32 +gss_wrap(struct gss_ctx *ctx_id, + int offset, + struct xdr_buf *buf, + struct page **inpages) +{ + return ctx_id->mech_type->gm_ops + ->gss_wrap(ctx_id, offset, buf, inpages); +} + +u32 +gss_unwrap(struct gss_ctx *ctx_id, + int offset, + struct xdr_buf *buf) +{ + return ctx_id->mech_type->gm_ops + ->gss_unwrap(ctx_id, offset, buf); +} + + +/* gss_delete_sec_context: free all resources associated with context_handle. + * Note this differs from the RFC 2744-specified prototype in that we don't + * bother returning an output token, since it would never be used anyway. */ + +u32 +gss_delete_sec_context(struct gss_ctx **context_handle) +{ + dprintk("RPC: gss_delete_sec_context deleting %p\n", + *context_handle); + + if (!*context_handle) + return GSS_S_NO_CONTEXT; + if ((*context_handle)->internal_ctx_id) + (*context_handle)->mech_type->gm_ops + ->gss_delete_sec_context((*context_handle) + ->internal_ctx_id); + gss_mech_put((*context_handle)->mech_type); + kfree(*context_handle); + *context_handle=NULL; + return GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c new file mode 100644 index 000000000..73dcda060 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c @@ -0,0 +1,410 @@ +/* + * linux/net/sunrpc/gss_rpc_upcall.c + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#include <linux/types.h> +#include <linux/un.h> + +#include <linux/sunrpc/svcauth.h> +#include "gss_rpc_upcall.h" + +#define GSSPROXY_SOCK_PATHNAME "/var/run/gssproxy.sock" + +#define GSSPROXY_PROGRAM (400112u) +#define GSSPROXY_VERS_1 (1u) + +/* + * Encoding/Decoding functions + */ + +enum { + GSSX_NULL = 0, /* Unused */ + GSSX_INDICATE_MECHS = 1, + GSSX_GET_CALL_CONTEXT = 2, + GSSX_IMPORT_AND_CANON_NAME = 3, + GSSX_EXPORT_CRED = 4, + GSSX_IMPORT_CRED = 5, + GSSX_ACQUIRE_CRED = 6, + GSSX_STORE_CRED = 7, + GSSX_INIT_SEC_CONTEXT = 8, + GSSX_ACCEPT_SEC_CONTEXT = 9, + GSSX_RELEASE_HANDLE = 10, + GSSX_GET_MIC = 11, + GSSX_VERIFY = 12, + GSSX_WRAP = 13, + GSSX_UNWRAP = 14, + GSSX_WRAP_SIZE_LIMIT = 15, +}; + +#define PROC(proc, name) \ +[GSSX_##proc] = { \ + .p_proc = GSSX_##proc, \ + .p_encode = gssx_enc_##name, \ + .p_decode = gssx_dec_##name, \ + .p_arglen = GSSX_ARG_##name##_sz, \ + .p_replen = GSSX_RES_##name##_sz, \ + .p_statidx = GSSX_##proc, \ + .p_name = #proc, \ +} + +static const struct rpc_procinfo gssp_procedures[] = { + PROC(INDICATE_MECHS, indicate_mechs), + PROC(GET_CALL_CONTEXT, get_call_context), + PROC(IMPORT_AND_CANON_NAME, import_and_canon_name), + PROC(EXPORT_CRED, export_cred), + PROC(IMPORT_CRED, import_cred), + PROC(ACQUIRE_CRED, acquire_cred), + PROC(STORE_CRED, store_cred), + PROC(INIT_SEC_CONTEXT, init_sec_context), + PROC(ACCEPT_SEC_CONTEXT, accept_sec_context), + PROC(RELEASE_HANDLE, release_handle), + PROC(GET_MIC, get_mic), + PROC(VERIFY, verify), + PROC(WRAP, wrap), + PROC(UNWRAP, unwrap), + PROC(WRAP_SIZE_LIMIT, wrap_size_limit), +}; + + + +/* + * Common transport functions + */ + +static const struct rpc_program gssp_program; + +static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt) +{ + static const struct sockaddr_un gssp_localaddr = { + .sun_family = AF_LOCAL, + .sun_path = GSSPROXY_SOCK_PATHNAME, + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&gssp_localaddr, + .addrsize = sizeof(gssp_localaddr), + .servername = "localhost", + .program = &gssp_program, + .version = GSSPROXY_VERS_1, + .authflavor = RPC_AUTH_NULL, + /* + * Note we want connection to be done in the caller's + * filesystem namespace. We therefore turn off the idle + * timeout, which would result in reconnections being + * done without the correct namespace: + */ + .flags = RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_NO_IDLE_TIMEOUT + }; + struct rpc_clnt *clnt; + int result = 0; + + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create AF_LOCAL gssproxy " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + *_clnt = NULL; + goto out; + } + + dprintk("RPC: created new gssp local client (gssp_local_clnt: " + "%p)\n", clnt); + *_clnt = clnt; + +out: + return result; +} + +void init_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_init(&sn->gssp_lock); + sn->gssp_clnt = NULL; +} + +int set_gssp_clnt(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int ret; + + mutex_lock(&sn->gssp_lock); + ret = gssp_rpc_create(net, &clnt); + if (!ret) { + if (sn->gssp_clnt) + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = clnt; + } + mutex_unlock(&sn->gssp_lock); + return ret; +} + +void clear_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_lock(&sn->gssp_lock); + if (sn->gssp_clnt) { + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = NULL; + } + mutex_unlock(&sn->gssp_lock); +} + +static struct rpc_clnt *get_gssp_clnt(struct sunrpc_net *sn) +{ + struct rpc_clnt *clnt; + + mutex_lock(&sn->gssp_lock); + clnt = sn->gssp_clnt; + if (clnt) + atomic_inc(&clnt->cl_count); + mutex_unlock(&sn->gssp_lock); + return clnt; +} + +static int gssp_call(struct net *net, struct rpc_message *msg) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int status; + + clnt = get_gssp_clnt(sn); + if (!clnt) + return -EIO; + status = rpc_call_sync(clnt, msg, 0); + if (status < 0) { + dprintk("gssp: rpc_call returned error %d\n", -status); + switch (status) { + case -EPROTONOSUPPORT: + status = -EINVAL; + break; + case -ECONNREFUSED: + case -ETIMEDOUT: + case -ENOTCONN: + status = -EAGAIN; + break; + case -ERESTARTSYS: + if (signalled ()) + status = -EINTR; + break; + default: + break; + } + } + rpc_release_client(clnt); + return status; +} + +static void gssp_free_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + int i; + + for (i = 0; i < arg->npages && arg->pages[i]; i++) + __free_page(arg->pages[i]); + + kfree(arg->pages); +} + +static int gssp_alloc_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + arg->npages = DIV_ROUND_UP(NGROUPS_MAX * 4, PAGE_SIZE); + arg->pages = kcalloc(arg->npages, sizeof(struct page *), GFP_KERNEL); + /* + * XXX: actual pages are allocated by xdr layer in + * xdr_partial_copy_from_skb. + */ + if (!arg->pages) + return -ENOMEM; + return 0; +} + +static char *gssp_stringify(struct xdr_netobj *netobj) +{ + return kstrndup(netobj->data, netobj->len, GFP_KERNEL); +} + +static void gssp_hostbased_service(char **principal) +{ + char *c; + + if (!*principal) + return; + + /* terminate and remove realm part */ + c = strchr(*principal, '@'); + if (c) { + *c = '\0'; + + /* change service-hostname delimiter */ + c = strchr(*principal, '/'); + if (c) + *c = '@'; + } + if (!c) { + /* not a service principal */ + kfree(*principal); + *principal = NULL; + } +} + +/* + * Public functions + */ + +/* numbers somewhat arbitrary but large enough for current needs */ +#define GSSX_MAX_OUT_HANDLE 128 +#define GSSX_MAX_SRC_PRINC 256 +#define GSSX_KMEMBUF (GSSX_max_output_handle_sz + \ + GSSX_max_oid_sz + \ + GSSX_max_princ_sz + \ + sizeof(struct svc_cred)) + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data) +{ + struct gssx_ctx ctxh = { + .state = data->in_handle + }; + struct gssx_arg_accept_sec_context arg = { + .input_token = data->in_token, + }; + struct gssx_ctx rctxh = { + /* + * pass in the max length we expect for each of these + * buffers but let the xdr code kmalloc them: + */ + .exported_context_token.len = GSSX_max_output_handle_sz, + .mech.len = GSS_OID_MAX_LEN, + .targ_name.display_name.len = GSSX_max_princ_sz, + .src_name.display_name.len = GSSX_max_princ_sz + }; + struct gssx_res_accept_sec_context res = { + .context_handle = &rctxh, + .output_token = &data->out_token + }; + struct rpc_message msg = { + .rpc_proc = &gssp_procedures[GSSX_ACCEPT_SEC_CONTEXT], + .rpc_argp = &arg, + .rpc_resp = &res, + .rpc_cred = NULL, /* FIXME ? */ + }; + struct xdr_netobj client_name = { 0 , NULL }; + struct xdr_netobj target_name = { 0, NULL }; + int ret; + + if (data->in_handle.len != 0) + arg.context_handle = &ctxh; + res.output_token->len = GSSX_max_output_token_sz; + + ret = gssp_alloc_receive_pages(&arg); + if (ret) + return ret; + + ret = gssp_call(net, &msg); + + gssp_free_receive_pages(&arg); + + /* we need to fetch all data even in case of error so + * that we can free special strctures is they have been allocated */ + data->major_status = res.status.major_status; + data->minor_status = res.status.minor_status; + if (res.context_handle) { + data->out_handle = rctxh.exported_context_token; + data->mech_oid.len = rctxh.mech.len; + if (rctxh.mech.data) { + memcpy(data->mech_oid.data, rctxh.mech.data, + data->mech_oid.len); + kfree(rctxh.mech.data); + } + client_name = rctxh.src_name.display_name; + target_name = rctxh.targ_name.display_name; + } + + if (res.options.count == 1) { + gssx_buffer *value = &res.options.data[0].value; + /* Currently we only decode CREDS_VALUE, if we add + * anything else we'll have to loop and match on the + * option name */ + if (value->len == 1) { + /* steal group info from struct svc_cred */ + data->creds = *(struct svc_cred *)value->data; + data->found_creds = 1; + } + /* whether we use it or not, free data */ + kfree(value->data); + } + + if (res.options.count != 0) { + kfree(res.options.data); + } + + /* convert to GSS_NT_HOSTBASED_SERVICE form and set into creds */ + if (data->found_creds) { + if (client_name.data) { + data->creds.cr_raw_principal = + gssp_stringify(&client_name); + data->creds.cr_principal = + gssp_stringify(&client_name); + gssp_hostbased_service(&data->creds.cr_principal); + } + if (target_name.data) { + data->creds.cr_targ_princ = + gssp_stringify(&target_name); + gssp_hostbased_service(&data->creds.cr_targ_princ); + } + } + kfree(client_name.data); + kfree(target_name.data); + + return ret; +} + +void gssp_free_upcall_data(struct gssp_upcall_data *data) +{ + kfree(data->in_handle.data); + kfree(data->out_handle.data); + kfree(data->out_token.data); + free_svc_cred(&data->creds); +} + +/* + * Initialization stuff + */ +static unsigned int gssp_version1_counts[ARRAY_SIZE(gssp_procedures)]; +static const struct rpc_version gssp_version1 = { + .number = GSSPROXY_VERS_1, + .nrprocs = ARRAY_SIZE(gssp_procedures), + .procs = gssp_procedures, + .counts = gssp_version1_counts, +}; + +static const struct rpc_version *gssp_version[] = { + NULL, + &gssp_version1, +}; + +static struct rpc_stat gssp_stats; + +static const struct rpc_program gssp_program = { + .name = "gssproxy", + .number = GSSPROXY_PROGRAM, + .nrvers = ARRAY_SIZE(gssp_version), + .version = gssp_version, + .stats = &gssp_stats, +}; diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h new file mode 100644 index 000000000..1e542aded --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h @@ -0,0 +1,48 @@ +/* + * linux/net/sunrpc/gss_rpc_upcall.h + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#ifndef _GSS_RPC_UPCALL_H +#define _GSS_RPC_UPCALL_H + +#include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/auth_gss.h> +#include "gss_rpc_xdr.h" +#include "../netns.h" + +struct gssp_upcall_data { + struct xdr_netobj in_handle; + struct gssp_in_token in_token; + struct xdr_netobj out_handle; + struct xdr_netobj out_token; + struct rpcsec_gss_oid mech_oid; + struct svc_cred creds; + int found_creds; + int major_status; + int minor_status; +}; + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data); +void gssp_free_upcall_data(struct gssp_upcall_data *data); + +void init_gssp_clnt(struct sunrpc_net *); +int set_gssp_clnt(struct net *); +void clear_gssp_clnt(struct sunrpc_net *); +#endif /* _GSS_RPC_UPCALL_H */ diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c new file mode 100644 index 000000000..444380f96 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -0,0 +1,851 @@ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#include <linux/sunrpc/svcauth.h> +#include "gss_rpc_xdr.h" + +static int gssx_enc_bool(struct xdr_stream *xdr, int v) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *p = v ? xdr_one : xdr_zero; + return 0; +} + +static int gssx_dec_bool(struct xdr_stream *xdr, u32 *v) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *v = be32_to_cpu(*p); + return 0; +} + +static int gssx_enc_buffer(struct xdr_stream *xdr, + const gssx_buffer *buf) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, sizeof(u32) + buf->len); + if (!p) + return -ENOSPC; + xdr_encode_opaque(p, buf->data, buf->len); + return 0; +} + +static int gssx_enc_in_token(struct xdr_stream *xdr, + const struct gssp_in_token *in) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = cpu_to_be32(in->page_len); + + /* all we need to do is to write pages */ + xdr_write_pages(xdr, in->pages, in->page_base, in->page_len); + + return 0; +} + + +static int gssx_dec_buffer(struct xdr_stream *xdr, + gssx_buffer *buf) +{ + u32 length; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (buf->len == 0) { + /* we intentionally are not interested in this buffer */ + return 0; + } + if (length > buf->len) + return -ENOSPC; + + if (!buf->data) { + buf->data = kmemdup(p, length, GFP_KERNEL); + if (!buf->data) + return -ENOMEM; + } else { + memcpy(buf->data, p, length); + } + buf->len = length; + return 0; +} + +static int gssx_enc_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_enc_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_enc_buffer(xdr, &opt->value); + return err; +} + +static int gssx_dec_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_dec_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_dec_buffer(xdr, &opt->value); + return err; +} + +static int dummy_enc_opt_array(struct xdr_stream *xdr, + const struct gssx_option_array *oa) +{ + __be32 *p; + + if (oa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_opt_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct gssx_option dummy; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + memset(&dummy, 0, sizeof(dummy)); + for (i = 0; i < count; i++) { + gssx_dec_option(xdr, &dummy); + } + + oa->count = 0; + oa->data = NULL; + return 0; +} + +static int get_host_u32(struct xdr_stream *xdr, u32 *res) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (!p) + return -EINVAL; + /* Contents of linux creds are all host-endian: */ + memcpy(res, p, sizeof(u32)); + return 0; +} + +static int gssx_dec_linux_creds(struct xdr_stream *xdr, + struct svc_cred *creds) +{ + u32 length; + __be32 *p; + u32 tmp; + u32 N; + int i, err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + + if (length > (3 + NGROUPS_MAX) * sizeof(u32)) + return -ENOSPC; + + /* uid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_uid = make_kuid(&init_user_ns, tmp); + + /* gid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_gid = make_kgid(&init_user_ns, tmp); + + /* number of additional gid's */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + N = tmp; + if ((3 + N) * sizeof(u32) != length) + return -EINVAL; + creds->cr_group_info = groups_alloc(N); + if (creds->cr_group_info == NULL) + return -ENOMEM; + + /* gid's */ + for (i = 0; i < N; i++) { + kgid_t kgid; + err = get_host_u32(xdr, &tmp); + if (err) + goto out_free_groups; + err = -EINVAL; + kgid = make_kgid(&init_user_ns, tmp); + if (!gid_valid(kgid)) + goto out_free_groups; + creds->cr_group_info->gid[i] = kgid; + } + groups_sort(creds->cr_group_info); + + return 0; +out_free_groups: + groups_free(creds->cr_group_info); + return err; +} + +static int gssx_dec_option_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct svc_cred *creds; + u32 count, i; + __be32 *p; + int err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + if (!count) + return 0; + + /* we recognize only 1 currently: CREDS_VALUE */ + oa->count = 1; + + oa->data = kmalloc(sizeof(struct gssx_option), GFP_KERNEL); + if (!oa->data) + return -ENOMEM; + + creds = kzalloc(sizeof(struct svc_cred), GFP_KERNEL); + if (!creds) { + kfree(oa->data); + return -ENOMEM; + } + + oa->data[0].option.data = CREDS_VALUE; + oa->data[0].option.len = sizeof(CREDS_VALUE); + oa->data[0].value.data = (void *)creds; + oa->data[0].value.len = 0; + + for (i = 0; i < count; i++) { + gssx_buffer dummy = { 0, NULL }; + u32 length; + + /* option buffer */ + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (length == sizeof(CREDS_VALUE) && + memcmp(p, CREDS_VALUE, sizeof(CREDS_VALUE)) == 0) { + /* We have creds here. parse them */ + err = gssx_dec_linux_creds(xdr, creds); + if (err) + return err; + oa->data[0].value.len = 1; /* presence */ + } else { + /* consume uninteresting buffer */ + err = gssx_dec_buffer(xdr, &dummy); + if (err) + return err; + } + } + return 0; +} + +static int gssx_dec_status(struct xdr_stream *xdr, + struct gssx_status *status) +{ + __be32 *p; + int err; + + /* status->major_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->major_status); + + /* status->mech */ + err = gssx_dec_buffer(xdr, &status->mech); + if (err) + return err; + + /* status->minor_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->minor_status); + + /* status->major_status_string */ + err = gssx_dec_buffer(xdr, &status->major_status_string); + if (err) + return err; + + /* status->minor_status_string */ + err = gssx_dec_buffer(xdr, &status->minor_status_string); + if (err) + return err; + + /* status->server_ctx */ + err = gssx_dec_buffer(xdr, &status->server_ctx); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* status->options */ + err = dummy_dec_opt_array(xdr, &status->options); + + return err; +} + +static int gssx_enc_call_ctx(struct xdr_stream *xdr, + const struct gssx_call_ctx *ctx) +{ + struct gssx_option opt; + __be32 *p; + int err; + + /* ctx->locale */ + err = gssx_enc_buffer(xdr, &ctx->locale); + if (err) + return err; + + /* ctx->server_ctx */ + err = gssx_enc_buffer(xdr, &ctx->server_ctx); + if (err) + return err; + + /* we always want to ask for lucid contexts */ + /* ctx->options */ + p = xdr_reserve_space(xdr, 4); + *p = cpu_to_be32(2); + + /* we want a lucid_v1 context */ + opt.option.data = LUCID_OPTION; + opt.option.len = sizeof(LUCID_OPTION); + opt.value.data = LUCID_VALUE; + opt.value.len = sizeof(LUCID_VALUE); + err = gssx_enc_option(xdr, &opt); + + /* ..and user creds */ + opt.option.data = CREDS_OPTION; + opt.option.len = sizeof(CREDS_OPTION); + opt.value.data = CREDS_VALUE; + opt.value.len = sizeof(CREDS_VALUE); + err = gssx_enc_option(xdr, &opt); + + return err; +} + +static int gssx_dec_name_attr(struct xdr_stream *xdr, + struct gssx_name_attr *attr) +{ + int err; + + /* attr->attr */ + err = gssx_dec_buffer(xdr, &attr->attr); + if (err) + return err; + + /* attr->value */ + err = gssx_dec_buffer(xdr, &attr->value); + if (err) + return err; + + /* attr->extensions */ + err = dummy_dec_opt_array(xdr, &attr->extensions); + + return err; +} + +static int dummy_enc_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + __be32 *p; + + if (naa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + struct gssx_name_attr dummy = { .attr = {.len = 0} }; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + for (i = 0; i < count; i++) { + gssx_dec_name_attr(xdr, &dummy); + } + + naa->count = 0; + naa->data = NULL; + return 0; +} + +static struct xdr_netobj zero_netobj = {}; + +static struct gssx_name_attr_array zero_name_attr_array = {}; + +static struct gssx_option_array zero_option_array = {}; + +static int gssx_enc_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + int err; + + /* name->display_name */ + err = gssx_enc_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* leave name_attributes empty for now, will add once we have any + * to pass up at all */ + /* name->name_attributes */ + err = dummy_enc_nameattr_array(xdr, &zero_name_attr_array); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* name->extensions */ + err = dummy_enc_opt_array(xdr, &zero_option_array); + + return err; +} + + +static int gssx_dec_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + struct xdr_netobj dummy_netobj = { .len = 0 }; + struct gssx_name_attr_array dummy_name_attr_array = { .count = 0 }; + struct gssx_option_array dummy_option_array = { .count = 0 }; + int err; + + /* name->display_name */ + err = gssx_dec_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* we assume we have no attributes for now, so simply consume them */ + /* name->name_attributes */ + err = dummy_dec_nameattr_array(xdr, &dummy_name_attr_array); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* name->extensions */ + err = dummy_dec_opt_array(xdr, &dummy_option_array); + + return err; +} + +static int dummy_enc_credel_array(struct xdr_stream *xdr, + struct gssx_cred_element_array *cea) +{ + __be32 *p; + + if (cea->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int gssx_enc_cred(struct xdr_stream *xdr, + struct gssx_cred *cred) +{ + int err; + + /* cred->desired_name */ + err = gssx_enc_name(xdr, &cred->desired_name); + if (err) + return err; + + /* cred->elements */ + err = dummy_enc_credel_array(xdr, &cred->elements); + if (err) + return err; + + /* cred->cred_handle_reference */ + err = gssx_enc_buffer(xdr, &cred->cred_handle_reference); + if (err) + return err; + + /* cred->needs_release */ + err = gssx_enc_bool(xdr, cred->needs_release); + + return err; +} + +static int gssx_enc_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_enc_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_enc_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_enc_bool(xdr, ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_enc_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_enc_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_enc_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_reserve_space(xdr, 8+8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_encode_hyper(p, ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_enc_bool(xdr, ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_enc_bool(xdr, ctx->open); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* ctx->options */ + err = dummy_enc_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_dec_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_dec_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_dec_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_dec_bool(xdr, &ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_dec_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_dec_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_dec_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_inline_decode(xdr, 8+8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_decode_hyper(p, &ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_dec_bool(xdr, &ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_dec_bool(xdr, &ctx->open); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* ctx->options */ + err = dummy_dec_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_enc_cb(struct xdr_stream *xdr, struct gssx_cb *cb) +{ + __be32 *p; + int err; + + /* cb->initiator_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->initiator_addrtype); + + /* cb->initiator_address */ + err = gssx_enc_buffer(xdr, &cb->initiator_address); + if (err) + return err; + + /* cb->acceptor_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->acceptor_addrtype); + + /* cb->acceptor_address */ + err = gssx_enc_buffer(xdr, &cb->acceptor_address); + if (err) + return err; + + /* cb->application_data */ + err = gssx_enc_buffer(xdr, &cb->application_data); + + return err; +} + +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + const void *data) +{ + const struct gssx_arg_accept_sec_context *arg = data; + int err; + + err = gssx_enc_call_ctx(xdr, &arg->call_ctx); + if (err) + goto done; + + /* arg->context_handle */ + if (arg->context_handle) + err = gssx_enc_ctx(xdr, arg->context_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->cred_handle */ + if (arg->cred_handle) + err = gssx_enc_cred(xdr, arg->cred_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->input_token */ + err = gssx_enc_in_token(xdr, &arg->input_token); + if (err) + goto done; + + /* arg->input_cb */ + if (arg->input_cb) + err = gssx_enc_cb(xdr, arg->input_cb); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + err = gssx_enc_bool(xdr, arg->ret_deleg_cred); + if (err) + goto done; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* arg->options */ + err = dummy_enc_opt_array(xdr, &arg->options); + + xdr_inline_pages(&req->rq_rcv_buf, + PAGE_SIZE/2 /* pretty arbitrary */, + arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE); +done: + if (err) + dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err); +} + +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + void *data) +{ + struct gssx_res_accept_sec_context *res = data; + u32 value_follows; + int err; + struct page *scratch; + + scratch = alloc_page(GFP_KERNEL); + if (!scratch) + return -ENOMEM; + xdr_set_scratch_buffer(xdr, page_address(scratch), PAGE_SIZE); + + /* res->status */ + err = gssx_dec_status(xdr, &res->status); + if (err) + goto out_free; + + /* res->context_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + err = gssx_dec_ctx(xdr, res->context_handle); + if (err) + goto out_free; + } else { + res->context_handle = NULL; + } + + /* res->output_token */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + err = gssx_dec_buffer(xdr, res->output_token); + if (err) + goto out_free; + } else { + res->output_token = NULL; + } + + /* res->delegated_cred_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + /* we do not support upcall servers sending this data. */ + err = -EINVAL; + goto out_free; + } + + /* res->options */ + err = gssx_dec_option_array(xdr, &res->options); + +out_free: + __free_page(scratch); + return err; +} diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h new file mode 100644 index 000000000..146c31032 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h @@ -0,0 +1,267 @@ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#ifndef _LINUX_GSS_RPC_XDR_H +#define _LINUX_GSS_RPC_XDR_H + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/xprtsock.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define LUCID_OPTION "exported_context_type" +#define LUCID_VALUE "linux_lucid_v1" +#define CREDS_OPTION "exported_creds_type" +#define CREDS_VALUE "linux_creds_v1" + +typedef struct xdr_netobj gssx_buffer; +typedef struct xdr_netobj utf8string; +typedef struct xdr_netobj gssx_OID; + +enum gssx_cred_usage { + GSSX_C_INITIATE = 1, + GSSX_C_ACCEPT = 2, + GSSX_C_BOTH = 3, +}; + +struct gssx_option { + gssx_buffer option; + gssx_buffer value; +}; + +struct gssx_option_array { + u32 count; + struct gssx_option *data; +}; + +struct gssx_status { + u64 major_status; + gssx_OID mech; + u64 minor_status; + utf8string major_status_string; + utf8string minor_status_string; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_call_ctx { + utf8string locale; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_name_attr { + gssx_buffer attr; + gssx_buffer value; + struct gssx_option_array extensions; +}; + +struct gssx_name_attr_array { + u32 count; + struct gssx_name_attr *data; +}; + +struct gssx_name { + gssx_buffer display_name; +}; +typedef struct gssx_name gssx_name; + +struct gssx_cred_element { + gssx_name MN; + gssx_OID mech; + u32 cred_usage; + u64 initiator_time_rec; + u64 acceptor_time_rec; + struct gssx_option_array options; +}; + +struct gssx_cred_element_array { + u32 count; + struct gssx_cred_element *data; +}; + +struct gssx_cred { + gssx_name desired_name; + struct gssx_cred_element_array elements; + gssx_buffer cred_handle_reference; + u32 needs_release; +}; + +struct gssx_ctx { + gssx_buffer exported_context_token; + gssx_buffer state; + u32 need_release; + gssx_OID mech; + gssx_name src_name; + gssx_name targ_name; + u64 lifetime; + u64 ctx_flags; + u32 locally_initiated; + u32 open; + struct gssx_option_array options; +}; + +struct gssx_cb { + u64 initiator_addrtype; + gssx_buffer initiator_address; + u64 acceptor_addrtype; + gssx_buffer acceptor_address; + gssx_buffer application_data; +}; + + +/* This structure is not defined in the protocol. + * It is used in the kernel to carry around a big buffer + * as a set of pages */ +struct gssp_in_token { + struct page **pages; /* Array of contiguous pages */ + unsigned int page_base; /* Start of page data */ + unsigned int page_len; /* Length of page data */ +}; + +struct gssx_arg_accept_sec_context { + struct gssx_call_ctx call_ctx; + struct gssx_ctx *context_handle; + struct gssx_cred *cred_handle; + struct gssp_in_token input_token; + struct gssx_cb *input_cb; + u32 ret_deleg_cred; + struct gssx_option_array options; + struct page **pages; + unsigned int npages; +}; + +struct gssx_res_accept_sec_context { + struct gssx_status status; + struct gssx_ctx *context_handle; + gssx_buffer *output_token; + /* struct gssx_cred *delegated_cred_handle; not used in kernel */ + struct gssx_option_array options; +}; + + + +#define gssx_enc_indicate_mechs NULL +#define gssx_dec_indicate_mechs NULL +#define gssx_enc_get_call_context NULL +#define gssx_dec_get_call_context NULL +#define gssx_enc_import_and_canon_name NULL +#define gssx_dec_import_and_canon_name NULL +#define gssx_enc_export_cred NULL +#define gssx_dec_export_cred NULL +#define gssx_enc_import_cred NULL +#define gssx_dec_import_cred NULL +#define gssx_enc_acquire_cred NULL +#define gssx_dec_acquire_cred NULL +#define gssx_enc_store_cred NULL +#define gssx_dec_store_cred NULL +#define gssx_enc_init_sec_context NULL +#define gssx_dec_init_sec_context NULL +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + const void *data); +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + void *data); +#define gssx_enc_release_handle NULL +#define gssx_dec_release_handle NULL +#define gssx_enc_get_mic NULL +#define gssx_dec_get_mic NULL +#define gssx_enc_verify NULL +#define gssx_dec_verify NULL +#define gssx_enc_wrap NULL +#define gssx_dec_wrap NULL +#define gssx_enc_unwrap NULL +#define gssx_dec_unwrap NULL +#define gssx_enc_wrap_size_limit NULL +#define gssx_dec_wrap_size_limit NULL + +/* non implemented calls are set to 0 size */ +#define GSSX_ARG_indicate_mechs_sz 0 +#define GSSX_RES_indicate_mechs_sz 0 +#define GSSX_ARG_get_call_context_sz 0 +#define GSSX_RES_get_call_context_sz 0 +#define GSSX_ARG_import_and_canon_name_sz 0 +#define GSSX_RES_import_and_canon_name_sz 0 +#define GSSX_ARG_export_cred_sz 0 +#define GSSX_RES_export_cred_sz 0 +#define GSSX_ARG_import_cred_sz 0 +#define GSSX_RES_import_cred_sz 0 +#define GSSX_ARG_acquire_cred_sz 0 +#define GSSX_RES_acquire_cred_sz 0 +#define GSSX_ARG_store_cred_sz 0 +#define GSSX_RES_store_cred_sz 0 +#define GSSX_ARG_init_sec_context_sz 0 +#define GSSX_RES_init_sec_context_sz 0 + +#define GSSX_default_in_call_ctx_sz (4 + 4 + 4 + \ + 8 + sizeof(LUCID_OPTION) + sizeof(LUCID_VALUE) + \ + 8 + sizeof(CREDS_OPTION) + sizeof(CREDS_VALUE)) +#define GSSX_default_in_ctx_hndl_sz (4 + 4+8 + 4 + 4 + 6*4 + 6*4 + 8 + 8 + \ + 4 + 4 + 4) +#define GSSX_default_in_cred_sz 4 /* we send in no cred_handle */ +#define GSSX_default_in_token_sz 4 /* does *not* include token data */ +#define GSSX_default_in_cb_sz 4 /* we do not use channel bindings */ +#define GSSX_ARG_accept_sec_context_sz (GSSX_default_in_call_ctx_sz + \ + GSSX_default_in_ctx_hndl_sz + \ + GSSX_default_in_cred_sz + \ + GSSX_default_in_token_sz + \ + GSSX_default_in_cb_sz + \ + 4 /* no deleg creds boolean */ + \ + 4) /* empty options */ + +/* somewhat arbitrary numbers but large enough (we ignore some of the data + * sent down, but it is part of the protocol so we need enough space to take + * it in) */ +#define GSSX_default_status_sz 8 + 24 + 8 + 256 + 256 + 16 + 4 +#define GSSX_max_output_handle_sz 128 +#define GSSX_max_oid_sz 16 +#define GSSX_max_princ_sz 256 +#define GSSX_default_ctx_sz (GSSX_max_output_handle_sz + \ + 16 + 4 + GSSX_max_oid_sz + \ + 2 * GSSX_max_princ_sz + \ + 8 + 8 + 4 + 4 + 4) +#define GSSX_max_output_token_sz 1024 +/* grouplist not included; we allocate separate pages for that: */ +#define GSSX_max_creds_sz (4 + 4 + 4 /* + NGROUPS_MAX*4 */) +#define GSSX_RES_accept_sec_context_sz (GSSX_default_status_sz + \ + GSSX_default_ctx_sz + \ + GSSX_max_output_token_sz + \ + 4 + GSSX_max_creds_sz) + +#define GSSX_ARG_release_handle_sz 0 +#define GSSX_RES_release_handle_sz 0 +#define GSSX_ARG_get_mic_sz 0 +#define GSSX_RES_get_mic_sz 0 +#define GSSX_ARG_verify_sz 0 +#define GSSX_RES_verify_sz 0 +#define GSSX_ARG_wrap_sz 0 +#define GSSX_RES_wrap_sz 0 +#define GSSX_ARG_unwrap_sz 0 +#define GSSX_RES_unwrap_sz 0 +#define GSSX_ARG_wrap_size_limit_sz 0 +#define GSSX_RES_wrap_size_limit_sz 0 + + + +#endif /* _LINUX_GSS_RPC_XDR_H */ diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c new file mode 100644 index 000000000..d9d03881e --- /dev/null +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -0,0 +1,1941 @@ +/* + * Neil Brown <neilb@cse.unsw.edu.au> + * J. Bruce Fields <bfields@umich.edu> + * Andy Adamson <andros@umich.edu> + * Dug Song <dugsong@monkey.org> + * + * RPCSEC_GSS server authentication. + * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078 + * (gssapi) + * + * The RPCSEC_GSS involves three stages: + * 1/ context creation + * 2/ data exchange + * 3/ context destruction + * + * Context creation is handled largely by upcalls to user-space. + * In particular, GSS_Accept_sec_context is handled by an upcall + * Data exchange is handled entirely within the kernel + * In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel. + * Context destruction is handled in-kernel + * GSS_Delete_sec_context is in-kernel + * + * Context creation is initiated by a RPCSEC_GSS_INIT request arriving. + * The context handle and gss_token are used as a key into the rpcsec_init cache. + * The content of this cache includes some of the outputs of GSS_Accept_sec_context, + * being major_status, minor_status, context_handle, reply_token. + * These are sent back to the client. + * Sequence window management is handled by the kernel. The window size if currently + * a compile time constant. + * + * When user-space is happy that a context is established, it places an entry + * in the rpcsec_context cache. The key for this cache is the context_handle. + * The content includes: + * uid/gidlist - for determining access rights + * mechanism type + * mechanism specific information, such as a key + * + */ + +#include <linux/slab.h> +#include <linux/types.h> +#include <linux/module.h> +#include <linux/pagemap.h> +#include <linux/user_namespace.h> + +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/cache.h> +#include "gss_rpc_upcall.h" + + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests + * into replies. + * + * Key is context handle (\x if empty) and gss_token. + * Content is major_status minor_status (integers) context_handle, reply_token. + * + */ + +static int netobj_equal(struct xdr_netobj *a, struct xdr_netobj *b) +{ + return a->len == b->len && 0 == memcmp(a->data, b->data, a->len); +} + +#define RSI_HASHBITS 6 +#define RSI_HASHMAX (1<<RSI_HASHBITS) + +struct rsi { + struct cache_head h; + struct xdr_netobj in_handle, in_token; + struct xdr_netobj out_handle, out_token; + int major_status, minor_status; +}; + +static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old); +static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item); + +static void rsi_free(struct rsi *rsii) +{ + kfree(rsii->in_handle.data); + kfree(rsii->in_token.data); + kfree(rsii->out_handle.data); + kfree(rsii->out_token.data); +} + +static void rsi_put(struct kref *ref) +{ + struct rsi *rsii = container_of(ref, struct rsi, h.ref); + rsi_free(rsii); + kfree(rsii); +} + +static inline int rsi_hash(struct rsi *item) +{ + return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS) + ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS); +} + +static int rsi_match(struct cache_head *a, struct cache_head *b) +{ + struct rsi *item = container_of(a, struct rsi, h); + struct rsi *tmp = container_of(b, struct rsi, h); + return netobj_equal(&item->in_handle, &tmp->in_handle) && + netobj_equal(&item->in_token, &tmp->in_token); +} + +static int dup_to_netobj(struct xdr_netobj *dst, char *src, int len) +{ + dst->len = len; + dst->data = (len ? kmemdup(src, len, GFP_KERNEL) : NULL); + if (len && !dst->data) + return -ENOMEM; + return 0; +} + +static inline int dup_netobj(struct xdr_netobj *dst, struct xdr_netobj *src) +{ + return dup_to_netobj(dst, src->data, src->len); +} + +static void rsi_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + new->out_handle.data = NULL; + new->out_handle.len = 0; + new->out_token.data = NULL; + new->out_token.len = 0; + new->in_handle.len = item->in_handle.len; + item->in_handle.len = 0; + new->in_token.len = item->in_token.len; + item->in_token.len = 0; + new->in_handle.data = item->in_handle.data; + item->in_handle.data = NULL; + new->in_token.data = item->in_token.data; + item->in_token.data = NULL; +} + +static void update_rsi(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + BUG_ON(new->out_handle.data || new->out_token.data); + new->out_handle.len = item->out_handle.len; + item->out_handle.len = 0; + new->out_token.len = item->out_token.len; + item->out_token.len = 0; + new->out_handle.data = item->out_handle.data; + item->out_handle.data = NULL; + new->out_token.data = item->out_token.data; + item->out_token.data = NULL; + + new->major_status = item->major_status; + new->minor_status = item->minor_status; +} + +static struct cache_head *rsi_alloc(void) +{ + struct rsi *rsii = kmalloc(sizeof(*rsii), GFP_KERNEL); + if (rsii) + return &rsii->h; + else + return NULL; +} + +static void rsi_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + struct rsi *rsii = container_of(h, struct rsi, h); + + qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len); + qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len); + (*bpp)[-1] = '\n'; +} + +static int rsi_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* context token expiry major minor context token */ + char *buf = mesg; + char *ep; + int len; + struct rsi rsii, *rsip = NULL; + time_t expiry; + int status = -EINVAL; + + memset(&rsii, 0, sizeof(rsii)); + /* handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_handle, buf, len)) + goto out; + + /* token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_token, buf, len)) + goto out; + + rsip = rsi_lookup(cd, &rsii); + if (!rsip) + goto out; + + rsii.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + /* major/minor */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.major_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.minor_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + + /* out_handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_handle, buf, len)) + goto out; + + /* out_token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_token, buf, len)) + goto out; + rsii.h.expiry_time = expiry; + rsip = rsi_update(cd, &rsii, rsip); + status = 0; +out: + rsi_free(&rsii); + if (rsip) + cache_put(&rsip->h, cd); + else + status = -ENOMEM; + return status; +} + +static const struct cache_detail rsi_cache_template = { + .owner = THIS_MODULE, + .hash_size = RSI_HASHMAX, + .name = "auth.rpcsec.init", + .cache_put = rsi_put, + .cache_request = rsi_request, + .cache_parse = rsi_parse, + .match = rsi_match, + .init = rsi_init, + .update = update_rsi, + .alloc = rsi_alloc, +}; + +static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item) +{ + struct cache_head *ch; + int hash = rsi_hash(item); + + ch = sunrpc_cache_lookup(cd, &item->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + +static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old) +{ + struct cache_head *ch; + int hash = rsi_hash(new); + + ch = sunrpc_cache_update(cd, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + + +/* + * The rpcsec_context cache is used to store a context that is + * used in data exchange. + * The key is a context handle. The content is: + * uid, gidlist, mechanism, service-set, mech-specific-data + */ + +#define RSC_HASHBITS 10 +#define RSC_HASHMAX (1<<RSC_HASHBITS) + +#define GSS_SEQ_WIN 128 + +struct gss_svc_seq_data { + /* highest seq number seen so far: */ + int sd_max; + /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of + * sd_win is nonzero iff sequence number i has been seen already: */ + unsigned long sd_win[GSS_SEQ_WIN/BITS_PER_LONG]; + spinlock_t sd_lock; +}; + +struct rsc { + struct cache_head h; + struct xdr_netobj handle; + struct svc_cred cred; + struct gss_svc_seq_data seqdata; + struct gss_ctx *mechctx; +}; + +static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old); +static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item); + +static void rsc_free(struct rsc *rsci) +{ + kfree(rsci->handle.data); + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + free_svc_cred(&rsci->cred); +} + +static void rsc_put(struct kref *ref) +{ + struct rsc *rsci = container_of(ref, struct rsc, h.ref); + + rsc_free(rsci); + kfree(rsci); +} + +static inline int +rsc_hash(struct rsc *rsci) +{ + return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS); +} + +static int +rsc_match(struct cache_head *a, struct cache_head *b) +{ + struct rsc *new = container_of(a, struct rsc, h); + struct rsc *tmp = container_of(b, struct rsc, h); + + return netobj_equal(&new->handle, &tmp->handle); +} + +static void +rsc_init(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->handle.len = tmp->handle.len; + tmp->handle.len = 0; + new->handle.data = tmp->handle.data; + tmp->handle.data = NULL; + new->mechctx = NULL; + init_svc_cred(&new->cred); +} + +static void +update_rsc(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->mechctx = tmp->mechctx; + tmp->mechctx = NULL; + memset(&new->seqdata, 0, sizeof(new->seqdata)); + spin_lock_init(&new->seqdata.sd_lock); + new->cred = tmp->cred; + init_svc_cred(&tmp->cred); +} + +static struct cache_head * +rsc_alloc(void) +{ + struct rsc *rsci = kmalloc(sizeof(*rsci), GFP_KERNEL); + if (rsci) + return &rsci->h; + else + return NULL; +} + +static int rsc_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */ + char *buf = mesg; + int id; + int len, rv; + struct rsc rsci, *rscp = NULL; + time_t expiry; + int status = -EINVAL; + struct gss_api_mech *gm = NULL; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsci.handle, buf, len)) + goto out; + + rsci.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + rscp = rsc_lookup(cd, &rsci); + if (!rscp) + goto out; + + /* uid, or NEGATIVE */ + rv = get_int(&mesg, &id); + if (rv == -EINVAL) + goto out; + if (rv == -ENOENT) + set_bit(CACHE_NEGATIVE, &rsci.h.flags); + else { + int N, i; + + /* + * NOTE: we skip uid_valid()/gid_valid() checks here: + * instead, * -1 id's are later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * + * (But supplementary gid's get no such special + * treatment so are checked for validity here.) + */ + /* uid */ + rsci.cred.cr_uid = make_kuid(&init_user_ns, id); + + /* gid */ + if (get_int(&mesg, &id)) + goto out; + rsci.cred.cr_gid = make_kgid(&init_user_ns, id); + + /* number of additional gid's */ + if (get_int(&mesg, &N)) + goto out; + if (N < 0 || N > NGROUPS_MAX) + goto out; + status = -ENOMEM; + rsci.cred.cr_group_info = groups_alloc(N); + if (rsci.cred.cr_group_info == NULL) + goto out; + + /* gid's */ + status = -EINVAL; + for (i=0; i<N; i++) { + kgid_t kgid; + if (get_int(&mesg, &id)) + goto out; + kgid = make_kgid(&init_user_ns, id); + if (!gid_valid(kgid)) + goto out; + rsci.cred.cr_group_info->gid[i] = kgid; + } + groups_sort(rsci.cred.cr_group_info); + + /* mech name */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + gm = rsci.cred.cr_gss_mech = gss_mech_get_by_name(buf); + status = -EOPNOTSUPP; + if (!gm) + goto out; + + status = -EINVAL; + /* mech-specific data: */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = gss_import_sec_context(buf, len, gm, &rsci.mechctx, + NULL, GFP_KERNEL); + if (status) + goto out; + + /* get client name */ + len = qword_get(&mesg, buf, mlen); + if (len > 0) { + rsci.cred.cr_principal = kstrdup(buf, GFP_KERNEL); + if (!rsci.cred.cr_principal) { + status = -ENOMEM; + goto out; + } + } + + } + rsci.h.expiry_time = expiry; + rscp = rsc_update(cd, &rsci, rscp); + status = 0; +out: + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, cd); + else + status = -ENOMEM; + return status; +} + +static const struct cache_detail rsc_cache_template = { + .owner = THIS_MODULE, + .hash_size = RSC_HASHMAX, + .name = "auth.rpcsec.context", + .cache_put = rsc_put, + .cache_parse = rsc_parse, + .match = rsc_match, + .init = rsc_init, + .update = update_rsc, + .alloc = rsc_alloc, +}; + +static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item) +{ + struct cache_head *ch; + int hash = rsc_hash(item); + + ch = sunrpc_cache_lookup(cd, &item->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + +static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old) +{ + struct cache_head *ch; + int hash = rsc_hash(new); + + ch = sunrpc_cache_update(cd, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + + +static struct rsc * +gss_svc_searchbyctx(struct cache_detail *cd, struct xdr_netobj *handle) +{ + struct rsc rsci; + struct rsc *found; + + memset(&rsci, 0, sizeof(rsci)); + if (dup_to_netobj(&rsci.handle, handle->data, handle->len)) + return NULL; + found = rsc_lookup(cd, &rsci); + rsc_free(&rsci); + if (!found) + return NULL; + if (cache_check(cd, &found->h, NULL)) + return NULL; + return found; +} + +/* Implements sequence number algorithm as specified in RFC 2203. */ +static int +gss_check_seq_num(struct rsc *rsci, int seq_num) +{ + struct gss_svc_seq_data *sd = &rsci->seqdata; + + spin_lock(&sd->sd_lock); + if (seq_num > sd->sd_max) { + if (seq_num >= sd->sd_max + GSS_SEQ_WIN) { + memset(sd->sd_win,0,sizeof(sd->sd_win)); + sd->sd_max = seq_num; + } else while (sd->sd_max < seq_num) { + sd->sd_max++; + __clear_bit(sd->sd_max % GSS_SEQ_WIN, sd->sd_win); + } + __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win); + goto ok; + } else if (seq_num <= sd->sd_max - GSS_SEQ_WIN) { + goto drop; + } + /* sd_max - GSS_SEQ_WIN < seq_num <= sd_max */ + if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) + goto drop; +ok: + spin_unlock(&sd->sd_lock); + return 1; +drop: + spin_unlock(&sd->sd_lock); + return 0; +} + +static inline u32 round_up_to_quad(u32 i) +{ + return (i + 3 ) & ~3; +} + +static inline int +svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o) +{ + int l; + + if (argv->iov_len < 4) + return -1; + o->len = svc_getnl(argv); + l = round_up_to_quad(o->len); + if (argv->iov_len < l) + return -1; + o->data = argv->iov_base; + argv->iov_base += l; + argv->iov_len -= l; + return 0; +} + +static inline int +svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o) +{ + u8 *p; + + if (resv->iov_len + 4 > PAGE_SIZE) + return -1; + svc_putnl(resv, o->len); + p = resv->iov_base + resv->iov_len; + resv->iov_len += round_up_to_quad(o->len); + if (resv->iov_len > PAGE_SIZE) + return -1; + memcpy(p, o->data, o->len); + memset(p + o->len, 0, round_up_to_quad(o->len) - o->len); + return 0; +} + +/* + * Verify the checksum on the header and return SVC_OK on success. + * Otherwise, return SVC_DROP (in the case of a bad sequence number) + * or return SVC_DENIED and indicate error in authp. + */ +static int +gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci, + __be32 *rpcstart, struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct gss_ctx *ctx_id = rsci->mechctx; + struct xdr_buf rpchdr; + struct xdr_netobj checksum; + u32 flavor = 0; + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec iov; + + /* data to compute the checksum over: */ + iov.iov_base = rpcstart; + iov.iov_len = (u8 *)argv->iov_base - (u8 *)rpcstart; + xdr_buf_from_iov(&iov, &rpchdr); + + *authp = rpc_autherr_badverf; + if (argv->iov_len < 4) + return SVC_DENIED; + flavor = svc_getnl(argv); + if (flavor != RPC_AUTH_GSS) + return SVC_DENIED; + if (svc_safe_getnetobj(argv, &checksum)) + return SVC_DENIED; + + if (rqstp->rq_deferred) /* skip verification of revisited request */ + return SVC_OK; + if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) { + *authp = rpcsec_gsserr_credproblem; + return SVC_DENIED; + } + + if (gc->gc_seq > MAXSEQ) { + dprintk("RPC: svcauth_gss: discarding request with " + "large sequence number %d\n", gc->gc_seq); + *authp = rpcsec_gsserr_ctxproblem; + return SVC_DENIED; + } + if (!gss_check_seq_num(rsci, gc->gc_seq)) { + dprintk("RPC: svcauth_gss: discarding request with " + "old sequence number %d\n", gc->gc_seq); + return SVC_DROP; + } + return SVC_OK; +} + +static int +gss_write_null_verf(struct svc_rqst *rqstp) +{ + __be32 *p; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_NULL); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + /* don't really need to check if head->iov_len > PAGE_SIZE ... */ + *p++ = 0; + if (!xdr_ressize_check(rqstp, p)) + return -1; + return 0; +} + +static int +gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq) +{ + __be32 *xdr_seq; + u32 maj_stat; + struct xdr_buf verf_data; + struct xdr_netobj mic; + __be32 *p; + struct kvec iov; + int err = -1; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_GSS); + xdr_seq = kmalloc(4, GFP_KERNEL); + if (!xdr_seq) + return -1; + *xdr_seq = htonl(seq); + + iov.iov_base = xdr_seq; + iov.iov_len = 4; + xdr_buf_from_iov(&iov, &verf_data); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx_id, &verf_data, &mic); + if (maj_stat != GSS_S_COMPLETE) + goto out; + *p++ = htonl(mic.len); + memset((u8 *)p + mic.len, 0, round_up_to_quad(mic.len) - mic.len); + p += XDR_QUADLEN(mic.len); + if (!xdr_ressize_check(rqstp, p)) + goto out; + err = 0; +out: + kfree(xdr_seq); + return err; +} + +struct gss_domain { + struct auth_domain h; + u32 pseudoflavor; +}; + +static struct auth_domain * +find_gss_auth_domain(struct gss_ctx *ctx, u32 svc) +{ + char *name; + + name = gss_service_to_auth_domain_name(ctx->mech_type, svc); + if (!name) + return NULL; + return auth_domain_find(name); +} + +static struct auth_ops svcauthops_gss; + +u32 svcauth_gss_flavor(struct auth_domain *dom) +{ + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + return gd->pseudoflavor; +} + +EXPORT_SYMBOL_GPL(svcauth_gss_flavor); + +struct auth_domain * +svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name) +{ + struct gss_domain *new; + struct auth_domain *test; + int stat = -ENOMEM; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + goto out; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (!new->h.name) + goto out_free_dom; + new->h.flavour = &svcauthops_gss; + new->pseudoflavor = pseudoflavor; + + test = auth_domain_lookup(name, &new->h); + if (test != &new->h) { + pr_warn("svc: duplicate registration of gss pseudo flavour %s.\n", + name); + stat = -EADDRINUSE; + auth_domain_put(test); + goto out_free_name; + } + return test; + +out_free_name: + kfree(new->h.name); +out_free_dom: + kfree(new); +out: + return ERR_PTR(stat); +} +EXPORT_SYMBOL_GPL(svcauth_gss_register_pseudoflavor); + +static inline int +read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = ntohl(raw); + return 0; +} + +/* It would be nice if this bit of code could be shared with the client. + * Obstacles: + * The client shouldn't malloc(), would have to pass in own memory. + * The server uses base of head iovec as read pointer, while the + * client uses separate pointer. */ +static int +unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + int stat = -EINVAL; + u32 integ_len, maj_stat; + struct xdr_netobj mic; + struct xdr_buf integ_buf; + + /* NFS READ normally uses splice to send data in-place. However + * the data in cache can change after the reply's MIC is computed + * but before the RPC reply is sent. To prevent the client from + * rejecting the server-computed MIC in this somewhat rare case, + * do not use splice with the GSS integrity service. + */ + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + + /* Did we already verify the signature on the original pass through? */ + if (rqstp->rq_deferred) + return 0; + + integ_len = svc_getnl(&buf->head[0]); + if (integ_len & 3) + return stat; + if (integ_len > buf->len) + return stat; + if (xdr_buf_subsegment(buf, &integ_buf, 0, integ_len)) { + WARN_ON_ONCE(1); + return stat; + } + /* copy out mic... */ + if (read_u32_from_xdr_buf(buf, integ_len, &mic.len)) + return stat; + if (mic.len > RPC_MAX_AUTH_SIZE) + return stat; + mic.data = kmalloc(mic.len, GFP_KERNEL); + if (!mic.data) + return stat; + if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len)) + goto out; + maj_stat = gss_verify_mic(ctx, &integ_buf, &mic); + if (maj_stat != GSS_S_COMPLETE) + goto out; + if (svc_getnl(&buf->head[0]) != seq) + goto out; + /* trim off the mic and padding at the end before returning */ + xdr_buf_trim(buf, round_up_to_quad(mic.len) + 4); + stat = 0; +out: + kfree(mic.data); + return stat; +} + +static inline int +total_buf_len(struct xdr_buf *buf) +{ + return buf->head[0].iov_len + buf->page_len + buf->tail[0].iov_len; +} + +static void +fix_priv_head(struct xdr_buf *buf, int pad) +{ + if (buf->page_len == 0) { + /* We need to adjust head and buf->len in tandem in this + * case to make svc_defer() work--it finds the original + * buffer start using buf->len - buf->head[0].iov_len. */ + buf->head[0].iov_len -= pad; + } +} + +static int +unwrap_priv_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + u32 priv_len, maj_stat; + int pad, saved_len, remaining_len, offset; + + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + + priv_len = svc_getnl(&buf->head[0]); + if (rqstp->rq_deferred) { + /* Already decrypted last time through! The sequence number + * check at out_seq is unnecessary but harmless: */ + goto out_seq; + } + /* buf->len is the number of bytes from the original start of the + * request to the end, where head[0].iov_len is just the bytes + * not yet read from the head, so these two values are different: */ + remaining_len = total_buf_len(buf); + if (priv_len > remaining_len) + return -EINVAL; + pad = remaining_len - priv_len; + buf->len -= pad; + fix_priv_head(buf, pad); + + /* Maybe it would be better to give gss_unwrap a length parameter: */ + saved_len = buf->len; + buf->len = priv_len; + maj_stat = gss_unwrap(ctx, 0, buf); + pad = priv_len - buf->len; + buf->len = saved_len; + buf->len -= pad; + /* The upper layers assume the buffer is aligned on 4-byte boundaries. + * In the krb5p case, at least, the data ends up offset, so we need to + * move it around. */ + /* XXX: This is very inefficient. It would be better to either do + * this while we encrypt, or maybe in the receive code, if we can peak + * ahead and work out the service and mechanism there. */ + offset = buf->head[0].iov_len % 4; + if (offset) { + buf->buflen = RPCSVC_MAXPAYLOAD; + xdr_shift_buf(buf, offset); + fix_priv_head(buf, pad); + } + if (maj_stat != GSS_S_COMPLETE) + return -EINVAL; +out_seq: + if (svc_getnl(&buf->head[0]) != seq) + return -EINVAL; + return 0; +} + +struct gss_svc_data { + /* decoded gss client cred: */ + struct rpc_gss_wire_cred clcred; + /* save a pointer to the beginning of the encoded verifier, + * for use in encryption/checksumming in svcauth_gss_release: */ + __be32 *verf_start; + struct rsc *rsci; +}; + +static int +svcauth_gss_set_client(struct svc_rqst *rqstp) +{ + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rsc *rsci = svcdata->rsci; + struct rpc_gss_wire_cred *gc = &svcdata->clcred; + int stat; + + /* + * A gss export can be specified either by: + * export *(sec=krb5,rw) + * or by + * export gss/krb5(rw) + * The latter is deprecated; but for backwards compatibility reasons + * the nfsd code will still fall back on trying it if the former + * doesn't work; so we try to make both available to nfsd, below. + */ + rqstp->rq_gssclient = find_gss_auth_domain(rsci->mechctx, gc->gc_svc); + if (rqstp->rq_gssclient == NULL) + return SVC_DENIED; + stat = svcauth_unix_set_client(rqstp); + if (stat == SVC_DROP || stat == SVC_CLOSE) + return stat; + return SVC_OK; +} + +static inline int +gss_write_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp, + struct xdr_netobj *out_handle, int *major_status) +{ + struct rsc *rsci; + int rc; + + if (*major_status != GSS_S_COMPLETE) + return gss_write_null_verf(rqstp); + rsci = gss_svc_searchbyctx(cd, out_handle); + if (rsci == NULL) { + *major_status = GSS_S_NO_CONTEXT; + return gss_write_null_verf(rqstp); + } + rc = gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN); + cache_put(&rsci->h, cd); + return rc; +} + +static inline int +gss_read_common_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle) +{ + /* Read the verifier; should be NULL: */ + *authp = rpc_autherr_badverf; + if (argv->iov_len < 2 * 4) + return SVC_DENIED; + if (svc_getnl(argv) != RPC_AUTH_NULL) + return SVC_DENIED; + if (svc_getnl(argv) != 0) + return SVC_DENIED; + /* Martial context handle and token for upcall: */ + *authp = rpc_autherr_badcred; + if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) + return SVC_DENIED; + if (dup_netobj(in_handle, &gc->gc_ctx)) + return SVC_CLOSE; + *authp = rpc_autherr_badverf; + + return 0; +} + +static inline int +gss_read_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle, + struct xdr_netobj *in_token) +{ + struct xdr_netobj tmpobj; + int res; + + res = gss_read_common_verf(gc, argv, authp, in_handle); + if (res) + return res; + + if (svc_safe_getnetobj(argv, &tmpobj)) { + kfree(in_handle->data); + return SVC_DENIED; + } + if (dup_netobj(in_token, &tmpobj)) { + kfree(in_handle->data); + return SVC_CLOSE; + } + + return 0; +} + +static void gss_free_in_token_pages(struct gssp_in_token *in_token) +{ + u32 inlen; + int i; + + i = 0; + inlen = in_token->page_len; + while (inlen) { + if (in_token->pages[i]) + put_page(in_token->pages[i]); + inlen -= inlen > PAGE_SIZE ? PAGE_SIZE : inlen; + } + + kfree(in_token->pages); + in_token->pages = NULL; +} + +static int gss_read_proxy_verf(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp, + struct xdr_netobj *in_handle, + struct gssp_in_token *in_token) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + unsigned int length, pgto_offs, pgfrom_offs; + int pages, i, res, pgto, pgfrom; + size_t inlen, to_offs, from_offs; + + res = gss_read_common_verf(gc, argv, authp, in_handle); + if (res) + return res; + + inlen = svc_getnl(argv); + if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) + return SVC_DENIED; + + pages = DIV_ROUND_UP(inlen, PAGE_SIZE); + in_token->pages = kcalloc(pages, sizeof(struct page *), GFP_KERNEL); + if (!in_token->pages) + return SVC_DENIED; + in_token->page_base = 0; + in_token->page_len = inlen; + for (i = 0; i < pages; i++) { + in_token->pages[i] = alloc_page(GFP_KERNEL); + if (!in_token->pages[i]) { + gss_free_in_token_pages(in_token); + return SVC_DENIED; + } + } + + length = min_t(unsigned int, inlen, argv->iov_len); + memcpy(page_address(in_token->pages[0]), argv->iov_base, length); + inlen -= length; + + to_offs = length; + from_offs = rqstp->rq_arg.page_base; + while (inlen) { + pgto = to_offs >> PAGE_SHIFT; + pgfrom = from_offs >> PAGE_SHIFT; + pgto_offs = to_offs & ~PAGE_MASK; + pgfrom_offs = from_offs & ~PAGE_MASK; + + length = min_t(unsigned int, inlen, + min_t(unsigned int, PAGE_SIZE - pgto_offs, + PAGE_SIZE - pgfrom_offs)); + memcpy(page_address(in_token->pages[pgto]) + pgto_offs, + page_address(rqstp->rq_arg.pages[pgfrom]) + pgfrom_offs, + length); + + to_offs += length; + from_offs += length; + inlen -= length; + } + return 0; +} + +static inline int +gss_write_resv(struct kvec *resv, size_t size_limit, + struct xdr_netobj *out_handle, struct xdr_netobj *out_token, + int major_status, int minor_status) +{ + if (resv->iov_len + 4 > size_limit) + return -1; + svc_putnl(resv, RPC_SUCCESS); + if (svc_safe_putnetobj(resv, out_handle)) + return -1; + if (resv->iov_len + 3 * 4 > size_limit) + return -1; + svc_putnl(resv, major_status); + svc_putnl(resv, minor_status); + svc_putnl(resv, GSS_SEQ_WIN); + if (svc_safe_putnetobj(resv, out_token)) + return -1; + return 0; +} + +/* + * Having read the cred already and found we're in the context + * initiation case, read the verifier and initiate (or check the results + * of) upcalls to userspace for help with context initiation. If + * the upcall results are available, write the verifier and result. + * Otherwise, drop the request pending an answer to the upcall. + */ +static int svcauth_gss_legacy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct rsi *rsip, rsikey; + int ret; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + memset(&rsikey, 0, sizeof(rsikey)); + ret = gss_read_verf(gc, argv, authp, + &rsikey.in_handle, &rsikey.in_token); + if (ret) + return ret; + + /* Perform upcall, or find upcall result: */ + rsip = rsi_lookup(sn->rsi_cache, &rsikey); + rsi_free(&rsikey); + if (!rsip) + return SVC_CLOSE; + if (cache_check(sn->rsi_cache, &rsip->h, &rqstp->rq_chandle) < 0) + /* No upcall result: */ + return SVC_CLOSE; + + ret = SVC_CLOSE; + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &rsip->out_handle, &rsip->major_status)) + goto out; + if (gss_write_resv(resv, PAGE_SIZE, + &rsip->out_handle, &rsip->out_token, + rsip->major_status, rsip->minor_status)) + goto out; + + ret = SVC_COMPLETE; +out: + cache_put(&rsip->h, sn->rsi_cache); + return ret; +} + +static int gss_proxy_save_rsc(struct cache_detail *cd, + struct gssp_upcall_data *ud, + uint64_t *handle) +{ + struct rsc rsci, *rscp = NULL; + static atomic64_t ctxhctr; + long long ctxh; + struct gss_api_mech *gm = NULL; + time_t expiry; + int status = -EINVAL; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + status = -ENOMEM; + /* the handle needs to be just a unique id, + * use a static counter */ + ctxh = atomic64_inc_return(&ctxhctr); + + /* make a copy for the caller */ + *handle = ctxh; + + /* make a copy for the rsc cache */ + if (dup_to_netobj(&rsci.handle, (char *)handle, sizeof(uint64_t))) + goto out; + rscp = rsc_lookup(cd, &rsci); + if (!rscp) + goto out; + + /* creds */ + if (!ud->found_creds) { + /* userspace seem buggy, we should always get at least a + * mapping to nobody */ + dprintk("RPC: No creds found!\n"); + goto out; + } else { + struct timespec64 boot; + + /* steal creds */ + rsci.cred = ud->creds; + memset(&ud->creds, 0, sizeof(struct svc_cred)); + + status = -EOPNOTSUPP; + /* get mech handle from OID */ + gm = gss_mech_get_by_OID(&ud->mech_oid); + if (!gm) + goto out; + rsci.cred.cr_gss_mech = gm; + + status = -EINVAL; + /* mech-specific data: */ + status = gss_import_sec_context(ud->out_handle.data, + ud->out_handle.len, + gm, &rsci.mechctx, + &expiry, GFP_KERNEL); + if (status) + goto out; + + getboottime64(&boot); + expiry -= boot.tv_sec; + } + + rsci.h.expiry_time = expiry; + rscp = rsc_update(cd, &rsci, rscp); + status = 0; +out: + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, cd); + else + status = -ENOMEM; + return status; +} + +static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_netobj cli_handle; + struct gssp_upcall_data ud; + uint64_t handle; + int status; + int ret; + struct net *net = SVC_NET(rqstp); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + memset(&ud, 0, sizeof(ud)); + ret = gss_read_proxy_verf(rqstp, gc, authp, + &ud.in_handle, &ud.in_token); + if (ret) + return ret; + + ret = SVC_CLOSE; + + /* Perform synchronous upcall to gss-proxy */ + status = gssp_accept_sec_context_upcall(net, &ud); + if (status) + goto out; + + dprintk("RPC: svcauth_gss: gss major status = %d " + "minor status = %d\n", + ud.major_status, ud.minor_status); + + switch (ud.major_status) { + case GSS_S_CONTINUE_NEEDED: + cli_handle = ud.out_handle; + break; + case GSS_S_COMPLETE: + status = gss_proxy_save_rsc(sn->rsc_cache, &ud, &handle); + if (status) { + pr_info("%s: gss_proxy_save_rsc failed (%d)\n", + __func__, status); + goto out; + } + cli_handle.data = (u8 *)&handle; + cli_handle.len = sizeof(handle); + break; + default: + ret = SVC_CLOSE; + goto out; + } + + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &cli_handle, &ud.major_status)) { + pr_info("%s: gss_write_init_verf failed\n", __func__); + goto out; + } + if (gss_write_resv(resv, PAGE_SIZE, + &cli_handle, &ud.out_token, + ud.major_status, ud.minor_status)) { + pr_info("%s: gss_write_resv failed\n", __func__); + goto out; + } + + ret = SVC_COMPLETE; +out: + gss_free_in_token_pages(&ud.in_token); + gssp_free_upcall_data(&ud); + return ret; +} + +/* + * Try to set the sn->use_gss_proxy variable to a new value. We only allow + * it to be changed if it's currently undefined (-1). If it's any other value + * then return -EBUSY unless the type wouldn't have changed anyway. + */ +static int set_gss_proxy(struct net *net, int type) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; + + WARN_ON_ONCE(type != 0 && type != 1); + ret = cmpxchg(&sn->use_gss_proxy, -1, type); + if (ret != -1 && ret != type) + return -EBUSY; + return 0; +} + +static bool use_gss_proxy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* If use_gss_proxy is still undefined, then try to disable it */ + if (sn->use_gss_proxy == -1) + set_gss_proxy(net, 0); + return sn->use_gss_proxy; +} + +#ifdef CONFIG_PROC_FS + +static ssize_t write_gssp(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = PDE_DATA(file_inode(file)); + char tbuf[20]; + unsigned long i; + int res; + + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + + tbuf[count] = 0; + res = kstrtoul(tbuf, 0, &i); + if (res) + return res; + if (i != 1) + return -EINVAL; + res = set_gssp_clnt(net); + if (res) + return res; + res = set_gss_proxy(net, 1); + if (res) + return res; + return count; +} + +static ssize_t read_gssp(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = PDE_DATA(file_inode(file)); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + unsigned long p = *ppos; + char tbuf[10]; + size_t len; + + snprintf(tbuf, sizeof(tbuf), "%d\n", sn->use_gss_proxy); + len = strlen(tbuf); + if (p >= len) + return 0; + len -= p; + if (len > count) + len = count; + if (copy_to_user(buf, (void *)(tbuf+p), len)) + return -EFAULT; + *ppos += len; + return len; +} + +static const struct file_operations use_gss_proxy_ops = { + .open = nonseekable_open, + .write = write_gssp, + .read = read_gssp, +}; + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct proc_dir_entry **p = &sn->use_gssp_proc; + + sn->use_gss_proxy = -1; + *p = proc_create_data("use-gss-proxy", S_IFREG | 0600, + sn->proc_net_rpc, + &use_gss_proxy_ops, net); + if (!*p) + return -ENOMEM; + init_gssp_clnt(sn); + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->use_gssp_proc) { + remove_proc_entry("use-gss-proxy", sn->proc_net_rpc); + clear_gssp_clnt(sn); + } +} +#else /* CONFIG_PROC_FS */ + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) {} + +#endif /* CONFIG_PROC_FS */ + +/* + * Accept an rpcsec packet. + * If context establishment, punt to user space + * If data exchange, verify/decrypt + * If context destruction, handle here + * In the context establishment and destruction case we encode + * response here and return SVC_COMPLETE. + */ +static int +svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + u32 crlen; + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + struct rsc *rsci = NULL; + __be32 *rpcstart; + __be32 *reject_stat = resv->iov_base + resv->iov_len; + int ret; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + dprintk("RPC: svcauth_gss: argv->iov_len = %zd\n", + argv->iov_len); + + *authp = rpc_autherr_badcred; + if (!svcdata) + svcdata = kmalloc(sizeof(*svcdata), GFP_KERNEL); + if (!svcdata) + goto auth_err; + rqstp->rq_auth_data = svcdata; + svcdata->verf_start = NULL; + svcdata->rsci = NULL; + gc = &svcdata->clcred; + + /* start of rpc packet is 7 u32's back from here: + * xid direction rpcversion prog vers proc flavour + */ + rpcstart = argv->iov_base; + rpcstart -= 7; + + /* credential is: + * version(==1), proc(0,1,2,3), seq, service (1,2,3), handle + * at least 5 u32s, and is preceded by length, so that makes 6. + */ + + if (argv->iov_len < 5 * 4) + goto auth_err; + crlen = svc_getnl(argv); + if (svc_getnl(argv) != RPC_GSS_VERSION) + goto auth_err; + gc->gc_proc = svc_getnl(argv); + gc->gc_seq = svc_getnl(argv); + gc->gc_svc = svc_getnl(argv); + if (svc_safe_getnetobj(argv, &gc->gc_ctx)) + goto auth_err; + if (crlen != round_up_to_quad(gc->gc_ctx.len) + 5 * 4) + goto auth_err; + + if ((gc->gc_proc != RPC_GSS_PROC_DATA) && (rqstp->rq_proc != 0)) + goto auth_err; + + *authp = rpc_autherr_badverf; + switch (gc->gc_proc) { + case RPC_GSS_PROC_INIT: + case RPC_GSS_PROC_CONTINUE_INIT: + if (use_gss_proxy(SVC_NET(rqstp))) + return svcauth_gss_proxy_init(rqstp, gc, authp); + else + return svcauth_gss_legacy_init(rqstp, gc, authp); + case RPC_GSS_PROC_DATA: + case RPC_GSS_PROC_DESTROY: + /* Look up the context, and check the verifier: */ + *authp = rpcsec_gsserr_credproblem; + rsci = gss_svc_searchbyctx(sn->rsc_cache, &gc->gc_ctx); + if (!rsci) + goto auth_err; + switch (gss_verify_header(rqstp, rsci, rpcstart, gc, authp)) { + case SVC_OK: + break; + case SVC_DENIED: + goto auth_err; + case SVC_DROP: + goto drop; + } + break; + default: + *authp = rpc_autherr_rejectedcred; + goto auth_err; + } + + /* now act upon the command: */ + switch (gc->gc_proc) { + case RPC_GSS_PROC_DESTROY: + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + /* Delete the entry from the cache_list and call cache_put */ + sunrpc_cache_unhash(sn->rsc_cache, &rsci->h); + if (resv->iov_len + 4 > PAGE_SIZE) + goto drop; + svc_putnl(resv, RPC_SUCCESS); + goto complete; + case RPC_GSS_PROC_DATA: + *authp = rpcsec_gsserr_ctxproblem; + svcdata->verf_start = resv->iov_base + resv->iov_len; + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + rqstp->rq_cred = rsci->cred; + get_group_info(rsci->cred.cr_group_info); + *authp = rpc_autherr_badcred; + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_integ_data(rqstp, &rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE; + break; + case RPC_GSS_SVC_PRIVACY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_priv_data(rqstp, &rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE * 2; + break; + default: + goto auth_err; + } + svcdata->rsci = rsci; + cache_get(&rsci->h); + rqstp->rq_cred.cr_flavor = gss_svc_to_pseudoflavor( + rsci->mechctx->mech_type, + GSS_C_QOP_DEFAULT, + gc->gc_svc); + ret = SVC_OK; + goto out; + } +garbage_args: + ret = SVC_GARBAGE; + goto out; +auth_err: + /* Restore write pointer to its original value: */ + xdr_ressize_check(rqstp, reject_stat); + ret = SVC_DENIED; + goto out; +complete: + ret = SVC_COMPLETE; + goto out; +drop: + ret = SVC_CLOSE; +out: + if (rsci) + cache_put(&rsci->h, sn->rsc_cache); + return ret; +} + +static __be32 * +svcauth_gss_prepare_to_wrap(struct xdr_buf *resbuf, struct gss_svc_data *gsd) +{ + __be32 *p; + u32 verf_len; + + p = gsd->verf_start; + gsd->verf_start = NULL; + + /* If the reply stat is nonzero, don't wrap: */ + if (*(p-1) != rpc_success) + return NULL; + /* Skip the verifier: */ + p += 1; + verf_len = ntohl(*p++); + p += XDR_QUADLEN(verf_len); + /* move accept_stat to right place: */ + memcpy(p, p + 2, 4); + /* Also don't wrap if the accept stat is nonzero: */ + if (*p != rpc_success) { + resbuf->head[0].iov_len -= 2 * 4; + return NULL; + } + p++; + return p; +} + +static inline int +svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct xdr_buf integ_buf; + struct xdr_netobj mic; + struct kvec *resv; + __be32 *p; + int integ_offset, integ_len; + int stat = -EINVAL; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + goto out; + integ_offset = (u8 *)(p + 1) - (u8 *)resbuf->head[0].iov_base; + integ_len = resbuf->len - integ_offset; + BUG_ON(integ_len % 4); + *p++ = htonl(integ_len); + *p++ = htonl(gc->gc_seq); + if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, integ_len)) { + WARN_ON_ONCE(1); + goto out_err; + } + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE > PAGE_SIZE) + goto out_err; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len; + resbuf->tail[0].iov_len = 0; + } + resv = &resbuf->tail[0]; + mic.data = (u8 *)resv->iov_base + resv->iov_len + 4; + if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic)) + goto out_err; + svc_putnl(resv, mic.len); + memset(mic.data + mic.len, 0, + round_up_to_quad(mic.len) - mic.len); + resv->iov_len += XDR_QUADLEN(mic.len) << 2; + /* not strictly required: */ + resbuf->len += XDR_QUADLEN(mic.len) << 2; + BUG_ON(resv->iov_len > PAGE_SIZE); +out: + stat = 0; +out_err: + return stat; +} + +static inline int +svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct page **inpages = NULL; + __be32 *p, *len; + int offset; + int pad; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + return 0; + len = p++; + offset = (u8 *)p - (u8 *)resbuf->head[0].iov_base; + *p++ = htonl(gc->gc_seq); + inpages = resbuf->pages; + /* XXX: Would be better to write some xdr helper functions for + * nfs{2,3,4}xdr.c that place the data right, instead of copying: */ + + /* + * If there is currently tail data, make sure there is + * room for the head, tail, and 2 * RPC_MAX_AUTH_SIZE in + * the page, and move the current tail data such that + * there is RPC_MAX_AUTH_SIZE slack space available in + * both the head and tail. + */ + if (resbuf->tail[0].iov_base) { + BUG_ON(resbuf->tail[0].iov_base >= resbuf->head[0].iov_base + + PAGE_SIZE); + BUG_ON(resbuf->tail[0].iov_base < resbuf->head[0].iov_base); + if (resbuf->tail[0].iov_len + resbuf->head[0].iov_len + + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + memmove(resbuf->tail[0].iov_base + RPC_MAX_AUTH_SIZE, + resbuf->tail[0].iov_base, + resbuf->tail[0].iov_len); + resbuf->tail[0].iov_base += RPC_MAX_AUTH_SIZE; + } + /* + * If there is no current tail data, make sure there is + * room for the head data, and 2 * RPC_MAX_AUTH_SIZE in the + * allotted page, and set up tail information such that there + * is RPC_MAX_AUTH_SIZE slack space available in both the + * head and tail. + */ + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + 2*RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE; + resbuf->tail[0].iov_len = 0; + } + if (gss_wrap(gsd->rsci->mechctx, offset, resbuf, inpages)) + return -ENOMEM; + *len = htonl(resbuf->len - offset); + pad = 3 - ((resbuf->len - offset - 1)&3); + p = (__be32 *)(resbuf->tail[0].iov_base + resbuf->tail[0].iov_len); + memset(p, 0, pad); + resbuf->tail[0].iov_len += pad; + resbuf->len += pad; + return 0; +} + +static int +svcauth_gss_release(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + struct xdr_buf *resbuf = &rqstp->rq_res; + int stat = -EINVAL; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + if (!gsd) + goto out; + gc = &gsd->clcred; + if (gc->gc_proc != RPC_GSS_PROC_DATA) + goto out; + /* Release can be called twice, but we only wrap once. */ + if (gsd->verf_start == NULL) + goto out; + /* normally not set till svc_send, but we need it here: */ + /* XXX: what for? Do we mess it up the moment we call svc_putu32 + * or whatever? */ + resbuf->len = total_buf_len(resbuf); + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + stat = svcauth_gss_wrap_resp_integ(rqstp); + if (stat) + goto out_err; + break; + case RPC_GSS_SVC_PRIVACY: + stat = svcauth_gss_wrap_resp_priv(rqstp); + if (stat) + goto out_err; + break; + /* + * For any other gc_svc value, svcauth_gss_accept() already set + * the auth_error appropriately; just fall through: + */ + } + +out: + stat = 0; +out_err: + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_gssclient) + auth_domain_put(rqstp->rq_gssclient); + rqstp->rq_gssclient = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + if (gsd && gsd->rsci) { + cache_put(&gsd->rsci->h, sn->rsc_cache); + gsd->rsci = NULL; + } + return stat; +} + +static void +svcauth_gss_domain_release(struct auth_domain *dom) +{ + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + kfree(dom->name); + kfree(gd); +} + +static struct auth_ops svcauthops_gss = { + .name = "rpcsec_gss", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_GSS, + .accept = svcauth_gss_accept, + .release = svcauth_gss_release, + .domain_release = svcauth_gss_domain_release, + .set_client = svcauth_gss_set_client, +}; + +static int rsi_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsi_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsi_cache = cd; + return 0; +} + +static void rsi_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsi_cache; + + sn->rsi_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static int rsc_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsc_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsc_cache = cd; + return 0; +} + +static void rsc_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsc_cache; + + sn->rsc_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +int +gss_svc_init_net(struct net *net) +{ + int rv; + + rv = rsc_cache_create_net(net); + if (rv) + return rv; + rv = rsi_cache_create_net(net); + if (rv) + goto out1; + rv = create_use_gss_proxy_proc_entry(net); + if (rv) + goto out2; + return 0; +out2: + rsi_cache_destroy_net(net); +out1: + rsc_cache_destroy_net(net); + return rv; +} + +void +gss_svc_shutdown_net(struct net *net) +{ + destroy_use_gss_proxy_proc_entry(net); + rsi_cache_destroy_net(net); + rsc_cache_destroy_net(net); +} + +int +gss_svc_init(void) +{ + return svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss); +} + +void +gss_svc_shutdown(void) +{ + svc_auth_unregister(RPC_AUTH_GSS); +} diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c new file mode 100644 index 000000000..4b48228ee --- /dev/null +++ b/net/sunrpc/auth_null.c @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/auth_null.c + * + * AUTH_NULL authentication. Really :-) + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sunrpc/clnt.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth null_auth; +static struct rpc_cred null_cred; + +static struct rpc_auth * +nul_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + atomic_inc(&null_auth.au_count); + return &null_auth; +} + +static void +nul_destroy(struct rpc_auth *auth) +{ +} + +/* + * Lookup NULL creds for current process + */ +static struct rpc_cred * +nul_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + if (flags & RPCAUTH_LOOKUP_RCU) + return &null_cred; + return get_rpccred(&null_cred); +} + +/* + * Destroy cred handle. + */ +static void +nul_destroy_cred(struct rpc_cred *cred) +{ +} + +/* + * Match cred handle against current process + */ +static int +nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) +{ + return 1; +} + +/* + * Marshal credential. + */ +static __be32 * +nul_marshal(struct rpc_task *task, __be32 *p) +{ + *p++ = htonl(RPC_AUTH_NULL); + *p++ = 0; + *p++ = htonl(RPC_AUTH_NULL); + *p++ = 0; + + return p; +} + +/* + * Refresh credential. This is a no-op for AUTH_NULL + */ +static int +nul_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static __be32 * +nul_validate(struct rpc_task *task, __be32 *p) +{ + rpc_authflavor_t flavor; + u32 size; + + flavor = ntohl(*p++); + if (flavor != RPC_AUTH_NULL) { + printk("RPC: bad verf flavor: %u\n", flavor); + return ERR_PTR(-EIO); + } + + size = ntohl(*p++); + if (size != 0) { + printk("RPC: bad verf size: %u\n", size); + return ERR_PTR(-EIO); + } + + return p; +} + +const struct rpc_authops authnull_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_NULL, + .au_name = "NULL", + .create = nul_create, + .destroy = nul_destroy, + .lookup_cred = nul_lookup_cred, +}; + +static +struct rpc_auth null_auth = { + .au_cslack = NUL_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_flags = RPCAUTH_AUTH_NO_CRKEY_TIMEOUT, + .au_ops = &authnull_ops, + .au_flavor = RPC_AUTH_NULL, + .au_count = ATOMIC_INIT(0), +}; + +static +const struct rpc_credops null_credops = { + .cr_name = "AUTH_NULL", + .crdestroy = nul_destroy_cred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = nul_match, + .crmarshal = nul_marshal, + .crrefresh = nul_refresh, + .crvalidate = nul_validate, +}; + +static +struct rpc_cred null_cred = { + .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru), + .cr_auth = &null_auth, + .cr_ops = &null_credops, + .cr_count = ATOMIC_INIT(1), + .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, +}; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c new file mode 100644 index 000000000..185e56d4f --- /dev/null +++ b/net/sunrpc/auth_unix.c @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/auth_unix.c + * + * UNIX-style authentication; no AUTH_SHORT support + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/slab.h> +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/auth.h> +#include <linux/user_namespace.h> + +struct unx_cred { + struct rpc_cred uc_base; + kgid_t uc_gid; + kgid_t uc_gids[UNX_NGROUPS]; +}; +#define uc_uid uc_base.cr_uid + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth unix_auth; +static const struct rpc_credops unix_credops; + +static struct rpc_auth * +unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + dprintk("RPC: creating UNIX authenticator for client %p\n", + clnt); + atomic_inc(&unix_auth.au_count); + return &unix_auth; +} + +static void +unx_destroy(struct rpc_auth *auth) +{ + dprintk("RPC: destroying UNIX authenticator %p\n", auth); + rpcauth_clear_credcache(auth->au_credcache); +} + +static int +unx_hash_cred(struct auth_cred *acred, unsigned int hashbits) +{ + return hash_64(from_kgid(&init_user_ns, acred->gid) | + ((u64)from_kuid(&init_user_ns, acred->uid) << + (sizeof(gid_t) * 8)), hashbits); +} + +/* + * Lookup AUTH_UNIX creds for current process + */ +static struct rpc_cred * +unx_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(auth, acred, flags, GFP_NOFS); +} + +static struct rpc_cred * +unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp) +{ + struct unx_cred *cred; + unsigned int groups = 0; + unsigned int i; + + dprintk("RPC: allocating UNIX cred for uid %d gid %d\n", + from_kuid(&init_user_ns, acred->uid), + from_kgid(&init_user_ns, acred->gid)); + + if (!(cred = kmalloc(sizeof(*cred), gfp))) + return ERR_PTR(-ENOMEM); + + rpcauth_init_cred(&cred->uc_base, acred, auth, &unix_credops); + cred->uc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + + if (acred->group_info != NULL) + groups = acred->group_info->ngroups; + if (groups > UNX_NGROUPS) + groups = UNX_NGROUPS; + + cred->uc_gid = acred->gid; + for (i = 0; i < groups; i++) + cred->uc_gids[i] = acred->group_info->gid[i]; + if (i < UNX_NGROUPS) + cred->uc_gids[i] = INVALID_GID; + + return &cred->uc_base; +} + +static void +unx_free_cred(struct unx_cred *unx_cred) +{ + dprintk("RPC: unx_free_cred %p\n", unx_cred); + kfree(unx_cred); +} + +static void +unx_free_cred_callback(struct rcu_head *head) +{ + struct unx_cred *unx_cred = container_of(head, struct unx_cred, uc_base.cr_rcu); + unx_free_cred(unx_cred); +} + +static void +unx_destroy_cred(struct rpc_cred *cred) +{ + call_rcu(&cred->cr_rcu, unx_free_cred_callback); +} + +/* + * Match credentials against current process creds. + * The root_override argument takes care of cases where the caller may + * request root creds (e.g. for NFS swapping). + */ +static int +unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) +{ + struct unx_cred *cred = container_of(rcred, struct unx_cred, uc_base); + unsigned int groups = 0; + unsigned int i; + + + if (!uid_eq(cred->uc_uid, acred->uid) || !gid_eq(cred->uc_gid, acred->gid)) + return 0; + + if (acred->group_info != NULL) + groups = acred->group_info->ngroups; + if (groups > UNX_NGROUPS) + groups = UNX_NGROUPS; + for (i = 0; i < groups ; i++) + if (!gid_eq(cred->uc_gids[i], acred->group_info->gid[i])) + return 0; + if (groups < UNX_NGROUPS && gid_valid(cred->uc_gids[groups])) + return 0; + return 1; +} + +/* + * Marshal credentials. + * Maybe we should keep a cached credential for performance reasons. + */ +static __be32 * +unx_marshal(struct rpc_task *task, __be32 *p) +{ + struct rpc_clnt *clnt = task->tk_client; + struct unx_cred *cred = container_of(task->tk_rqstp->rq_cred, struct unx_cred, uc_base); + __be32 *base, *hold; + int i; + + *p++ = htonl(RPC_AUTH_UNIX); + base = p++; + *p++ = htonl(jiffies/HZ); + + /* + * Copy the UTS nodename captured when the client was created. + */ + p = xdr_encode_array(p, clnt->cl_nodename, clnt->cl_nodelen); + + *p++ = htonl((u32) from_kuid(&init_user_ns, cred->uc_uid)); + *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gid)); + hold = p++; + for (i = 0; i < UNX_NGROUPS && gid_valid(cred->uc_gids[i]); i++) + *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gids[i])); + *hold = htonl(p - hold - 1); /* gid array length */ + *base = htonl((p - base - 1) << 2); /* cred length */ + + *p++ = htonl(RPC_AUTH_NULL); + *p++ = htonl(0); + + return p; +} + +/* + * Refresh credentials. This is a no-op for AUTH_UNIX + */ +static int +unx_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static __be32 * +unx_validate(struct rpc_task *task, __be32 *p) +{ + rpc_authflavor_t flavor; + u32 size; + + flavor = ntohl(*p++); + if (flavor != RPC_AUTH_NULL && + flavor != RPC_AUTH_UNIX && + flavor != RPC_AUTH_SHORT) { + printk("RPC: bad verf flavor: %u\n", flavor); + return ERR_PTR(-EIO); + } + + size = ntohl(*p++); + if (size > RPC_MAX_AUTH_SIZE) { + printk("RPC: giant verf size: %u\n", size); + return ERR_PTR(-EIO); + } + task->tk_rqstp->rq_cred->cr_auth->au_rslack = (size >> 2) + 2; + p += (size >> 2); + + return p; +} + +int __init rpc_init_authunix(void) +{ + return rpcauth_init_credcache(&unix_auth); +} + +void rpc_destroy_authunix(void) +{ + rpcauth_destroy_credcache(&unix_auth); +} + +const struct rpc_authops authunix_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_UNIX, + .au_name = "UNIX", + .create = unx_create, + .destroy = unx_destroy, + .hash_cred = unx_hash_cred, + .lookup_cred = unx_lookup_cred, + .crcreate = unx_create_cred, +}; + +static +struct rpc_auth unix_auth = { + .au_cslack = UNX_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_flags = RPCAUTH_AUTH_NO_CRKEY_TIMEOUT, + .au_ops = &authunix_ops, + .au_flavor = RPC_AUTH_UNIX, + .au_count = ATOMIC_INIT(0), +}; + +static +const struct rpc_credops unix_credops = { + .cr_name = "AUTH_UNIX", + .crdestroy = unx_destroy_cred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = unx_match, + .crmarshal = unx_marshal, + .crrefresh = unx_refresh, + .crvalidate = unx_validate, +}; diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c new file mode 100644 index 000000000..3c15a99b9 --- /dev/null +++ b/net/sunrpc/backchannel_rqst.c @@ -0,0 +1,364 @@ +/****************************************************************************** + +(c) 2007 Network Appliance, Inc. All Rights Reserved. +(c) 2009 NetApp. All Rights Reserved. + +NetApp provides this source code under the GPL v2 License. +The GPL v2 license is available at +http://opensource.org/licenses/gpl-license.php. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +******************************************************************************/ + +#include <linux/tcp.h> +#include <linux/slab.h> +#include <linux/sunrpc/xprt.h> +#include <linux/export.h> +#include <linux/sunrpc/bc_xprt.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +#define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* + * Helper routines that track the number of preallocation elements + * on the transport. + */ +static inline int xprt_need_to_requeue(struct rpc_xprt *xprt) +{ + return xprt->bc_alloc_count < atomic_read(&xprt->bc_free_slots); +} + +static inline void xprt_inc_alloc_count(struct rpc_xprt *xprt, unsigned int n) +{ + atomic_add(n, &xprt->bc_free_slots); + xprt->bc_alloc_count += n; +} + +static inline int xprt_dec_alloc_count(struct rpc_xprt *xprt, unsigned int n) +{ + atomic_sub(n, &xprt->bc_free_slots); + return xprt->bc_alloc_count -= n; +} + +/* + * Free the preallocated rpc_rqst structure and the memory + * buffers hanging off of it. + */ +static void xprt_free_allocation(struct rpc_rqst *req) +{ + struct xdr_buf *xbufp; + + dprintk("RPC: free allocations for req= %p\n", req); + WARN_ON_ONCE(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); + xbufp = &req->rq_rcv_buf; + free_page((unsigned long)xbufp->head[0].iov_base); + xbufp = &req->rq_snd_buf; + free_page((unsigned long)xbufp->head[0].iov_base); + kfree(req); +} + +static int xprt_alloc_xdr_buf(struct xdr_buf *buf, gfp_t gfp_flags) +{ + struct page *page; + /* Preallocate one XDR receive buffer */ + page = alloc_page(gfp_flags); + if (page == NULL) + return -ENOMEM; + xdr_buf_init(buf, page_address(page), PAGE_SIZE); + return 0; +} + +static +struct rpc_rqst *xprt_alloc_bc_req(struct rpc_xprt *xprt, gfp_t gfp_flags) +{ + struct rpc_rqst *req; + + /* Pre-allocate one backchannel rpc_rqst */ + req = kzalloc(sizeof(*req), gfp_flags); + if (req == NULL) + return NULL; + + req->rq_xprt = xprt; + INIT_LIST_HEAD(&req->rq_list); + INIT_LIST_HEAD(&req->rq_bc_list); + + /* Preallocate one XDR receive buffer */ + if (xprt_alloc_xdr_buf(&req->rq_rcv_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc receive xbuf\n"); + goto out_free; + } + req->rq_rcv_buf.len = PAGE_SIZE; + + /* Preallocate one XDR send buffer */ + if (xprt_alloc_xdr_buf(&req->rq_snd_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc snd xbuf\n"); + goto out_free; + } + return req; +out_free: + xprt_free_allocation(req); + return NULL; +} + +/* + * Preallocate up to min_reqs structures and related buffers for use + * by the backchannel. This function can be called multiple times + * when creating new sessions that use the same rpc_xprt. The + * preallocated buffers are added to the pool of resources used by + * the rpc_xprt. Anyone of these resources may be used used by an + * incoming callback request. It's up to the higher levels in the + * stack to enforce that the maximum number of session slots is not + * being exceeded. + * + * Some callback arguments can be large. For example, a pNFS server + * using multiple deviceids. The list can be unbound, but the client + * has the ability to tell the server the maximum size of the callback + * requests. Each deviceID is 16 bytes, so allocate one page + * for the arguments to have enough room to receive a number of these + * deviceIDs. The NFS client indicates to the pNFS server that its + * callback requests can be up to 4096 bytes in size. + */ +int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs) +{ + if (!xprt->ops->bc_setup) + return 0; + return xprt->ops->bc_setup(xprt, min_reqs); +} +EXPORT_SYMBOL_GPL(xprt_setup_backchannel); + +int xprt_setup_bc(struct rpc_xprt *xprt, unsigned int min_reqs) +{ + struct rpc_rqst *req; + struct list_head tmp_list; + int i; + + dprintk("RPC: setup backchannel transport\n"); + + /* + * We use a temporary list to keep track of the preallocated + * buffers. Once we're done building the list we splice it + * into the backchannel preallocation list off of the rpc_xprt + * struct. This helps minimize the amount of time the list + * lock is held on the rpc_xprt struct. It also makes cleanup + * easier in case of memory allocation errors. + */ + INIT_LIST_HEAD(&tmp_list); + for (i = 0; i < min_reqs; i++) { + /* Pre-allocate one backchannel rpc_rqst */ + req = xprt_alloc_bc_req(xprt, GFP_KERNEL); + if (req == NULL) { + printk(KERN_ERR "Failed to create bc rpc_rqst\n"); + goto out_free; + } + + /* Add the allocated buffer to the tmp list */ + dprintk("RPC: adding req= %p\n", req); + list_add(&req->rq_bc_pa_list, &tmp_list); + } + + /* + * Add the temporary list to the backchannel preallocation list + */ + spin_lock(&xprt->bc_pa_lock); + list_splice(&tmp_list, &xprt->bc_pa_list); + xprt_inc_alloc_count(xprt, min_reqs); + spin_unlock(&xprt->bc_pa_lock); + + dprintk("RPC: setup backchannel transport done\n"); + return 0; + +out_free: + /* + * Memory allocation failed, free the temporary list + */ + while (!list_empty(&tmp_list)) { + req = list_first_entry(&tmp_list, + struct rpc_rqst, + rq_bc_pa_list); + list_del(&req->rq_bc_pa_list); + xprt_free_allocation(req); + } + + dprintk("RPC: setup backchannel transport failed\n"); + return -ENOMEM; +} + +/** + * xprt_destroy_backchannel - Destroys the backchannel preallocated structures. + * @xprt: the transport holding the preallocated strucures + * @max_reqs the maximum number of preallocated structures to destroy + * + * Since these structures may have been allocated by multiple calls + * to xprt_setup_backchannel, we only destroy up to the maximum number + * of reqs specified by the caller. + */ +void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) +{ + if (xprt->ops->bc_destroy) + xprt->ops->bc_destroy(xprt, max_reqs); +} +EXPORT_SYMBOL_GPL(xprt_destroy_backchannel); + +void xprt_destroy_bc(struct rpc_xprt *xprt, unsigned int max_reqs) +{ + struct rpc_rqst *req = NULL, *tmp = NULL; + + dprintk("RPC: destroy backchannel transport\n"); + + if (max_reqs == 0) + goto out; + + spin_lock_bh(&xprt->bc_pa_lock); + xprt_dec_alloc_count(xprt, max_reqs); + list_for_each_entry_safe(req, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { + dprintk("RPC: req=%p\n", req); + list_del(&req->rq_bc_pa_list); + xprt_free_allocation(req); + if (--max_reqs == 0) + break; + } + spin_unlock_bh(&xprt->bc_pa_lock); + +out: + dprintk("RPC: backchannel list empty= %s\n", + list_empty(&xprt->bc_pa_list) ? "true" : "false"); +} + +static struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *req = NULL; + + dprintk("RPC: allocate a backchannel request\n"); + if (atomic_read(&xprt->bc_free_slots) <= 0) + goto not_found; + if (list_empty(&xprt->bc_pa_list)) { + req = xprt_alloc_bc_req(xprt, GFP_ATOMIC); + if (!req) + goto not_found; + list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + } + req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + req->rq_reply_bytes_recvd = 0; + req->rq_bytes_sent = 0; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + req->rq_xid = xid; + req->rq_connect_cookie = xprt->connect_cookie; +not_found: + dprintk("RPC: backchannel req=%p\n", req); + return req; +} + +/* + * Return the preallocated rpc_rqst structure and XDR buffers + * associated with this rpc_task. + */ +void xprt_free_bc_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + xprt->ops->bc_free_rqst(req); +} + +void xprt_free_bc_rqst(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + dprintk("RPC: free backchannel req=%p\n", req); + + req->rq_connect_cookie = xprt->connect_cookie - 1; + smp_mb__before_atomic(); + clear_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + smp_mb__after_atomic(); + + /* + * Return it to the list of preallocations so that it + * may be reused by a new callback request. + */ + spin_lock_bh(&xprt->bc_pa_lock); + if (xprt_need_to_requeue(xprt)) { + list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + req = NULL; + } + spin_unlock_bh(&xprt->bc_pa_lock); + if (req != NULL) { + /* + * The last remaining session was destroyed while this + * entry was in use. Free the entry and don't attempt + * to add back to the list because there is no need to + * have anymore preallocated entries. + */ + dprintk("RPC: Last session removed req=%p\n", req); + xprt_free_allocation(req); + return; + } +} + +/* + * One or more rpc_rqst structure have been preallocated during the + * backchannel setup. Buffer space for the send and private XDR buffers + * has been preallocated as well. Use xprt_alloc_bc_request to allocate + * to this request. Use xprt_free_bc_request to return it. + * + * We know that we're called in soft interrupt context, grab the spin_lock + * since there is no need to grab the bottom half spin_lock. + * + * Return an available rpc_rqst, otherwise NULL if non are available. + */ +struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *req; + + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { + if (req->rq_connect_cookie != xprt->connect_cookie) + continue; + if (req->rq_xid == xid) + goto found; + } + req = xprt_alloc_bc_request(xprt, xid); +found: + spin_unlock(&xprt->bc_pa_lock); + return req; +} + +/* + * Add callback request to callback list. The callback + * service sleeps on the sv_cb_waitq waiting for new + * requests. Wake it up after adding enqueing the + * request. + */ +void xprt_complete_bc_request(struct rpc_rqst *req, uint32_t copied) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct svc_serv *bc_serv = xprt->bc_serv; + + spin_lock(&xprt->bc_pa_lock); + list_del(&req->rq_bc_pa_list); + xprt_dec_alloc_count(xprt, 1); + spin_unlock(&xprt->bc_pa_lock); + + req->rq_private_buf.len = copied; + set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + + dprintk("RPC: add callback request to list\n"); + spin_lock(&bc_serv->sv_cb_lock); + list_add(&req->rq_bc_list, &bc_serv->sv_cb_list); + wake_up(&bc_serv->sv_cb_waitq); + spin_unlock(&bc_serv->sv_cb_lock); +} diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c new file mode 100644 index 000000000..3a28e150b --- /dev/null +++ b/net/sunrpc/cache.c @@ -0,0 +1,1861 @@ +/* + * net/sunrpc/cache.c + * + * Generic code for various authentication-related caches + * used by sunrpc clients and servers. + * + * Copyright (C) 2002 Neil Brown <neilb@cse.unsw.edu.au> + * + * Released under terms in GPL version 2. See COPYING. + * + */ + +#include <linux/types.h> +#include <linux/fs.h> +#include <linux/file.h> +#include <linux/slab.h> +#include <linux/signal.h> +#include <linux/sched.h> +#include <linux/kmod.h> +#include <linux/list.h> +#include <linux/module.h> +#include <linux/ctype.h> +#include <linux/string_helpers.h> +#include <linux/uaccess.h> +#include <linux/poll.h> +#include <linux/seq_file.h> +#include <linux/proc_fs.h> +#include <linux/net.h> +#include <linux/workqueue.h> +#include <linux/mutex.h> +#include <linux/pagemap.h> +#include <asm/ioctls.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/cache.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include "netns.h" + +#define RPCDBG_FACILITY RPCDBG_CACHE + +static bool cache_defer_req(struct cache_req *req, struct cache_head *item); +static void cache_revisit_request(struct cache_head *item); + +static void cache_init(struct cache_head *h, struct cache_detail *detail) +{ + time_t now = seconds_since_boot(); + INIT_HLIST_NODE(&h->cache_list); + h->flags = 0; + kref_init(&h->ref); + h->expiry_time = now + CACHE_NEW_EXPIRY; + if (now <= detail->flush_time) + /* ensure it isn't already expired */ + now = detail->flush_time + 1; + h->last_refresh = now; +} + +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail); + +struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, + struct cache_head *key, int hash) +{ + struct cache_head *new = NULL, *freeme = NULL, *tmp = NULL; + struct hlist_head *head; + + head = &detail->hash_table[hash]; + + read_lock(&detail->hash_lock); + + hlist_for_each_entry(tmp, head, cache_list) { + if (detail->match(tmp, key)) { + if (cache_is_expired(detail, tmp)) + /* This entry is expired, we will discard it. */ + break; + cache_get(tmp); + read_unlock(&detail->hash_lock); + return tmp; + } + } + read_unlock(&detail->hash_lock); + /* Didn't find anything, insert an empty entry */ + + new = detail->alloc(); + if (!new) + return NULL; + /* must fully initialise 'new', else + * we might get lose if we need to + * cache_put it soon. + */ + cache_init(new, detail); + detail->init(new, key); + + write_lock(&detail->hash_lock); + + /* check if entry appeared while we slept */ + hlist_for_each_entry(tmp, head, cache_list) { + if (detail->match(tmp, key)) { + if (cache_is_expired(detail, tmp)) { + hlist_del_init(&tmp->cache_list); + detail->entries --; + freeme = tmp; + break; + } + cache_get(tmp); + write_unlock(&detail->hash_lock); + cache_put(new, detail); + return tmp; + } + } + + hlist_add_head(&new->cache_list, head); + detail->entries++; + cache_get(new); + write_unlock(&detail->hash_lock); + + if (freeme) { + cache_fresh_unlocked(freeme, detail); + cache_put(freeme, detail); + } + return new; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_lookup); + + +static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch); + +static void cache_fresh_locked(struct cache_head *head, time_t expiry, + struct cache_detail *detail) +{ + time_t now = seconds_since_boot(); + if (now <= detail->flush_time) + /* ensure it isn't immediately treated as expired */ + now = detail->flush_time + 1; + head->expiry_time = expiry; + head->last_refresh = now; + smp_wmb(); /* paired with smp_rmb() in cache_is_valid() */ + set_bit(CACHE_VALID, &head->flags); +} + +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail) +{ + if (test_and_clear_bit(CACHE_PENDING, &head->flags)) { + cache_revisit_request(head); + cache_dequeue(detail, head); + } +} + +struct cache_head *sunrpc_cache_update(struct cache_detail *detail, + struct cache_head *new, struct cache_head *old, int hash) +{ + /* The 'old' entry is to be replaced by 'new'. + * If 'old' is not VALID, we update it directly, + * otherwise we need to replace it + */ + struct cache_head *tmp; + + if (!test_bit(CACHE_VALID, &old->flags)) { + write_lock(&detail->hash_lock); + if (!test_bit(CACHE_VALID, &old->flags)) { + if (test_bit(CACHE_NEGATIVE, &new->flags)) + set_bit(CACHE_NEGATIVE, &old->flags); + else + detail->update(old, new); + cache_fresh_locked(old, new->expiry_time, detail); + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(old, detail); + return old; + } + write_unlock(&detail->hash_lock); + } + /* We need to insert a new entry */ + tmp = detail->alloc(); + if (!tmp) { + cache_put(old, detail); + return NULL; + } + cache_init(tmp, detail); + detail->init(tmp, old); + + write_lock(&detail->hash_lock); + if (test_bit(CACHE_NEGATIVE, &new->flags)) + set_bit(CACHE_NEGATIVE, &tmp->flags); + else + detail->update(tmp, new); + hlist_add_head(&tmp->cache_list, &detail->hash_table[hash]); + detail->entries++; + cache_get(tmp); + cache_fresh_locked(tmp, new->expiry_time, detail); + cache_fresh_locked(old, 0, detail); + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(tmp, detail); + cache_fresh_unlocked(old, detail); + cache_put(old, detail); + return tmp; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_update); + +static int cache_make_upcall(struct cache_detail *cd, struct cache_head *h) +{ + if (cd->cache_upcall) + return cd->cache_upcall(cd, h); + return sunrpc_cache_pipe_upcall(cd, h); +} + +static inline int cache_is_valid(struct cache_head *h) +{ + if (!test_bit(CACHE_VALID, &h->flags)) + return -EAGAIN; + else { + /* entry is valid */ + if (test_bit(CACHE_NEGATIVE, &h->flags)) + return -ENOENT; + else { + /* + * In combination with write barrier in + * sunrpc_cache_update, ensures that anyone + * using the cache entry after this sees the + * updated contents: + */ + smp_rmb(); + return 0; + } + } +} + +static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h) +{ + int rv; + + write_lock(&detail->hash_lock); + rv = cache_is_valid(h); + if (rv == -EAGAIN) { + set_bit(CACHE_NEGATIVE, &h->flags); + cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY, + detail); + rv = -ENOENT; + } + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(h, detail); + return rv; +} + +/* + * This is the generic cache management routine for all + * the authentication caches. + * It checks the currency of a cache item and will (later) + * initiate an upcall to fill it if needed. + * + * + * Returns 0 if the cache_head can be used, or cache_puts it and returns + * -EAGAIN if upcall is pending and request has been queued + * -ETIMEDOUT if upcall failed or request could not be queue or + * upcall completed but item is still invalid (implying that + * the cache item has been replaced with a newer one). + * -ENOENT if cache entry was negative + */ +int cache_check(struct cache_detail *detail, + struct cache_head *h, struct cache_req *rqstp) +{ + int rv; + long refresh_age, age; + + /* First decide return status as best we can */ + rv = cache_is_valid(h); + + /* now see if we want to start an upcall */ + refresh_age = (h->expiry_time - h->last_refresh); + age = seconds_since_boot() - h->last_refresh; + + if (rqstp == NULL) { + if (rv == -EAGAIN) + rv = -ENOENT; + } else if (rv == -EAGAIN || + (h->expiry_time != 0 && age > refresh_age/2)) { + dprintk("RPC: Want update, refage=%ld, age=%ld\n", + refresh_age, age); + if (!test_and_set_bit(CACHE_PENDING, &h->flags)) { + switch (cache_make_upcall(detail, h)) { + case -EINVAL: + rv = try_to_negate_entry(detail, h); + break; + case -EAGAIN: + cache_fresh_unlocked(h, detail); + break; + } + } + } + + if (rv == -EAGAIN) { + if (!cache_defer_req(rqstp, h)) { + /* + * Request was not deferred; handle it as best + * we can ourselves: + */ + rv = cache_is_valid(h); + if (rv == -EAGAIN) + rv = -ETIMEDOUT; + } + } + if (rv) + cache_put(h, detail); + return rv; +} +EXPORT_SYMBOL_GPL(cache_check); + +/* + * caches need to be periodically cleaned. + * For this we maintain a list of cache_detail and + * a current pointer into that list and into the table + * for that entry. + * + * Each time cache_clean is called it finds the next non-empty entry + * in the current table and walks the list in that entry + * looking for entries that can be removed. + * + * An entry gets removed if: + * - The expiry is before current time + * - The last_refresh time is before the flush_time for that cache + * + * later we might drop old entries with non-NEVER expiry if that table + * is getting 'full' for some definition of 'full' + * + * The question of "how often to scan a table" is an interesting one + * and is answered in part by the use of the "nextcheck" field in the + * cache_detail. + * When a scan of a table begins, the nextcheck field is set to a time + * that is well into the future. + * While scanning, if an expiry time is found that is earlier than the + * current nextcheck time, nextcheck is set to that expiry time. + * If the flush_time is ever set to a time earlier than the nextcheck + * time, the nextcheck time is then set to that flush_time. + * + * A table is then only scanned if the current time is at least + * the nextcheck time. + * + */ + +static LIST_HEAD(cache_list); +static DEFINE_SPINLOCK(cache_list_lock); +static struct cache_detail *current_detail; +static int current_index; + +static void do_cache_clean(struct work_struct *work); +static struct delayed_work cache_cleaner; + +void sunrpc_init_cache_detail(struct cache_detail *cd) +{ + rwlock_init(&cd->hash_lock); + INIT_LIST_HEAD(&cd->queue); + spin_lock(&cache_list_lock); + cd->nextcheck = 0; + cd->entries = 0; + atomic_set(&cd->readers, 0); + cd->last_close = 0; + cd->last_warn = -1; + list_add(&cd->others, &cache_list); + spin_unlock(&cache_list_lock); + + /* start the cleaning process */ + queue_delayed_work(system_power_efficient_wq, &cache_cleaner, 0); +} +EXPORT_SYMBOL_GPL(sunrpc_init_cache_detail); + +void sunrpc_destroy_cache_detail(struct cache_detail *cd) +{ + cache_purge(cd); + spin_lock(&cache_list_lock); + write_lock(&cd->hash_lock); + if (current_detail == cd) + current_detail = NULL; + list_del_init(&cd->others); + write_unlock(&cd->hash_lock); + spin_unlock(&cache_list_lock); + if (list_empty(&cache_list)) { + /* module must be being unloaded so its safe to kill the worker */ + cancel_delayed_work_sync(&cache_cleaner); + } +} +EXPORT_SYMBOL_GPL(sunrpc_destroy_cache_detail); + +/* clean cache tries to find something to clean + * and cleans it. + * It returns 1 if it cleaned something, + * 0 if it didn't find anything this time + * -1 if it fell off the end of the list. + */ +static int cache_clean(void) +{ + int rv = 0; + struct list_head *next; + + spin_lock(&cache_list_lock); + + /* find a suitable table if we don't already have one */ + while (current_detail == NULL || + current_index >= current_detail->hash_size) { + if (current_detail) + next = current_detail->others.next; + else + next = cache_list.next; + if (next == &cache_list) { + current_detail = NULL; + spin_unlock(&cache_list_lock); + return -1; + } + current_detail = list_entry(next, struct cache_detail, others); + if (current_detail->nextcheck > seconds_since_boot()) + current_index = current_detail->hash_size; + else { + current_index = 0; + current_detail->nextcheck = seconds_since_boot()+30*60; + } + } + + /* find a non-empty bucket in the table */ + while (current_detail && + current_index < current_detail->hash_size && + hlist_empty(¤t_detail->hash_table[current_index])) + current_index++; + + /* find a cleanable entry in the bucket and clean it, or set to next bucket */ + + if (current_detail && current_index < current_detail->hash_size) { + struct cache_head *ch = NULL; + struct cache_detail *d; + struct hlist_head *head; + struct hlist_node *tmp; + + write_lock(¤t_detail->hash_lock); + + /* Ok, now to clean this strand */ + + head = ¤t_detail->hash_table[current_index]; + hlist_for_each_entry_safe(ch, tmp, head, cache_list) { + if (current_detail->nextcheck > ch->expiry_time) + current_detail->nextcheck = ch->expiry_time+1; + if (!cache_is_expired(current_detail, ch)) + continue; + + hlist_del_init(&ch->cache_list); + current_detail->entries--; + rv = 1; + break; + } + + write_unlock(¤t_detail->hash_lock); + d = current_detail; + if (!ch) + current_index ++; + spin_unlock(&cache_list_lock); + if (ch) { + set_bit(CACHE_CLEANED, &ch->flags); + cache_fresh_unlocked(ch, d); + cache_put(ch, d); + } + } else + spin_unlock(&cache_list_lock); + + return rv; +} + +/* + * We want to regularly clean the cache, so we need to schedule some work ... + */ +static void do_cache_clean(struct work_struct *work) +{ + int delay = 5; + if (cache_clean() == -1) + delay = round_jiffies_relative(30*HZ); + + if (list_empty(&cache_list)) + delay = 0; + + if (delay) + queue_delayed_work(system_power_efficient_wq, + &cache_cleaner, delay); +} + + +/* + * Clean all caches promptly. This just calls cache_clean + * repeatedly until we are sure that every cache has had a chance to + * be fully cleaned + */ +void cache_flush(void) +{ + while (cache_clean() != -1) + cond_resched(); + while (cache_clean() != -1) + cond_resched(); +} +EXPORT_SYMBOL_GPL(cache_flush); + +void cache_purge(struct cache_detail *detail) +{ + struct cache_head *ch = NULL; + struct hlist_head *head = NULL; + struct hlist_node *tmp = NULL; + int i = 0; + + write_lock(&detail->hash_lock); + if (!detail->entries) { + write_unlock(&detail->hash_lock); + return; + } + + dprintk("RPC: %d entries in %s cache\n", detail->entries, detail->name); + for (i = 0; i < detail->hash_size; i++) { + head = &detail->hash_table[i]; + hlist_for_each_entry_safe(ch, tmp, head, cache_list) { + hlist_del_init(&ch->cache_list); + detail->entries--; + + set_bit(CACHE_CLEANED, &ch->flags); + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(ch, detail); + cache_put(ch, detail); + write_lock(&detail->hash_lock); + } + } + write_unlock(&detail->hash_lock); +} +EXPORT_SYMBOL_GPL(cache_purge); + + +/* + * Deferral and Revisiting of Requests. + * + * If a cache lookup finds a pending entry, we + * need to defer the request and revisit it later. + * All deferred requests are stored in a hash table, + * indexed by "struct cache_head *". + * As it may be wasteful to store a whole request + * structure, we allow the request to provide a + * deferred form, which must contain a + * 'struct cache_deferred_req' + * This cache_deferred_req contains a method to allow + * it to be revisited when cache info is available + */ + +#define DFR_HASHSIZE (PAGE_SIZE/sizeof(struct list_head)) +#define DFR_HASH(item) ((((long)item)>>4 ^ (((long)item)>>13)) % DFR_HASHSIZE) + +#define DFR_MAX 300 /* ??? */ + +static DEFINE_SPINLOCK(cache_defer_lock); +static LIST_HEAD(cache_defer_list); +static struct hlist_head cache_defer_hash[DFR_HASHSIZE]; +static int cache_defer_cnt; + +static void __unhash_deferred_req(struct cache_deferred_req *dreq) +{ + hlist_del_init(&dreq->hash); + if (!list_empty(&dreq->recent)) { + list_del_init(&dreq->recent); + cache_defer_cnt--; + } +} + +static void __hash_deferred_req(struct cache_deferred_req *dreq, struct cache_head *item) +{ + int hash = DFR_HASH(item); + + INIT_LIST_HEAD(&dreq->recent); + hlist_add_head(&dreq->hash, &cache_defer_hash[hash]); +} + +static void setup_deferral(struct cache_deferred_req *dreq, + struct cache_head *item, + int count_me) +{ + + dreq->item = item; + + spin_lock(&cache_defer_lock); + + __hash_deferred_req(dreq, item); + + if (count_me) { + cache_defer_cnt++; + list_add(&dreq->recent, &cache_defer_list); + } + + spin_unlock(&cache_defer_lock); + +} + +struct thread_deferred_req { + struct cache_deferred_req handle; + struct completion completion; +}; + +static void cache_restart_thread(struct cache_deferred_req *dreq, int too_many) +{ + struct thread_deferred_req *dr = + container_of(dreq, struct thread_deferred_req, handle); + complete(&dr->completion); +} + +static void cache_wait_req(struct cache_req *req, struct cache_head *item) +{ + struct thread_deferred_req sleeper; + struct cache_deferred_req *dreq = &sleeper.handle; + + sleeper.completion = COMPLETION_INITIALIZER_ONSTACK(sleeper.completion); + dreq->revisit = cache_restart_thread; + + setup_deferral(dreq, item, 0); + + if (!test_bit(CACHE_PENDING, &item->flags) || + wait_for_completion_interruptible_timeout( + &sleeper.completion, req->thread_wait) <= 0) { + /* The completion wasn't completed, so we need + * to clean up + */ + spin_lock(&cache_defer_lock); + if (!hlist_unhashed(&sleeper.handle.hash)) { + __unhash_deferred_req(&sleeper.handle); + spin_unlock(&cache_defer_lock); + } else { + /* cache_revisit_request already removed + * this from the hash table, but hasn't + * called ->revisit yet. It will very soon + * and we need to wait for it. + */ + spin_unlock(&cache_defer_lock); + wait_for_completion(&sleeper.completion); + } + } +} + +static void cache_limit_defers(void) +{ + /* Make sure we haven't exceed the limit of allowed deferred + * requests. + */ + struct cache_deferred_req *discard = NULL; + + if (cache_defer_cnt <= DFR_MAX) + return; + + spin_lock(&cache_defer_lock); + + /* Consider removing either the first or the last */ + if (cache_defer_cnt > DFR_MAX) { + if (prandom_u32() & 1) + discard = list_entry(cache_defer_list.next, + struct cache_deferred_req, recent); + else + discard = list_entry(cache_defer_list.prev, + struct cache_deferred_req, recent); + __unhash_deferred_req(discard); + } + spin_unlock(&cache_defer_lock); + if (discard) + discard->revisit(discard, 1); +} + +/* Return true if and only if a deferred request is queued. */ +static bool cache_defer_req(struct cache_req *req, struct cache_head *item) +{ + struct cache_deferred_req *dreq; + + if (req->thread_wait) { + cache_wait_req(req, item); + if (!test_bit(CACHE_PENDING, &item->flags)) + return false; + } + dreq = req->defer(req); + if (dreq == NULL) + return false; + setup_deferral(dreq, item, 1); + if (!test_bit(CACHE_PENDING, &item->flags)) + /* Bit could have been cleared before we managed to + * set up the deferral, so need to revisit just in case + */ + cache_revisit_request(item); + + cache_limit_defers(); + return true; +} + +static void cache_revisit_request(struct cache_head *item) +{ + struct cache_deferred_req *dreq; + struct list_head pending; + struct hlist_node *tmp; + int hash = DFR_HASH(item); + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) + if (dreq->item == item) { + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); + } + + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 0); + } +} + +void cache_clean_deferred(void *owner) +{ + struct cache_deferred_req *dreq, *tmp; + struct list_head pending; + + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { + if (dreq->owner == owner) { + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); + } + } + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 1); + } +} + +/* + * communicate with user-space + * + * We have a magic /proc file - /proc/net/rpc/<cachename>/channel. + * On read, you get a full request, or block. + * On write, an update request is processed. + * Poll works if anything to read, and always allows write. + * + * Implemented by linked list of requests. Each open file has + * a ->private that also exists in this list. New requests are added + * to the end and may wakeup and preceding readers. + * New readers are added to the head. If, on read, an item is found with + * CACHE_UPCALLING clear, we free it from the list. + * + */ + +static DEFINE_SPINLOCK(queue_lock); +static DEFINE_MUTEX(queue_io_mutex); + +struct cache_queue { + struct list_head list; + int reader; /* if 0, then request */ +}; +struct cache_request { + struct cache_queue q; + struct cache_head *item; + char * buf; + int len; + int readers; +}; +struct cache_reader { + struct cache_queue q; + int offset; /* if non-0, we have a refcnt on next request */ +}; + +static int cache_request(struct cache_detail *detail, + struct cache_request *crq) +{ + char *bp = crq->buf; + int len = PAGE_SIZE; + + detail->cache_request(detail, crq->item, &bp, &len); + if (len < 0) + return -EAGAIN; + return PAGE_SIZE - len; +} + +static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, + loff_t *ppos, struct cache_detail *cd) +{ + struct cache_reader *rp = filp->private_data; + struct cache_request *rq; + struct inode *inode = file_inode(filp); + int err; + + if (count == 0) + return 0; + + inode_lock(inode); /* protect against multiple concurrent + * readers on this file */ + again: + spin_lock(&queue_lock); + /* need to find next request */ + while (rp->q.list.next != &cd->queue && + list_entry(rp->q.list.next, struct cache_queue, list) + ->reader) { + struct list_head *next = rp->q.list.next; + list_move(&rp->q.list, next); + } + if (rp->q.list.next == &cd->queue) { + spin_unlock(&queue_lock); + inode_unlock(inode); + WARN_ON_ONCE(rp->offset); + return 0; + } + rq = container_of(rp->q.list.next, struct cache_request, q.list); + WARN_ON_ONCE(rq->q.reader); + if (rp->offset == 0) + rq->readers++; + spin_unlock(&queue_lock); + + if (rq->len == 0) { + err = cache_request(cd, rq); + if (err < 0) + goto out; + rq->len = err; + } + + if (rp->offset == 0 && !test_bit(CACHE_PENDING, &rq->item->flags)) { + err = -EAGAIN; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } else { + if (rp->offset + count > rq->len) + count = rq->len - rp->offset; + err = -EFAULT; + if (copy_to_user(buf, rq->buf + rp->offset, count)) + goto out; + rp->offset += count; + if (rp->offset >= rq->len) { + rp->offset = 0; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } + err = 0; + } + out: + if (rp->offset == 0) { + /* need to release rq */ + spin_lock(&queue_lock); + rq->readers--; + if (rq->readers == 0 && + !test_bit(CACHE_PENDING, &rq->item->flags)) { + list_del(&rq->q.list); + spin_unlock(&queue_lock); + cache_put(rq->item, cd); + kfree(rq->buf); + kfree(rq); + } else + spin_unlock(&queue_lock); + } + if (err == -EAGAIN) + goto again; + inode_unlock(inode); + return err ? err : count; +} + +static ssize_t cache_do_downcall(char *kaddr, const char __user *buf, + size_t count, struct cache_detail *cd) +{ + ssize_t ret; + + if (count == 0) + return -EINVAL; + if (copy_from_user(kaddr, buf, count)) + return -EFAULT; + kaddr[count] = '\0'; + ret = cd->cache_parse(cd, kaddr, count); + if (!ret) + ret = count; + return ret; +} + +static ssize_t cache_slow_downcall(const char __user *buf, + size_t count, struct cache_detail *cd) +{ + static char write_buf[8192]; /* protected by queue_io_mutex */ + ssize_t ret = -EINVAL; + + if (count >= sizeof(write_buf)) + goto out; + mutex_lock(&queue_io_mutex); + ret = cache_do_downcall(write_buf, buf, count, cd); + mutex_unlock(&queue_io_mutex); +out: + return ret; +} + +static ssize_t cache_downcall(struct address_space *mapping, + const char __user *buf, + size_t count, struct cache_detail *cd) +{ + struct page *page; + char *kaddr; + ssize_t ret = -ENOMEM; + + if (count >= PAGE_SIZE) + goto out_slow; + + page = find_or_create_page(mapping, 0, GFP_KERNEL); + if (!page) + goto out_slow; + + kaddr = kmap(page); + ret = cache_do_downcall(kaddr, buf, count, cd); + kunmap(page); + unlock_page(page); + put_page(page); + return ret; +out_slow: + return cache_slow_downcall(buf, count, cd); +} + +static ssize_t cache_write(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + struct address_space *mapping = filp->f_mapping; + struct inode *inode = file_inode(filp); + ssize_t ret = -EINVAL; + + if (!cd->cache_parse) + goto out; + + inode_lock(inode); + ret = cache_downcall(mapping, buf, count, cd); + inode_unlock(inode); +out: + return ret; +} + +static DECLARE_WAIT_QUEUE_HEAD(queue_wait); + +static __poll_t cache_poll(struct file *filp, poll_table *wait, + struct cache_detail *cd) +{ + __poll_t mask; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + + poll_wait(filp, &queue_wait, wait); + + /* alway allow write */ + mask = EPOLLOUT | EPOLLWRNORM; + + if (!rp) + return mask; + + spin_lock(&queue_lock); + + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + mask |= EPOLLIN | EPOLLRDNORM; + break; + } + spin_unlock(&queue_lock); + return mask; +} + +static int cache_ioctl(struct inode *ino, struct file *filp, + unsigned int cmd, unsigned long arg, + struct cache_detail *cd) +{ + int len = 0; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + + if (cmd != FIONREAD || !rp) + return -EINVAL; + + spin_lock(&queue_lock); + + /* only find the length remaining in current request, + * or the length of the next request + */ + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + struct cache_request *cr = + container_of(cq, struct cache_request, q); + len = cr->len - rp->offset; + break; + } + spin_unlock(&queue_lock); + + return put_user(len, (int __user *)arg); +} + +static int cache_open(struct inode *inode, struct file *filp, + struct cache_detail *cd) +{ + struct cache_reader *rp = NULL; + + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + nonseekable_open(inode, filp); + if (filp->f_mode & FMODE_READ) { + rp = kmalloc(sizeof(*rp), GFP_KERNEL); + if (!rp) { + module_put(cd->owner); + return -ENOMEM; + } + rp->offset = 0; + rp->q.reader = 1; + atomic_inc(&cd->readers); + spin_lock(&queue_lock); + list_add(&rp->q.list, &cd->queue); + spin_unlock(&queue_lock); + } + filp->private_data = rp; + return 0; +} + +static int cache_release(struct inode *inode, struct file *filp, + struct cache_detail *cd) +{ + struct cache_reader *rp = filp->private_data; + + if (rp) { + spin_lock(&queue_lock); + if (rp->offset) { + struct cache_queue *cq; + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + container_of(cq, struct cache_request, q) + ->readers--; + break; + } + rp->offset = 0; + } + list_del(&rp->q.list); + spin_unlock(&queue_lock); + + filp->private_data = NULL; + kfree(rp); + + cd->last_close = seconds_since_boot(); + atomic_dec(&cd->readers); + } + module_put(cd->owner); + return 0; +} + + + +static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) +{ + struct cache_queue *cq, *tmp; + struct cache_request *cr; + struct list_head dequeued; + + INIT_LIST_HEAD(&dequeued); + spin_lock(&queue_lock); + list_for_each_entry_safe(cq, tmp, &detail->queue, list) + if (!cq->reader) { + cr = container_of(cq, struct cache_request, q); + if (cr->item != ch) + continue; + if (test_bit(CACHE_PENDING, &ch->flags)) + /* Lost a race and it is pending again */ + break; + if (cr->readers != 0) + continue; + list_move(&cr->q.list, &dequeued); + } + spin_unlock(&queue_lock); + while (!list_empty(&dequeued)) { + cr = list_entry(dequeued.next, struct cache_request, q.list); + list_del(&cr->q.list); + cache_put(cr->item, detail); + kfree(cr->buf); + kfree(cr); + } +} + +/* + * Support routines for text-based upcalls. + * Fields are separated by spaces. + * Fields are either mangled to quote space tab newline slosh with slosh + * or a hexified with a leading \x + * Record is terminated with newline. + * + */ + +void qword_add(char **bpp, int *lp, char *str) +{ + char *bp = *bpp; + int len = *lp; + int ret; + + if (len < 0) return; + + ret = string_escape_str(str, bp, len, ESCAPE_OCTAL, "\\ \n\t"); + if (ret >= len) { + bp += len; + len = -1; + } else { + bp += ret; + len -= ret; + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL_GPL(qword_add); + +void qword_addhex(char **bpp, int *lp, char *buf, int blen) +{ + char *bp = *bpp; + int len = *lp; + + if (len < 0) return; + + if (len > 2) { + *bp++ = '\\'; + *bp++ = 'x'; + len -= 2; + while (blen && len >= 2) { + bp = hex_byte_pack(bp, *buf++); + len -= 2; + blen--; + } + } + if (blen || len<1) len = -1; + else { + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL_GPL(qword_addhex); + +static void warn_no_listener(struct cache_detail *detail) +{ + if (detail->last_warn != detail->last_close) { + detail->last_warn = detail->last_close; + if (detail->warn_no_listener) + detail->warn_no_listener(detail, detail->last_close != 0); + } +} + +static bool cache_listeners_exist(struct cache_detail *detail) +{ + if (atomic_read(&detail->readers)) + return true; + if (detail->last_close == 0) + /* This cache was never opened */ + return false; + if (detail->last_close < seconds_since_boot() - 30) + /* + * We allow for the possibility that someone might + * restart a userspace daemon without restarting the + * server; but after 30 seconds, we give up. + */ + return false; + return true; +} + +/* + * register an upcall request to user-space and queue it up for read() by the + * upcall daemon. + * + * Each request is at most one page long. + */ +int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) +{ + + char *buf; + struct cache_request *crq; + int ret = 0; + + if (!detail->cache_request) + return -EINVAL; + + if (!cache_listeners_exist(detail)) { + warn_no_listener(detail); + return -EINVAL; + } + if (test_bit(CACHE_CLEANED, &h->flags)) + /* Too late to make an upcall */ + return -EAGAIN; + + buf = kmalloc(PAGE_SIZE, GFP_KERNEL); + if (!buf) + return -EAGAIN; + + crq = kmalloc(sizeof (*crq), GFP_KERNEL); + if (!crq) { + kfree(buf); + return -EAGAIN; + } + + crq->q.reader = 0; + crq->buf = buf; + crq->len = 0; + crq->readers = 0; + spin_lock(&queue_lock); + if (test_bit(CACHE_PENDING, &h->flags)) { + crq->item = cache_get(h); + list_add_tail(&crq->q.list, &detail->queue); + } else + /* Lost a race, no longer PENDING, so don't enqueue */ + ret = -EAGAIN; + spin_unlock(&queue_lock); + wake_up(&queue_wait); + if (ret == -EAGAIN) { + kfree(buf); + kfree(crq); + } + return ret; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall); + +/* + * parse a message from user-space and pass it + * to an appropriate cache + * Messages are, like requests, separated into fields by + * spaces and dequotes as \xHEXSTRING or embedded \nnn octal + * + * Message is + * reply cachename expiry key ... content.... + * + * key and content are both parsed by cache + */ + +int qword_get(char **bpp, char *dest, int bufsize) +{ + /* return bytes copied, or -1 on error */ + char *bp = *bpp; + int len = 0; + + while (*bp == ' ') bp++; + + if (bp[0] == '\\' && bp[1] == 'x') { + /* HEX STRING */ + bp += 2; + while (len < bufsize - 1) { + int h, l; + + h = hex_to_bin(bp[0]); + if (h < 0) + break; + + l = hex_to_bin(bp[1]); + if (l < 0) + break; + + *dest++ = (h << 4) | l; + bp += 2; + len++; + } + } else { + /* text with \nnn octal quoting */ + while (*bp != ' ' && *bp != '\n' && *bp && len < bufsize-1) { + if (*bp == '\\' && + isodigit(bp[1]) && (bp[1] <= '3') && + isodigit(bp[2]) && + isodigit(bp[3])) { + int byte = (*++bp -'0'); + bp++; + byte = (byte << 3) | (*bp++ - '0'); + byte = (byte << 3) | (*bp++ - '0'); + *dest++ = byte; + len++; + } else { + *dest++ = *bp++; + len++; + } + } + } + + if (*bp != ' ' && *bp != '\n' && *bp != '\0') + return -1; + while (*bp == ' ') bp++; + *bpp = bp; + *dest = '\0'; + return len; +} +EXPORT_SYMBOL_GPL(qword_get); + + +/* + * support /proc/net/rpc/$CACHENAME/content + * as a seqfile. + * We call ->cache_show passing NULL for the item to + * get a header, then pass each real item in the cache + */ + +void *cache_seq_start(struct seq_file *m, loff_t *pos) + __acquires(cd->hash_lock) +{ + loff_t n = *pos; + unsigned int hash, entry; + struct cache_head *ch; + struct cache_detail *cd = m->private; + + read_lock(&cd->hash_lock); + if (!n--) + return SEQ_START_TOKEN; + hash = n >> 32; + entry = n & ((1LL<<32) - 1); + + hlist_for_each_entry(ch, &cd->hash_table[hash], cache_list) + if (!entry--) + return ch; + n &= ~((1LL<<32) - 1); + do { + hash++; + n += 1LL<<32; + } while(hash < cd->hash_size && + hlist_empty(&cd->hash_table[hash])); + if (hash >= cd->hash_size) + return NULL; + *pos = n+1; + return hlist_entry_safe(cd->hash_table[hash].first, + struct cache_head, cache_list); +} +EXPORT_SYMBOL_GPL(cache_seq_start); + +void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) +{ + struct cache_head *ch = p; + int hash = (*pos >> 32); + struct cache_detail *cd = m->private; + + if (p == SEQ_START_TOKEN) + hash = 0; + else if (ch->cache_list.next == NULL) { + hash++; + *pos += 1LL<<32; + } else { + ++*pos; + return hlist_entry_safe(ch->cache_list.next, + struct cache_head, cache_list); + } + *pos &= ~((1LL<<32) - 1); + while (hash < cd->hash_size && + hlist_empty(&cd->hash_table[hash])) { + hash++; + *pos += 1LL<<32; + } + if (hash >= cd->hash_size) + return NULL; + ++*pos; + return hlist_entry_safe(cd->hash_table[hash].first, + struct cache_head, cache_list); +} +EXPORT_SYMBOL_GPL(cache_seq_next); + +void cache_seq_stop(struct seq_file *m, void *p) + __releases(cd->hash_lock) +{ + struct cache_detail *cd = m->private; + read_unlock(&cd->hash_lock); +} +EXPORT_SYMBOL_GPL(cache_seq_stop); + +static int c_show(struct seq_file *m, void *p) +{ + struct cache_head *cp = p; + struct cache_detail *cd = m->private; + + if (p == SEQ_START_TOKEN) + return cd->cache_show(m, cd, NULL); + + ifdebug(CACHE) + seq_printf(m, "# expiry=%ld refcnt=%d flags=%lx\n", + convert_to_wallclock(cp->expiry_time), + kref_read(&cp->ref), cp->flags); + cache_get(cp); + if (cache_check(cd, cp, NULL)) + /* cache_check does a cache_put on failure */ + seq_printf(m, "# "); + else { + if (cache_is_expired(cd, cp)) + seq_printf(m, "# "); + cache_put(cp, cd); + } + + return cd->cache_show(m, cd, cp); +} + +static const struct seq_operations cache_content_op = { + .start = cache_seq_start, + .next = cache_seq_next, + .stop = cache_seq_stop, + .show = c_show, +}; + +static int content_open(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + struct seq_file *seq; + int err; + + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + + err = seq_open(file, &cache_content_op); + if (err) { + module_put(cd->owner); + return err; + } + + seq = file->private_data; + seq->private = cd; + return 0; +} + +static int content_release(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + int ret = seq_release(inode, file); + module_put(cd->owner); + return ret; +} + +static int open_flush(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + return nonseekable_open(inode, file); +} + +static int release_flush(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + module_put(cd->owner); + return 0; +} + +static ssize_t read_flush(struct file *file, char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + char tbuf[22]; + size_t len; + + len = snprintf(tbuf, sizeof(tbuf), "%lu\n", + convert_to_wallclock(cd->flush_time)); + return simple_read_from_buffer(buf, count, ppos, tbuf, len); +} + +static ssize_t write_flush(struct file *file, const char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + char tbuf[20]; + char *ep; + time_t now; + + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + tbuf[count] = 0; + simple_strtoul(tbuf, &ep, 0); + if (*ep && *ep != '\n') + return -EINVAL; + /* Note that while we check that 'buf' holds a valid number, + * we always ignore the value and just flush everything. + * Making use of the number leads to races. + */ + + now = seconds_since_boot(); + /* Always flush everything, so behave like cache_purge() + * Do this by advancing flush_time to the current time, + * or by one second if it has already reached the current time. + * Newly added cache entries will always have ->last_refresh greater + * that ->flush_time, so they don't get flushed prematurely. + */ + + if (cd->flush_time >= now) + now = cd->flush_time + 1; + + cd->flush_time = now; + cd->nextcheck = now; + cache_flush(); + + *ppos += count; + return count; +} + +static ssize_t cache_read_procfs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE_DATA(file_inode(filp)); + + return cache_read(filp, buf, count, ppos, cd); +} + +static ssize_t cache_write_procfs(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE_DATA(file_inode(filp)); + + return cache_write(filp, buf, count, ppos, cd); +} + +static __poll_t cache_poll_procfs(struct file *filp, poll_table *wait) +{ + struct cache_detail *cd = PDE_DATA(file_inode(filp)); + + return cache_poll(filp, wait, cd); +} + +static long cache_ioctl_procfs(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct cache_detail *cd = PDE_DATA(inode); + + return cache_ioctl(inode, filp, cmd, arg, cd); +} + +static int cache_open_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return cache_open(inode, filp, cd); +} + +static int cache_release_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return cache_release(inode, filp, cd); +} + +static const struct file_operations cache_file_operations_procfs = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = cache_read_procfs, + .write = cache_write_procfs, + .poll = cache_poll_procfs, + .unlocked_ioctl = cache_ioctl_procfs, /* for FIONREAD */ + .open = cache_open_procfs, + .release = cache_release_procfs, +}; + +static int content_open_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return content_open(inode, filp, cd); +} + +static int content_release_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return content_release(inode, filp, cd); +} + +static const struct file_operations content_file_operations_procfs = { + .open = content_open_procfs, + .read = seq_read, + .llseek = seq_lseek, + .release = content_release_procfs, +}; + +static int open_flush_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return open_flush(inode, filp, cd); +} + +static int release_flush_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = PDE_DATA(inode); + + return release_flush(inode, filp, cd); +} + +static ssize_t read_flush_procfs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE_DATA(file_inode(filp)); + + return read_flush(filp, buf, count, ppos, cd); +} + +static ssize_t write_flush_procfs(struct file *filp, + const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE_DATA(file_inode(filp)); + + return write_flush(filp, buf, count, ppos, cd); +} + +static const struct file_operations cache_flush_operations_procfs = { + .open = open_flush_procfs, + .read = read_flush_procfs, + .write = write_flush_procfs, + .release = release_flush_procfs, + .llseek = no_llseek, +}; + +static void remove_cache_proc_entries(struct cache_detail *cd) +{ + if (cd->procfs) { + proc_remove(cd->procfs); + cd->procfs = NULL; + } +} + +#ifdef CONFIG_PROC_FS +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) +{ + struct proc_dir_entry *p; + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + cd->procfs = proc_mkdir(cd->name, sn->proc_net_rpc); + if (cd->procfs == NULL) + goto out_nomem; + + p = proc_create_data("flush", S_IFREG | 0600, + cd->procfs, &cache_flush_operations_procfs, cd); + if (p == NULL) + goto out_nomem; + + if (cd->cache_request || cd->cache_parse) { + p = proc_create_data("channel", S_IFREG | 0600, cd->procfs, + &cache_file_operations_procfs, cd); + if (p == NULL) + goto out_nomem; + } + if (cd->cache_show) { + p = proc_create_data("content", S_IFREG | 0400, cd->procfs, + &content_file_operations_procfs, cd); + if (p == NULL) + goto out_nomem; + } + return 0; +out_nomem: + remove_cache_proc_entries(cd); + return -ENOMEM; +} +#else /* CONFIG_PROC_FS */ +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) +{ + return 0; +} +#endif + +void __init cache_initialize(void) +{ + INIT_DEFERRABLE_WORK(&cache_cleaner, do_cache_clean); +} + +int cache_register_net(struct cache_detail *cd, struct net *net) +{ + int ret; + + sunrpc_init_cache_detail(cd); + ret = create_cache_proc_entries(cd, net); + if (ret) + sunrpc_destroy_cache_detail(cd); + return ret; +} +EXPORT_SYMBOL_GPL(cache_register_net); + +void cache_unregister_net(struct cache_detail *cd, struct net *net) +{ + remove_cache_proc_entries(cd); + sunrpc_destroy_cache_detail(cd); +} +EXPORT_SYMBOL_GPL(cache_unregister_net); + +struct cache_detail *cache_create_net(const struct cache_detail *tmpl, struct net *net) +{ + struct cache_detail *cd; + int i; + + cd = kmemdup(tmpl, sizeof(struct cache_detail), GFP_KERNEL); + if (cd == NULL) + return ERR_PTR(-ENOMEM); + + cd->hash_table = kcalloc(cd->hash_size, sizeof(struct hlist_head), + GFP_KERNEL); + if (cd->hash_table == NULL) { + kfree(cd); + return ERR_PTR(-ENOMEM); + } + + for (i = 0; i < cd->hash_size; i++) + INIT_HLIST_HEAD(&cd->hash_table[i]); + cd->net = net; + return cd; +} +EXPORT_SYMBOL_GPL(cache_create_net); + +void cache_destroy_net(struct cache_detail *cd, struct net *net) +{ + kfree(cd->hash_table); + kfree(cd); +} +EXPORT_SYMBOL_GPL(cache_destroy_net); + +static ssize_t cache_read_pipefs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_read(filp, buf, count, ppos, cd); +} + +static ssize_t cache_write_pipefs(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_write(filp, buf, count, ppos, cd); +} + +static __poll_t cache_poll_pipefs(struct file *filp, poll_table *wait) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_poll(filp, wait, cd); +} + +static long cache_ioctl_pipefs(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_ioctl(inode, filp, cmd, arg, cd); +} + +static int cache_open_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_open(inode, filp, cd); +} + +static int cache_release_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_release(inode, filp, cd); +} + +const struct file_operations cache_file_operations_pipefs = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = cache_read_pipefs, + .write = cache_write_pipefs, + .poll = cache_poll_pipefs, + .unlocked_ioctl = cache_ioctl_pipefs, /* for FIONREAD */ + .open = cache_open_pipefs, + .release = cache_release_pipefs, +}; + +static int content_open_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return content_open(inode, filp, cd); +} + +static int content_release_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return content_release(inode, filp, cd); +} + +const struct file_operations content_file_operations_pipefs = { + .open = content_open_pipefs, + .read = seq_read, + .llseek = seq_lseek, + .release = content_release_pipefs, +}; + +static int open_flush_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return open_flush(inode, filp, cd); +} + +static int release_flush_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return release_flush(inode, filp, cd); +} + +static ssize_t read_flush_pipefs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return read_flush(filp, buf, count, ppos, cd); +} + +static ssize_t write_flush_pipefs(struct file *filp, + const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return write_flush(filp, buf, count, ppos, cd); +} + +const struct file_operations cache_flush_operations_pipefs = { + .open = open_flush_pipefs, + .read = read_flush_pipefs, + .write = write_flush_pipefs, + .release = release_flush_pipefs, + .llseek = no_llseek, +}; + +int sunrpc_cache_register_pipefs(struct dentry *parent, + const char *name, umode_t umode, + struct cache_detail *cd) +{ + struct dentry *dir = rpc_create_cache_dir(parent, name, umode, cd); + if (IS_ERR(dir)) + return PTR_ERR(dir); + cd->pipefs = dir; + return 0; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_register_pipefs); + +void sunrpc_cache_unregister_pipefs(struct cache_detail *cd) +{ + if (cd->pipefs) { + rpc_remove_cache_dir(cd->pipefs); + cd->pipefs = NULL; + } +} +EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs); + +void sunrpc_cache_unhash(struct cache_detail *cd, struct cache_head *h) +{ + write_lock(&cd->hash_lock); + if (!hlist_unhashed(&h->cache_list)){ + hlist_del_init(&h->cache_list); + cd->entries--; + write_unlock(&cd->hash_lock); + cache_put(h, cd); + } else + write_unlock(&cd->hash_lock); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_unhash); diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c new file mode 100644 index 000000000..0d7d149b1 --- /dev/null +++ b/net/sunrpc/clnt.c @@ -0,0 +1,2896 @@ +/* + * linux/net/sunrpc/clnt.c + * + * This file contains the high-level RPC interface. + * It is modeled as a finite state machine to support both synchronous + * and asynchronous requests. + * + * - RPC header generation and argument serialization. + * - Credential refresh. + * - TCP connect handling. + * - Retry of operation when it is suspected the operation failed because + * of uid squashing on the server, or when the credentials were stale + * and need to be refreshed, or when a packet was damaged in transit. + * This may be have to be moved to the VFS layer. + * + * Copyright (C) 1992,1993 Rick Sladkey <jrs@world.std.com> + * Copyright (C) 1995,1996 Olaf Kirch <okir@monad.swb.de> + */ + + +#include <linux/module.h> +#include <linux/types.h> +#include <linux/kallsyms.h> +#include <linux/mm.h> +#include <linux/namei.h> +#include <linux/mount.h> +#include <linux/slab.h> +#include <linux/rcupdate.h> +#include <linux/utsname.h> +#include <linux/workqueue.h> +#include <linux/in.h> +#include <linux/in6.h> +#include <linux/un.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/metrics.h> +#include <linux/sunrpc/bc_xprt.h> +#include <trace/events/sunrpc.h> + +#include "sunrpc.h" +#include "netns.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_CALL +#endif + +#define dprint_status(t) \ + dprintk("RPC: %5u %s (status %d)\n", t->tk_pid, \ + __func__, t->tk_status) + +/* + * All RPC clients are linked into this list + */ + +static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); + + +static void call_start(struct rpc_task *task); +static void call_reserve(struct rpc_task *task); +static void call_reserveresult(struct rpc_task *task); +static void call_allocate(struct rpc_task *task); +static void call_decode(struct rpc_task *task); +static void call_bind(struct rpc_task *task); +static void call_bind_status(struct rpc_task *task); +static void call_transmit(struct rpc_task *task); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_transmit(struct rpc_task *task); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ +static void call_status(struct rpc_task *task); +static void call_transmit_status(struct rpc_task *task); +static void call_refresh(struct rpc_task *task); +static void call_refreshresult(struct rpc_task *task); +static void call_timeout(struct rpc_task *task); +static void call_connect(struct rpc_task *task); +static void call_connect_status(struct rpc_task *task); + +static __be32 *rpc_encode_header(struct rpc_task *task); +static __be32 *rpc_verify_header(struct rpc_task *task); +static int rpc_ping(struct rpc_clnt *clnt); + +static void rpc_register_client(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_add(&clnt->cl_clients, &sn->all_clients); + spin_unlock(&sn->rpc_client_lock); +} + +static void rpc_unregister_client(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_del(&clnt->cl_clients); + spin_unlock(&sn->rpc_client_lock); +} + +static void __rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) +{ + rpc_remove_client_dir(clnt); +} + +static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; + + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + __rpc_clnt_remove_pipedir(clnt); + rpc_put_sb_net(net); + } +} + +static struct dentry *rpc_setup_pipedir_sb(struct super_block *sb, + struct rpc_clnt *clnt) +{ + static uint32_t clntid; + const char *dir_name = clnt->cl_program->pipe_dir_name; + char name[15]; + struct dentry *dir, *dentry; + + dir = rpc_d_lookup_sb(sb, dir_name); + if (dir == NULL) { + pr_info("RPC: pipefs directory doesn't exist: %s\n", dir_name); + return dir; + } + for (;;) { + snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++); + name[sizeof(name) - 1] = '\0'; + dentry = rpc_create_client_dir(dir, name, clnt); + if (!IS_ERR(dentry)) + break; + if (dentry == ERR_PTR(-EEXIST)) + continue; + printk(KERN_INFO "RPC: Couldn't create pipefs entry" + " %s/%s, error %ld\n", + dir_name, name, PTR_ERR(dentry)); + break; + } + dput(dir); + return dentry; +} + +static int +rpc_setup_pipedir(struct super_block *pipefs_sb, struct rpc_clnt *clnt) +{ + struct dentry *dentry; + + if (clnt->cl_program->pipe_dir_name != NULL) { + dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + } + return 0; +} + +static int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event) +{ + if (clnt->cl_program->pipe_dir_name == NULL) + return 1; + + switch (event) { + case RPC_PIPEFS_MOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry != NULL) + return 1; + if (atomic_read(&clnt->cl_count) == 0) + return 1; + break; + case RPC_PIPEFS_UMOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry == NULL) + return 1; + break; + } + return 0; +} + +static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + struct dentry *dentry; + + switch (event) { + case RPC_PIPEFS_MOUNT: + dentry = rpc_setup_pipedir_sb(sb, clnt); + if (!dentry) + return -ENOENT; + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + break; + case RPC_PIPEFS_UMOUNT: + __rpc_clnt_remove_pipedir(clnt); + break; + default: + printk(KERN_ERR "%s: unknown event: %ld\n", __func__, event); + return -ENOTSUPP; + } + return 0; +} + +static int __rpc_pipefs_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + int error = 0; + + for (;; clnt = clnt->cl_parent) { + if (!rpc_clnt_skip_event(clnt, event)) + error = __rpc_clnt_handle_event(clnt, event, sb); + if (error || clnt == clnt->cl_parent) + break; + } + return error; +} + +static struct rpc_clnt *rpc_get_client_for_event(struct net *net, int event) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { + if (rpc_clnt_skip_event(clnt, event)) + continue; + spin_unlock(&sn->rpc_client_lock); + return clnt; + } + spin_unlock(&sn->rpc_client_lock); + return NULL; +} + +static int rpc_pipefs_event(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct super_block *sb = ptr; + struct rpc_clnt *clnt; + int error = 0; + + while ((clnt = rpc_get_client_for_event(sb->s_fs_info, event))) { + error = __rpc_pipefs_event(clnt, event, sb); + if (error) + break; + } + return error; +} + +static struct notifier_block rpc_clients_block = { + .notifier_call = rpc_pipefs_event, + .priority = SUNRPC_PIPEFS_RPC_PRIO, +}; + +int rpc_clients_notifier_register(void) +{ + return rpc_pipefs_notifier_register(&rpc_clients_block); +} + +void rpc_clients_notifier_unregister(void) +{ + return rpc_pipefs_notifier_unregister(&rpc_clients_block); +} + +static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + const struct rpc_timeout *timeout) +{ + struct rpc_xprt *old; + + spin_lock(&clnt->cl_lock); + old = rcu_dereference_protected(clnt->cl_xprt, + lockdep_is_held(&clnt->cl_lock)); + + if (!xprt_bound(xprt)) + clnt->cl_autobind = 1; + + clnt->cl_timeout = timeout; + rcu_assign_pointer(clnt->cl_xprt, xprt); + spin_unlock(&clnt->cl_lock); + + return old; +} + +static void rpc_clnt_set_nodename(struct rpc_clnt *clnt, const char *nodename) +{ + clnt->cl_nodelen = strlcpy(clnt->cl_nodename, + nodename, sizeof(clnt->cl_nodename)); +} + +static int rpc_client_register(struct rpc_clnt *clnt, + rpc_authflavor_t pseudoflavor, + const char *client_name) +{ + struct rpc_auth_create_args auth_args = { + .pseudoflavor = pseudoflavor, + .target_name = client_name, + }; + struct rpc_auth *auth; + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; + int err; + + rpc_clnt_debugfs_register(clnt); + + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + err = rpc_setup_pipedir(pipefs_sb, clnt); + if (err) + goto out; + } + + rpc_register_client(clnt); + if (pipefs_sb) + rpc_put_sb_net(net); + + auth = rpcauth_create(&auth_args, clnt); + if (IS_ERR(auth)) { + dprintk("RPC: Couldn't create auth handle (flavor %u)\n", + pseudoflavor); + err = PTR_ERR(auth); + goto err_auth; + } + return 0; +err_auth: + pipefs_sb = rpc_get_sb_net(net); + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); +out: + if (pipefs_sb) + rpc_put_sb_net(net); + rpc_clnt_debugfs_unregister(clnt); + return err; +} + +static DEFINE_IDA(rpc_clids); + +void rpc_cleanup_clids(void) +{ + ida_destroy(&rpc_clids); +} + +static int rpc_alloc_clid(struct rpc_clnt *clnt) +{ + int clid; + + clid = ida_simple_get(&rpc_clids, 0, 0, GFP_KERNEL); + if (clid < 0) + return clid; + clnt->cl_clid = clid; + return 0; +} + +static void rpc_free_clid(struct rpc_clnt *clnt) +{ + ida_simple_remove(&rpc_clids, clnt->cl_clid); +} + +static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + struct rpc_clnt *parent) +{ + const struct rpc_program *program = args->program; + const struct rpc_version *version; + struct rpc_clnt *clnt = NULL; + const struct rpc_timeout *timeout; + const char *nodename = args->nodename; + int err; + + /* sanity check the name before trying to print it */ + dprintk("RPC: creating %s client for %s (xprt %p)\n", + program->name, args->servername, xprt); + + err = rpciod_up(); + if (err) + goto out_no_rpciod; + + err = -EINVAL; + if (args->version >= program->nrvers) + goto out_err; + version = program->version[args->version]; + if (version == NULL) + goto out_err; + + err = -ENOMEM; + clnt = kzalloc(sizeof(*clnt), GFP_KERNEL); + if (!clnt) + goto out_err; + clnt->cl_parent = parent ? : clnt; + + err = rpc_alloc_clid(clnt); + if (err) + goto out_no_clid; + + clnt->cl_procinfo = version->procs; + clnt->cl_maxproc = version->nrprocs; + clnt->cl_prog = args->prognumber ? : program->number; + clnt->cl_vers = version->number; + clnt->cl_stats = program->stats; + clnt->cl_metrics = rpc_alloc_iostats(clnt); + rpc_init_pipe_dir_head(&clnt->cl_pipedir_objects); + err = -ENOMEM; + if (clnt->cl_metrics == NULL) + goto out_no_stats; + clnt->cl_program = program; + INIT_LIST_HEAD(&clnt->cl_tasks); + spin_lock_init(&clnt->cl_lock); + + timeout = xprt->timeout; + if (args->timeout != NULL) { + memcpy(&clnt->cl_timeout_default, args->timeout, + sizeof(clnt->cl_timeout_default)); + timeout = &clnt->cl_timeout_default; + } + + rpc_clnt_set_transport(clnt, xprt, timeout); + xprt_iter_init(&clnt->cl_xpi, xps); + xprt_switch_put(xps); + + clnt->cl_rtt = &clnt->cl_rtt_default; + rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval); + + atomic_set(&clnt->cl_count, 1); + + if (nodename == NULL) + nodename = utsname()->nodename; + /* save the nodename */ + rpc_clnt_set_nodename(clnt, nodename); + + err = rpc_client_register(clnt, args->authflavor, args->client_name); + if (err) + goto out_no_path; + if (parent) + atomic_inc(&parent->cl_count); + return clnt; + +out_no_path: + rpc_free_iostats(clnt->cl_metrics); +out_no_stats: + rpc_free_clid(clnt); +out_no_clid: + kfree(clnt); +out_err: + rpciod_down(); +out_no_rpciod: + xprt_switch_put(xps); + xprt_put(xprt); + return ERR_PTR(err); +} + +static struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args, + struct rpc_xprt *xprt) +{ + struct rpc_clnt *clnt = NULL; + struct rpc_xprt_switch *xps; + + if (args->bc_xprt && args->bc_xprt->xpt_bc_xps) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xps = args->bc_xprt->xpt_bc_xps; + xprt_switch_get(xps); + } else { + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return ERR_PTR(-ENOMEM); + } + if (xprt->bc_xprt) { + xprt_switch_get(xps); + xprt->bc_xprt->xpt_bc_xps = xps; + } + } + clnt = rpc_new_client(args, xps, xprt, NULL); + if (IS_ERR(clnt)) + return clnt; + + if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { + int err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } + + clnt->cl_softrtry = 1; + if (args->flags & RPC_CLNT_CREATE_HARDRTRY) + clnt->cl_softrtry = 0; + + if (args->flags & RPC_CLNT_CREATE_AUTOBIND) + clnt->cl_autobind = 1; + if (args->flags & RPC_CLNT_CREATE_NO_RETRANS_TIMEOUT) + clnt->cl_noretranstimeo = 1; + if (args->flags & RPC_CLNT_CREATE_DISCRTRY) + clnt->cl_discrtry = 1; + if (!(args->flags & RPC_CLNT_CREATE_QUIET)) + clnt->cl_chatty = 1; + + return clnt; +} + +/** + * rpc_create - create an RPC client and transport with one call + * @args: rpc_clnt create argument structure + * + * Creates and initializes an RPC transport and an RPC client. + * + * It can ping the server in order to determine if it is up, and to see if + * it supports this program and version. RPC_CLNT_CREATE_NOPING disables + * this behavior so asynchronous tasks can also use rpc_create. + */ +struct rpc_clnt *rpc_create(struct rpc_create_args *args) +{ + struct rpc_xprt *xprt; + struct xprt_create xprtargs = { + .net = args->net, + .ident = args->protocol, + .srcaddr = args->saddress, + .dstaddr = args->address, + .addrlen = args->addrsize, + .servername = args->servername, + .bc_xprt = args->bc_xprt, + }; + char servername[48]; + + if (args->bc_xprt) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xprt = args->bc_xprt->xpt_bc_xprt; + if (xprt) { + xprt_get(xprt); + return rpc_create_xprt(args, xprt); + } + } + + if (args->flags & RPC_CLNT_CREATE_INFINITE_SLOTS) + xprtargs.flags |= XPRT_CREATE_INFINITE_SLOTS; + if (args->flags & RPC_CLNT_CREATE_NO_IDLE_TIMEOUT) + xprtargs.flags |= XPRT_CREATE_NO_IDLE_TIMEOUT; + /* + * If the caller chooses not to specify a hostname, whip + * up a string representation of the passed-in address. + */ + if (xprtargs.servername == NULL) { + struct sockaddr_un *sun = + (struct sockaddr_un *)args->address; + struct sockaddr_in *sin = + (struct sockaddr_in *)args->address; + struct sockaddr_in6 *sin6 = + (struct sockaddr_in6 *)args->address; + + servername[0] = '\0'; + switch (args->address->sa_family) { + case AF_LOCAL: + snprintf(servername, sizeof(servername), "%s", + sun->sun_path); + break; + case AF_INET: + snprintf(servername, sizeof(servername), "%pI4", + &sin->sin_addr.s_addr); + break; + case AF_INET6: + snprintf(servername, sizeof(servername), "%pI6", + &sin6->sin6_addr); + break; + default: + /* caller wants default server name, but + * address family isn't recognized. */ + return ERR_PTR(-EINVAL); + } + xprtargs.servername = servername; + } + + xprt = xprt_create_transport(&xprtargs); + if (IS_ERR(xprt)) + return (struct rpc_clnt *)xprt; + + /* + * By default, kernel RPC client connects from a reserved port. + * CAP_NET_BIND_SERVICE will not be set for unprivileged requesters, + * but it is always enabled for rpciod, which handles the connect + * operation. + */ + xprt->resvport = 1; + if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT) + xprt->resvport = 0; + + return rpc_create_xprt(args, xprt); +} +EXPORT_SYMBOL_GPL(rpc_create); + +/* + * This function clones the RPC client structure. It allows us to share the + * same transport while varying parameters such as the authentication + * flavour. + */ +static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, + struct rpc_clnt *clnt) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; + struct rpc_clnt *new; + int err; + + err = -ENOMEM; + rcu_read_lock(); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + if (xprt == NULL || xps == NULL) { + xprt_put(xprt); + xprt_switch_put(xps); + goto out_err; + } + args->servername = xprt->servername; + args->nodename = clnt->cl_nodename; + + new = rpc_new_client(args, xps, xprt, clnt); + if (IS_ERR(new)) { + err = PTR_ERR(new); + goto out_err; + } + + /* Turn off autobind on clones */ + new->cl_autobind = 0; + new->cl_softrtry = clnt->cl_softrtry; + new->cl_noretranstimeo = clnt->cl_noretranstimeo; + new->cl_discrtry = clnt->cl_discrtry; + new->cl_chatty = clnt->cl_chatty; + return new; + +out_err: + dprintk("RPC: %s: returned error %d\n", __func__, err); + return ERR_PTR(err); +} + +/** + * rpc_clone_client - Clone an RPC client structure + * + * @clnt: RPC client whose parameters are copied + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt *rpc_clone_client(struct rpc_clnt *clnt) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = clnt->cl_auth->au_flavor, + }; + return __rpc_clone_client(&args, clnt); +} +EXPORT_SYMBOL_GPL(rpc_clone_client); + +/** + * rpc_clone_client_set_auth - Clone an RPC client structure and set its auth + * + * @clnt: RPC client whose parameters are copied + * @flavor: security flavor for new client + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt * +rpc_clone_client_set_auth(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = flavor, + }; + return __rpc_clone_client(&args, clnt); +} +EXPORT_SYMBOL_GPL(rpc_clone_client_set_auth); + +/** + * rpc_switch_client_transport: switch the RPC transport on the fly + * @clnt: pointer to a struct rpc_clnt + * @args: pointer to the new transport arguments + * @timeout: pointer to the new timeout parameters + * + * This function allows the caller to switch the RPC transport for the + * rpc_clnt structure 'clnt' to allow it to connect to a mirrored NFS + * server, for instance. It assumes that the caller has ensured that + * there are no active RPC tasks by using some form of locking. + * + * Returns zero if "clnt" is now using the new xprt. Otherwise a + * negative errno is returned, and "clnt" continues to use the old + * xprt. + */ +int rpc_switch_client_transport(struct rpc_clnt *clnt, + struct xprt_create *args, + const struct rpc_timeout *timeout) +{ + const struct rpc_timeout *old_timeo; + rpc_authflavor_t pseudoflavor; + struct rpc_xprt_switch *xps, *oldxps; + struct rpc_xprt *xprt, *old; + struct rpc_clnt *parent; + int err; + + xprt = xprt_create_transport(args); + if (IS_ERR(xprt)) { + dprintk("RPC: failed to create new xprt for clnt %p\n", + clnt); + return PTR_ERR(xprt); + } + + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return -ENOMEM; + } + + pseudoflavor = clnt->cl_auth->au_flavor; + + old_timeo = clnt->cl_timeout; + old = rpc_clnt_set_transport(clnt, xprt, timeout); + oldxps = xprt_iter_xchg_switch(&clnt->cl_xpi, xps); + + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); + rpc_clnt_debugfs_unregister(clnt); + + /* + * A new transport was created. "clnt" therefore + * becomes the root of a new cl_parent tree. clnt's + * children, if it has any, still point to the old xprt. + */ + parent = clnt->cl_parent; + clnt->cl_parent = clnt; + + /* + * The old rpc_auth cache cannot be re-used. GSS + * contexts in particular are between a single + * client and server. + */ + err = rpc_client_register(clnt, pseudoflavor, NULL); + if (err) + goto out_revert; + + synchronize_rcu(); + if (parent != clnt) + rpc_release_client(parent); + xprt_switch_put(oldxps); + xprt_put(old); + dprintk("RPC: replaced xprt for clnt %p\n", clnt); + return 0; + +out_revert: + xps = xprt_iter_xchg_switch(&clnt->cl_xpi, oldxps); + rpc_clnt_set_transport(clnt, old, old_timeo); + clnt->cl_parent = parent; + rpc_client_register(clnt, pseudoflavor, NULL); + xprt_switch_put(xps); + xprt_put(xprt); + dprintk("RPC: failed to switch xprt for clnt %p\n", clnt); + return err; +} +EXPORT_SYMBOL_GPL(rpc_switch_client_transport); + +static +int rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + if (xps == NULL) + return -EAGAIN; + xprt_iter_init_listall(xpi, xps); + xprt_switch_put(xps); + return 0; +} + +/** + * rpc_clnt_iterate_for_each_xprt - Apply a function to all transports + * @clnt: pointer to client + * @fn: function to apply + * @data: void pointer to function data + * + * Iterates through the list of RPC transports currently attached to the + * client and applies the function fn(clnt, xprt, data). + * + * On error, the iteration stops, and the function returns the error value. + */ +int rpc_clnt_iterate_for_each_xprt(struct rpc_clnt *clnt, + int (*fn)(struct rpc_clnt *, struct rpc_xprt *, void *), + void *data) +{ + struct rpc_xprt_iter xpi; + int ret; + + ret = rpc_clnt_xprt_iter_init(clnt, &xpi); + if (ret) + return ret; + for (;;) { + struct rpc_xprt *xprt = xprt_iter_get_next(&xpi); + + if (!xprt) + break; + ret = fn(clnt, xprt, data); + xprt_put(xprt); + if (ret < 0) + break; + } + xprt_iter_destroy(&xpi); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_iterate_for_each_xprt); + +/* + * Kill all tasks for the given client. + * XXX: kill their descendants as well? + */ +void rpc_killall_tasks(struct rpc_clnt *clnt) +{ + struct rpc_task *rovr; + + + if (list_empty(&clnt->cl_tasks)) + return; + dprintk("RPC: killing all tasks for client %p\n", clnt); + /* + * Spin lock all_tasks to prevent changes... + */ + spin_lock(&clnt->cl_lock); + list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) { + if (!RPC_IS_ACTIVATED(rovr)) + continue; + if (!(rovr->tk_flags & RPC_TASK_KILLED)) { + rovr->tk_flags |= RPC_TASK_KILLED; + rpc_exit(rovr, -EIO); + if (RPC_IS_QUEUED(rovr)) + rpc_wake_up_queued_task(rovr->tk_waitqueue, + rovr); + } + } + spin_unlock(&clnt->cl_lock); +} +EXPORT_SYMBOL_GPL(rpc_killall_tasks); + +/* + * Properly shut down an RPC client, terminating all outstanding + * requests. + */ +void rpc_shutdown_client(struct rpc_clnt *clnt) +{ + might_sleep(); + + dprintk_rcu("RPC: shutting down %s client for %s\n", + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + + while (!list_empty(&clnt->cl_tasks)) { + rpc_killall_tasks(clnt); + wait_event_timeout(destroy_wait, + list_empty(&clnt->cl_tasks), 1*HZ); + } + + rpc_release_client(clnt); +} +EXPORT_SYMBOL_GPL(rpc_shutdown_client); + +/* + * Free an RPC client + */ +static struct rpc_clnt * +rpc_free_client(struct rpc_clnt *clnt) +{ + struct rpc_clnt *parent = NULL; + + dprintk_rcu("RPC: destroying %s client for %s\n", + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + if (clnt->cl_parent != clnt) + parent = clnt->cl_parent; + rpc_clnt_debugfs_unregister(clnt); + rpc_clnt_remove_pipedir(clnt); + rpc_unregister_client(clnt); + rpc_free_iostats(clnt->cl_metrics); + clnt->cl_metrics = NULL; + xprt_put(rcu_dereference_raw(clnt->cl_xprt)); + xprt_iter_destroy(&clnt->cl_xpi); + rpciod_down(); + rpc_free_clid(clnt); + kfree(clnt); + return parent; +} + +/* + * Free an RPC client + */ +static struct rpc_clnt * +rpc_free_auth(struct rpc_clnt *clnt) +{ + if (clnt->cl_auth == NULL) + return rpc_free_client(clnt); + + /* + * Note: RPCSEC_GSS may need to send NULL RPC calls in order to + * release remaining GSS contexts. This mechanism ensures + * that it can do so safely. + */ + atomic_inc(&clnt->cl_count); + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = NULL; + if (atomic_dec_and_test(&clnt->cl_count)) + return rpc_free_client(clnt); + return NULL; +} + +/* + * Release reference to the RPC client + */ +void +rpc_release_client(struct rpc_clnt *clnt) +{ + dprintk("RPC: rpc_release_client(%p)\n", clnt); + + do { + if (list_empty(&clnt->cl_tasks)) + wake_up(&destroy_wait); + if (!atomic_dec_and_test(&clnt->cl_count)) + break; + clnt = rpc_free_auth(clnt); + } while (clnt != NULL); +} +EXPORT_SYMBOL_GPL(rpc_release_client); + +/** + * rpc_bind_new_program - bind a new RPC program to an existing client + * @old: old rpc_client + * @program: rpc program to set + * @vers: rpc program version + * + * Clones the rpc client and sets up a new RPC program. This is mainly + * of use for enabling different RPC programs to share the same transport. + * The Sun NFSv2/v3 ACL protocol can do this. + */ +struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, + const struct rpc_program *program, + u32 vers) +{ + struct rpc_create_args args = { + .program = program, + .prognumber = program->number, + .version = vers, + .authflavor = old->cl_auth->au_flavor, + }; + struct rpc_clnt *clnt; + int err; + + clnt = __rpc_clone_client(&args, old); + if (IS_ERR(clnt)) + goto out; + err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + clnt = ERR_PTR(err); + } +out: + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_bind_new_program); + +void rpc_task_release_transport(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + if (xprt) { + task->tk_xprt = NULL; + xprt_put(xprt); + } +} +EXPORT_SYMBOL_GPL(rpc_task_release_transport); + +void rpc_task_release_client(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + if (clnt != NULL) { + /* Remove from client task list */ + spin_lock(&clnt->cl_lock); + list_del(&task->tk_task); + spin_unlock(&clnt->cl_lock); + task->tk_client = NULL; + + rpc_release_client(clnt); + } + rpc_task_release_transport(task); +} + +static +void rpc_task_set_transport(struct rpc_task *task, struct rpc_clnt *clnt) +{ + if (!task->tk_xprt) + task->tk_xprt = xprt_iter_get_next(&clnt->cl_xpi); +} + +static +void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) +{ + + if (clnt != NULL) { + rpc_task_set_transport(task, clnt); + task->tk_client = clnt; + atomic_inc(&clnt->cl_count); + if (clnt->cl_softrtry) + task->tk_flags |= RPC_TASK_SOFT; + if (clnt->cl_noretranstimeo) + task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT; + if (atomic_read(&clnt->cl_swapper)) + task->tk_flags |= RPC_TASK_SWAPPER; + /* Add to the client's list of all tasks */ + spin_lock(&clnt->cl_lock); + list_add_tail(&task->tk_task, &clnt->cl_tasks); + spin_unlock(&clnt->cl_lock); + } +} + +static void +rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg) +{ + if (msg != NULL) { + task->tk_msg.rpc_proc = msg->rpc_proc; + task->tk_msg.rpc_argp = msg->rpc_argp; + task->tk_msg.rpc_resp = msg->rpc_resp; + if (msg->rpc_cred != NULL) + task->tk_msg.rpc_cred = get_rpccred(msg->rpc_cred); + } +} + +/* + * Default callback for async RPC calls + */ +static void +rpc_default_callback(struct rpc_task *task, void *data) +{ +} + +static const struct rpc_call_ops rpc_default_ops = { + .rpc_call_done = rpc_default_callback, +}; + +/** + * rpc_run_task - Allocate a new RPC task, then run rpc_execute against it + * @task_setup_data: pointer to task initialisation data + */ +struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) +{ + struct rpc_task *task; + + task = rpc_new_task(task_setup_data); + + rpc_task_set_client(task, task_setup_data->rpc_client); + rpc_task_set_rpc_message(task, task_setup_data->rpc_message); + + if (task->tk_action == NULL) + rpc_call_start(task); + + atomic_inc(&task->tk_count); + rpc_execute(task); + return task; +} +EXPORT_SYMBOL_GPL(rpc_run_task); + +/** + * rpc_call_sync - Perform a synchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + */ +int rpc_call_sync(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = &rpc_default_ops, + .flags = flags, + }; + int status; + + WARN_ON_ONCE(flags & RPC_TASK_ASYNC); + if (flags & RPC_TASK_ASYNC) { + rpc_release_calldata(task_setup_data.callback_ops, + task_setup_data.callback_data); + return -EINVAL; + } + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} +EXPORT_SYMBOL_GPL(rpc_call_sync); + +/** + * rpc_call_async - Perform an asynchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + * @tk_ops: RPC call ops + * @data: user call data + */ +int +rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, + const struct rpc_call_ops *tk_ops, void *data) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = tk_ops, + .callback_data = data, + .flags = flags|RPC_TASK_ASYNC, + }; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + rpc_put_task(task); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_call_async); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/** + * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run + * rpc_execute against it + * @req: RPC request + */ +struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) +{ + struct rpc_task *task; + struct xdr_buf *xbufp = &req->rq_snd_buf; + struct rpc_task_setup task_setup_data = { + .callback_ops = &rpc_default_ops, + .flags = RPC_TASK_SOFTCONN, + }; + + dprintk("RPC: rpc_run_bc_task req= %p\n", req); + /* + * Create an rpc_task to send the data + */ + task = rpc_new_task(&task_setup_data); + task->tk_rqstp = req; + + /* + * Set up the xdr_buf length. + * This also indicates that the buffer is XDR encoded already. + */ + xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + + xbufp->tail[0].iov_len; + + task->tk_action = call_bc_transmit; + atomic_inc(&task->tk_count); + WARN_ON_ONCE(atomic_read(&task->tk_count) != 2); + rpc_execute(task); + + dprintk("RPC: rpc_run_bc_task: task= %p\n", task); + return task; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +void +rpc_call_start(struct rpc_task *task) +{ + task->tk_action = call_start; +} +EXPORT_SYMBOL_GPL(rpc_call_start); + +/** + * rpc_peeraddr - extract remote peer address from clnt's xprt + * @clnt: RPC client structure + * @buf: target buffer + * @bufsize: length of target buffer + * + * Returns the number of bytes that are actually in the stored address. + */ +size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize) +{ + size_t bytes; + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + + bytes = xprt->addrlen; + if (bytes > bufsize) + bytes = bufsize; + memcpy(buf, &xprt->addr, bytes); + rcu_read_unlock(); + + return bytes; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr); + +/** + * rpc_peeraddr2str - return remote peer address in printable format + * @clnt: RPC client structure + * @format: address format + * + * NB: the lifetime of the memory referenced by the returned pointer is + * the same as the rpc_xprt itself. As long as the caller uses this + * pointer, it must hold the RCU read lock. + */ +const char *rpc_peeraddr2str(struct rpc_clnt *clnt, + enum rpc_display_format_t format) +{ + struct rpc_xprt *xprt; + + xprt = rcu_dereference(clnt->cl_xprt); + + if (xprt->address_strings[format] != NULL) + return xprt->address_strings[format]; + else + return "unprintable"; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr2str); + +static const struct sockaddr_in rpc_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), +}; + +static const struct sockaddr_in6 rpc_in6addr_loopback = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, +}; + +/* + * Try a getsockname() on a connected datagram socket. Using a + * connected datagram socket prevents leaving a socket in TIME_WAIT. + * This conserves the ephemeral port number space. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, + struct sockaddr *buf) +{ + struct socket *sock; + int err; + + err = __sock_create(net, sap->sa_family, + SOCK_DGRAM, IPPROTO_UDP, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create UDP socket (%d)\n", err); + goto out; + } + + switch (sap->sa_family) { + case AF_INET: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + break; + default: + err = -EAFNOSUPPORT; + goto out; + } + if (err < 0) { + dprintk("RPC: can't bind UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_connect(sock, sap, salen, 0); + if (err < 0) { + dprintk("RPC: can't connect UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_getsockname(sock, buf); + if (err < 0) { + dprintk("RPC: getsockname failed (%d)\n", err); + goto out_release; + } + + err = 0; + if (buf->sa_family == AF_INET6) { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)buf; + sin6->sin6_scope_id = 0; + } + dprintk("RPC: %s succeeded\n", __func__); + +out_release: + sock_release(sock); +out: + return err; +} + +/* + * Scraping a connected socket failed, so we don't have a useable + * local address. Fallback: generate an address that will prevent + * the server from calling us back. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_anyaddr(int family, struct sockaddr *buf, size_t buflen) +{ + switch (family) { + case AF_INET: + if (buflen < sizeof(rpc_inaddr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + if (buflen < sizeof(rpc_in6addr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + break; + default: + dprintk("RPC: %s: address family not supported\n", + __func__); + return -EAFNOSUPPORT; + } + dprintk("RPC: %s: succeeded\n", __func__); + return 0; +} + +/** + * rpc_localaddr - discover local endpoint address for an RPC client + * @clnt: RPC client structure + * @buf: target buffer + * @buflen: size of target buffer, in bytes + * + * Returns zero and fills in "buf" and "buflen" if successful; + * otherwise, a negative errno is returned. + * + * This works even if the underlying transport is not currently connected, + * or if the upper layer never previously provided a source address. + * + * The result of this function call is transient: multiple calls in + * succession may give different results, depending on how local + * networking configuration changes over time. + */ +int rpc_localaddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t buflen) +{ + struct sockaddr_storage address; + struct sockaddr *sap = (struct sockaddr *)&address; + struct rpc_xprt *xprt; + struct net *net; + size_t salen; + int err; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + salen = xprt->addrlen; + memcpy(sap, &xprt->addr, salen); + net = get_net(xprt->xprt_net); + rcu_read_unlock(); + + rpc_set_port(sap, 0); + err = rpc_sockname(net, sap, salen, buf); + put_net(net); + if (err != 0) + /* Couldn't discover local address, return ANYADDR */ + return rpc_anyaddr(sap->sa_family, buf, buflen); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_localaddr); + +void +rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + if (xprt->ops->set_buffer_size) + xprt->ops->set_buffer_size(xprt, sndsize, rcvsize); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_setbufsize); + +/** + * rpc_net_ns - Get the network namespace for this RPC client + * @clnt: RPC client to query + * + */ +struct net *rpc_net_ns(struct rpc_clnt *clnt) +{ + struct net *ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->xprt_net; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_net_ns); + +/** + * rpc_max_payload - Get maximum payload size for a transport, in bytes + * @clnt: RPC client to query + * + * For stream transports, this is one RPC record fragment (see RFC + * 1831), as we don't support multi-record requests yet. For datagram + * transports, this is the size of an IP packet minus the IP, UDP, and + * RPC header sizes. + */ +size_t rpc_max_payload(struct rpc_clnt *clnt) +{ + size_t ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->max_payload; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_max_payload); + +/** + * rpc_max_bc_payload - Get maximum backchannel payload size, in bytes + * @clnt: RPC client to query + */ +size_t rpc_max_bc_payload(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + size_t ret; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + ret = xprt->ops->bc_maxpayload(xprt); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_max_bc_payload); + +/** + * rpc_force_rebind - force transport to check that remote port is unchanged + * @clnt: client to rebind + * + */ +void rpc_force_rebind(struct rpc_clnt *clnt) +{ + if (clnt->cl_autobind) { + rcu_read_lock(); + xprt_clear_bound(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL_GPL(rpc_force_rebind); + +/* + * Restart an (async) RPC call from the call_prepare state. + * Usually called from within the exit handler. + */ +int +rpc_restart_call_prepare(struct rpc_task *task) +{ + if (RPC_ASSASSINATED(task)) + return 0; + task->tk_action = call_start; + task->tk_status = 0; + if (task->tk_ops->rpc_call_prepare != NULL) + task->tk_action = rpc_prepare_task; + return 1; +} +EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); + +/* + * Restart an (async) RPC call. Usually called from within the + * exit handler. + */ +int +rpc_restart_call(struct rpc_task *task) +{ + if (RPC_ASSASSINATED(task)) + return 0; + task->tk_action = call_start; + task->tk_status = 0; + return 1; +} +EXPORT_SYMBOL_GPL(rpc_restart_call); + +const char +*rpc_proc_name(const struct rpc_task *task) +{ + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + + if (proc) { + if (proc->p_name) + return proc->p_name; + else + return "NULL"; + } else + return "no proc"; +} + +/* + * 0. Initial state + * + * Other FSM states can be visited zero or more times, but + * this state is visited exactly once for each RPC. + */ +static void +call_start(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + int idx = task->tk_msg.rpc_proc->p_statidx; + + trace_rpc_request(task); + dprintk("RPC: %5u call_start %s%d proc %s (%s)\n", task->tk_pid, + clnt->cl_program->name, clnt->cl_vers, + rpc_proc_name(task), + (RPC_IS_ASYNC(task) ? "async" : "sync")); + + /* Increment call count (version might not be valid for ping) */ + if (clnt->cl_program->version[clnt->cl_vers]) + clnt->cl_program->version[clnt->cl_vers]->counts[idx]++; + clnt->cl_stats->rpccnt++; + task->tk_action = call_reserve; + rpc_task_set_transport(task, clnt); +} + +/* + * 1. Reserve an RPC call slot + */ +static void +call_reserve(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_reserve(task); +} + +static void call_retry_reserve(struct rpc_task *task); + +/* + * 1b. Grok the result of xprt_reserve() + */ +static void +call_reserveresult(struct rpc_task *task) +{ + int status = task->tk_status; + + dprint_status(task); + + /* + * After a call to xprt_reserve(), we must have either + * a request slot or else an error status. + */ + task->tk_status = 0; + if (status >= 0) { + if (task->tk_rqstp) { + task->tk_action = call_refresh; + return; + } + + printk(KERN_ERR "%s: status=%d, but no request slot, exiting\n", + __func__, status); + rpc_exit(task, -EIO); + return; + } + + /* + * Even though there was an error, we may have acquired + * a request slot somehow. Make sure not to leak it. + */ + if (task->tk_rqstp) { + printk(KERN_ERR "%s: status=%d, request allocated anyway\n", + __func__, status); + xprt_release(task); + } + + switch (status) { + case -ENOMEM: + rpc_delay(task, HZ >> 2); + /* fall through */ + case -EAGAIN: /* woken up; retry */ + task->tk_action = call_retry_reserve; + return; + case -EIO: /* probably a shutdown */ + break; + default: + printk(KERN_ERR "%s: unrecognized error %d, exiting\n", + __func__, status); + break; + } + rpc_exit(task, status); +} + +/* + * 1c. Retry reserving an RPC call slot + */ +static void +call_retry_reserve(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_retry_reserve(task); +} + +/* + * 2. Bind and/or refresh the credentials + */ +static void +call_refresh(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_action = call_refreshresult; + task->tk_status = 0; + task->tk_client->cl_stats->rpcauthrefresh++; + rpcauth_refreshcred(task); +} + +/* + * 2a. Process the results of a credential refresh + */ +static void +call_refreshresult(struct rpc_task *task) +{ + int status = task->tk_status; + + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_refresh; + switch (status) { + case 0: + if (rpcauth_uptodatecred(task)) { + task->tk_action = call_allocate; + return; + } + /* Use rate-limiting and a max number of retries if refresh + * had status 0 but failed to update the cred. + */ + /* fall through */ + case -ETIMEDOUT: + rpc_delay(task, 3*HZ); + /* fall through */ + case -EAGAIN: + status = -EACCES; + /* fall through */ + case -EKEYEXPIRED: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + dprintk("RPC: %5u %s: retry refresh creds\n", + task->tk_pid, __func__); + return; + } + dprintk("RPC: %5u %s: refresh creds failed with error %d\n", + task->tk_pid, __func__, status); + rpc_exit(task, status); +} + +/* + * 2b. Allocate the buffer. For details, see sched.c:rpc_malloc. + * (Note: buffer memory is freed in xprt_release). + */ +static void +call_allocate(struct rpc_task *task) +{ + unsigned int slack = task->tk_rqstp->rq_cred->cr_auth->au_cslack; + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + int status; + + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_bind; + + if (req->rq_buffer) + return; + + if (proc->p_proc != 0) { + BUG_ON(proc->p_arglen == 0); + if (proc->p_decode != NULL) + BUG_ON(proc->p_replen == 0); + } + + /* + * Calculate the size (in quads) of the RPC call + * and reply headers, and convert both values + * to byte sizes. + */ + req->rq_callsize = RPC_CALLHDRSIZE + (slack << 1) + proc->p_arglen; + req->rq_callsize <<= 2; + req->rq_rcvsize = RPC_REPHDRSIZE + slack + proc->p_replen; + req->rq_rcvsize <<= 2; + + status = xprt->ops->buf_alloc(task); + xprt_inject_disconnect(xprt); + if (status == 0) + return; + if (status != -ENOMEM) { + rpc_exit(task, status); + return; + } + + dprintk("RPC: %5u rpc_buffer allocation failed\n", task->tk_pid); + + if (RPC_IS_ASYNC(task) || !fatal_signal_pending(current)) { + task->tk_action = call_allocate; + rpc_delay(task, HZ>>4); + return; + } + + rpc_exit(task, -ERESTARTSYS); +} + +static inline int +rpc_task_need_encode(struct rpc_task *task) +{ + return task->tk_rqstp->rq_snd_buf.len == 0; +} + +static inline void +rpc_task_force_reencode(struct rpc_task *task) +{ + task->tk_rqstp->rq_snd_buf.len = 0; + task->tk_rqstp->rq_bytes_sent = 0; +} + +/* + * 3. Encode arguments of an RPC call + */ +static void +rpc_xdr_encode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + kxdreproc_t encode; + __be32 *p; + + dprint_status(task); + + xdr_buf_init(&req->rq_snd_buf, + req->rq_buffer, + req->rq_callsize); + xdr_buf_init(&req->rq_rcv_buf, + req->rq_rbuffer, + req->rq_rcvsize); + + p = rpc_encode_header(task); + if (p == NULL) { + printk(KERN_INFO "RPC: couldn't encode RPC header, exit EIO\n"); + rpc_exit(task, -EIO); + return; + } + + encode = task->tk_msg.rpc_proc->p_encode; + if (encode == NULL) + return; + + task->tk_status = rpcauth_wrap_req(task, encode, req, p, + task->tk_msg.rpc_argp); +} + +/* + * 4. Get the server port number if not yet set + */ +static void +call_bind(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + dprint_status(task); + + task->tk_action = call_connect; + if (!xprt_bound(xprt)) { + task->tk_action = call_bind_status; + task->tk_timeout = xprt->bind_timeout; + xprt->ops->rpcbind(task); + } +} + +/* + * 4a. Sort out bind result + */ +static void +call_bind_status(struct rpc_task *task) +{ + int status = -EIO; + + if (task->tk_status >= 0) { + dprint_status(task); + task->tk_status = 0; + task->tk_action = call_connect; + return; + } + + trace_rpc_bind_status(task); + switch (task->tk_status) { + case -ENOMEM: + dprintk("RPC: %5u rpcbind out of memory\n", task->tk_pid); + rpc_delay(task, HZ >> 2); + goto retry_timeout; + case -EACCES: + dprintk("RPC: %5u remote rpcbind: RPC program/version " + "unavailable\n", task->tk_pid); + /* fail immediately if this is an RPC ping */ + if (task->tk_msg.rpc_proc->p_proc == 0) { + status = -EOPNOTSUPP; + break; + } + if (task->tk_rebind_retry == 0) + break; + task->tk_rebind_retry--; + rpc_delay(task, 3*HZ); + goto retry_timeout; + case -ETIMEDOUT: + dprintk("RPC: %5u rpcbind request timed out\n", + task->tk_pid); + goto retry_timeout; + case -EPFNOSUPPORT: + /* server doesn't support any rpcbind version we know of */ + dprintk("RPC: %5u unrecognized remote rpcbind service\n", + task->tk_pid); + break; + case -EPROTONOSUPPORT: + dprintk("RPC: %5u remote rpcbind version unavailable, retrying\n", + task->tk_pid); + goto retry_timeout; + case -ECONNREFUSED: /* connection problems */ + case -ECONNRESET: + case -ECONNABORTED: + case -ENOTCONN: + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -ENOBUFS: + case -EPIPE: + dprintk("RPC: %5u remote rpcbind unreachable: %d\n", + task->tk_pid, task->tk_status); + if (!RPC_IS_SOFTCONN(task)) { + rpc_delay(task, 5*HZ); + goto retry_timeout; + } + status = task->tk_status; + break; + default: + dprintk("RPC: %5u unrecognized rpcbind error (%d)\n", + task->tk_pid, -task->tk_status); + } + + rpc_exit(task, status); + return; + +retry_timeout: + task->tk_status = 0; + task->tk_action = call_timeout; +} + +/* + * 4b. Connect to the RPC server + */ +static void +call_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + dprintk("RPC: %5u call_connect xprt %p %s connected\n", + task->tk_pid, xprt, + (xprt_connected(xprt) ? "is" : "is not")); + + task->tk_action = call_transmit; + if (!xprt_connected(xprt)) { + task->tk_action = call_connect_status; + if (task->tk_status < 0) + return; + if (task->tk_flags & RPC_TASK_NOCONNECT) { + rpc_exit(task, -ENOTCONN); + return; + } + xprt_connect(task); + } +} + +/* + * 4c. Sort out connect result + */ +static void +call_connect_status(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + int status = task->tk_status; + + dprint_status(task); + + trace_rpc_connect_status(task); + task->tk_status = 0; + switch (status) { + case -ECONNREFUSED: + /* A positive refusal suggests a rebind is needed. */ + if (RPC_IS_SOFTCONN(task)) + break; + if (clnt->cl_autobind) { + rpc_force_rebind(clnt); + task->tk_action = call_bind; + return; + } + /* fall through */ + case -ECONNRESET: + case -ECONNABORTED: + case -ENETDOWN: + case -ENETUNREACH: + case -EHOSTUNREACH: + case -EADDRINUSE: + case -ENOBUFS: + case -EPIPE: + xprt_conditional_disconnect(task->tk_rqstp->rq_xprt, + task->tk_rqstp->rq_connect_cookie); + if (RPC_IS_SOFTCONN(task)) + break; + /* retry with existing socket, after a delay */ + rpc_delay(task, 3*HZ); + /* fall through */ + case -EAGAIN: + /* Check for timeouts before looping back to call_bind */ + case -ETIMEDOUT: + task->tk_action = call_timeout; + return; + case 0: + clnt->cl_stats->netreconn++; + task->tk_action = call_transmit; + return; + } + rpc_exit(task, status); +} + +/* + * 5. Transmit the RPC request, and wait for reply + */ +static void +call_transmit(struct rpc_task *task) +{ + int is_retrans = RPC_WAS_SENT(task); + + dprint_status(task); + + task->tk_action = call_status; + if (task->tk_status < 0) + return; + if (!xprt_prepare_transmit(task)) + return; + task->tk_action = call_transmit_status; + /* Encode here so that rpcsec_gss can use correct sequence number. */ + if (rpc_task_need_encode(task)) { + rpc_xdr_encode(task); + /* Did the encode result in an error condition? */ + if (task->tk_status != 0) { + /* Was the error nonfatal? */ + if (task->tk_status == -EAGAIN) + rpc_delay(task, HZ >> 4); + else + rpc_exit(task, task->tk_status); + return; + } + } + xprt_transmit(task); + if (task->tk_status < 0) + return; + if (is_retrans) + task->tk_client->cl_stats->rpcretrans++; + /* + * On success, ensure that we call xprt_end_transmit() before sleeping + * in order to allow access to the socket to other RPC requests. + */ + call_transmit_status(task); + if (rpc_reply_expected(task)) + return; + task->tk_action = rpc_exit_task; + rpc_wake_up_queued_task(&task->tk_rqstp->rq_xprt->pending, task); +} + +/* + * 5a. Handle cleanup after a transmission + */ +static void +call_transmit_status(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + task->tk_action = call_status; + + /* + * Common case: success. Force the compiler to put this + * test first. Or, if any error and xprt_close_wait, + * release the xprt lock so the socket can close. + */ + if (task->tk_status == 0 || xprt_close_wait(xprt)) { + xprt_end_transmit(task); + rpc_task_force_reencode(task); + return; + } + + switch (task->tk_status) { + case -EAGAIN: + case -ENOBUFS: + break; + default: + dprint_status(task); + xprt_end_transmit(task); + rpc_task_force_reencode(task); + break; + /* + * Special cases: if we've been waiting on the + * socket's write_space() callback, or if the + * socket just returned a connection error, + * then hold onto the transport lock. + */ + case -ECONNREFUSED: + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -EPERM: + if (RPC_IS_SOFTCONN(task)) { + xprt_end_transmit(task); + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, + task->tk_status); + rpc_exit(task, task->tk_status); + break; + } + /* fall through */ + case -ECONNRESET: + case -ECONNABORTED: + case -EADDRINUSE: + case -ENOTCONN: + case -EPIPE: + rpc_task_force_reencode(task); + } +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/* + * 5b. Send the backchannel RPC reply. On error, drop the reply. In + * addition, disconnect on connectivity errors. + */ +static void +call_bc_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (!xprt_prepare_transmit(task)) + goto out_retry; + + if (task->tk_status < 0) { + printk(KERN_NOTICE "RPC: Could not send backchannel reply " + "error: %d\n", task->tk_status); + goto out_done; + } + if (req->rq_connect_cookie != req->rq_xprt->connect_cookie) + req->rq_bytes_sent = 0; + + xprt_transmit(task); + + if (task->tk_status == -EAGAIN) + goto out_nospace; + + xprt_end_transmit(task); + dprint_status(task); + switch (task->tk_status) { + case 0: + /* Success */ + case -ENETDOWN: + case -EHOSTDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -ECONNRESET: + case -ECONNREFUSED: + case -EADDRINUSE: + case -ENOTCONN: + case -EPIPE: + break; + case -ETIMEDOUT: + /* + * Problem reaching the server. Disconnect and let the + * forechannel reestablish the connection. The server will + * have to retransmit the backchannel request and we'll + * reprocess it. Since these ops are idempotent, there's no + * need to cache our reply at this time. + */ + printk(KERN_NOTICE "RPC: Could not send backchannel reply " + "error: %d\n", task->tk_status); + xprt_conditional_disconnect(req->rq_xprt, + req->rq_connect_cookie); + break; + default: + /* + * We were unable to reply and will have to drop the + * request. The server should reconnect and retransmit. + */ + WARN_ON_ONCE(task->tk_status == -EAGAIN); + printk(KERN_NOTICE "RPC: Could not send backchannel reply " + "error: %d\n", task->tk_status); + break; + } + rpc_wake_up_queued_task(&req->rq_xprt->pending, task); +out_done: + task->tk_action = rpc_exit_task; + return; +out_nospace: + req->rq_connect_cookie = req->rq_xprt->connect_cookie; +out_retry: + task->tk_status = 0; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/* + * 6. Sort out the RPC call status + */ +static void +call_status(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + int status; + + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, task->tk_status); + + if (req->rq_reply_bytes_recvd > 0 && !req->rq_bytes_sent) + task->tk_status = req->rq_reply_bytes_recvd; + + dprint_status(task); + + status = task->tk_status; + if (status >= 0) { + task->tk_action = call_decode; + return; + } + + trace_rpc_call_status(task); + task->tk_status = 0; + switch(status) { + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -EPERM: + if (RPC_IS_SOFTCONN(task)) { + rpc_exit(task, status); + break; + } + /* + * Delay any retries for 3 seconds, then handle as if it + * were a timeout. + */ + rpc_delay(task, 3*HZ); + /* fall through */ + case -ETIMEDOUT: + task->tk_action = call_timeout; + break; + case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + rpc_force_rebind(clnt); + /* fall through */ + case -EADDRINUSE: + rpc_delay(task, 3*HZ); + /* fall through */ + case -EPIPE: + case -ENOTCONN: + task->tk_action = call_bind; + break; + case -ENOBUFS: + rpc_delay(task, HZ>>2); + /* fall through */ + case -EAGAIN: + task->tk_action = call_transmit; + break; + case -EIO: + /* shutdown or soft timeout */ + rpc_exit(task, status); + break; + default: + if (clnt->cl_chatty) + printk("%s: RPC call returned error %d\n", + clnt->cl_program->name, -status); + rpc_exit(task, status); + } +} + +/* + * 6a. Handle RPC timeout + * We do not release the request slot, so we keep using the + * same XID for all retransmits. + */ +static void +call_timeout(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + if (xprt_adjust_timeout(task->tk_rqstp) == 0) { + dprintk("RPC: %5u call_timeout (minor)\n", task->tk_pid); + goto retry; + } + + dprintk("RPC: %5u call_timeout (major)\n", task->tk_pid); + task->tk_timeouts++; + + if (RPC_IS_SOFTCONN(task)) { + rpc_exit(task, -ETIMEDOUT); + return; + } + if (RPC_IS_SOFT(task)) { + if (clnt->cl_chatty) { + printk(KERN_NOTICE "%s: server %s not responding, timed out\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + if (task->tk_flags & RPC_TASK_TIMEOUT) + rpc_exit(task, -ETIMEDOUT); + else + rpc_exit(task, -EIO); + return; + } + + if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) { + task->tk_flags |= RPC_CALL_MAJORSEEN; + if (clnt->cl_chatty) { + printk(KERN_NOTICE "%s: server %s not responding, still trying\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + } + rpc_force_rebind(clnt); + /* + * Did our request time out due to an RPCSEC_GSS out-of-sequence + * event? RFC2203 requires the server to drop all such requests. + */ + rpcauth_invalcred(task); + +retry: + task->tk_action = call_bind; + task->tk_status = 0; +} + +/* + * 7. Decode the RPC reply + */ +static void +call_decode(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; + __be32 *p; + + dprint_status(task); + + if (task->tk_flags & RPC_CALL_MAJORSEEN) { + if (clnt->cl_chatty) { + printk(KERN_NOTICE "%s: server %s OK\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + task->tk_flags &= ~RPC_CALL_MAJORSEEN; + } + + /* + * Ensure that we see all writes made by xprt_complete_rqst() + * before it changed req->rq_reply_bytes_recvd. + */ + smp_rmb(); + req->rq_rcv_buf.len = req->rq_private_buf.len; + + /* Check that the softirq receive buffer is valid */ + WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf, + sizeof(req->rq_rcv_buf)) != 0); + + if (req->rq_rcv_buf.len < 12) { + if (!RPC_IS_SOFT(task)) { + task->tk_action = call_bind; + goto out_retry; + } + dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", + clnt->cl_program->name, task->tk_status); + task->tk_action = call_timeout; + goto out_retry; + } + + p = rpc_verify_header(task); + if (IS_ERR(p)) { + if (p == ERR_PTR(-EAGAIN)) + goto out_retry; + return; + } + + task->tk_action = rpc_exit_task; + + if (decode) { + task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, + task->tk_msg.rpc_resp); + } + dprintk("RPC: %5u call_decode result %d\n", task->tk_pid, + task->tk_status); + return; +out_retry: + task->tk_status = 0; + /* Note: rpc_verify_header() may have freed the RPC slot */ + if (task->tk_rqstp == req) { + req->rq_reply_bytes_recvd = req->rq_rcv_buf.len = 0; + if (task->tk_client->cl_discrtry) + xprt_conditional_disconnect(req->rq_xprt, + req->rq_connect_cookie); + } +} + +static __be32 * +rpc_encode_header(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + __be32 *p = req->rq_svec[0].iov_base; + + /* FIXME: check buffer size? */ + + p = xprt_skip_transport_header(req->rq_xprt, p); + *p++ = req->rq_xid; /* XID */ + *p++ = htonl(RPC_CALL); /* CALL */ + *p++ = htonl(RPC_VERSION); /* RPC version */ + *p++ = htonl(clnt->cl_prog); /* program number */ + *p++ = htonl(clnt->cl_vers); /* program version */ + *p++ = htonl(task->tk_msg.rpc_proc->p_proc); /* procedure */ + p = rpcauth_marshcred(task, p); + req->rq_slen = xdr_adjust_iovec(&req->rq_svec[0], p); + return p; +} + +static __be32 * +rpc_verify_header(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct kvec *iov = &task->tk_rqstp->rq_rcv_buf.head[0]; + int len = task->tk_rqstp->rq_rcv_buf.len >> 2; + __be32 *p = iov->iov_base; + u32 n; + int error = -EACCES; + + if ((task->tk_rqstp->rq_rcv_buf.len & 3) != 0) { + /* RFC-1014 says that the representation of XDR data must be a + * multiple of four bytes + * - if it isn't pointer subtraction in the NFS client may give + * undefined results + */ + dprintk("RPC: %5u %s: XDR representation not a multiple of" + " 4 bytes: 0x%x\n", task->tk_pid, __func__, + task->tk_rqstp->rq_rcv_buf.len); + error = -EIO; + goto out_err; + } + if ((len -= 3) < 0) + goto out_overflow; + + p += 1; /* skip XID */ + if ((n = ntohl(*p++)) != RPC_REPLY) { + dprintk("RPC: %5u %s: not an RPC reply: %x\n", + task->tk_pid, __func__, n); + error = -EIO; + goto out_garbage; + } + + if ((n = ntohl(*p++)) != RPC_MSG_ACCEPTED) { + if (--len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_AUTH_ERROR: + break; + case RPC_MISMATCH: + dprintk("RPC: %5u %s: RPC call version mismatch!\n", + task->tk_pid, __func__); + error = -EPROTONOSUPPORT; + goto out_err; + default: + dprintk("RPC: %5u %s: RPC call rejected, " + "unknown error: %x\n", + task->tk_pid, __func__, n); + error = -EIO; + goto out_err; + } + if (--len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_AUTH_REJECTEDCRED: + case RPC_AUTH_REJECTEDVERF: + case RPCSEC_GSS_CREDPROBLEM: + case RPCSEC_GSS_CTXPROBLEM: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + dprintk("RPC: %5u %s: retry stale creds\n", + task->tk_pid, __func__); + rpcauth_invalcred(task); + /* Ensure we obtain a new XID! */ + xprt_release(task); + task->tk_action = call_reserve; + goto out_retry; + case RPC_AUTH_BADCRED: + case RPC_AUTH_BADVERF: + /* possibly garbled cred/verf? */ + if (!task->tk_garb_retry) + break; + task->tk_garb_retry--; + dprintk("RPC: %5u %s: retry garbled creds\n", + task->tk_pid, __func__); + task->tk_action = call_bind; + goto out_retry; + case RPC_AUTH_TOOWEAK: + printk(KERN_NOTICE "RPC: server %s requires stronger " + "authentication.\n", + task->tk_xprt->servername); + break; + default: + dprintk("RPC: %5u %s: unknown auth error: %x\n", + task->tk_pid, __func__, n); + error = -EIO; + } + dprintk("RPC: %5u %s: call rejected %d\n", + task->tk_pid, __func__, n); + goto out_err; + } + p = rpcauth_checkverf(task, p); + if (IS_ERR(p)) { + error = PTR_ERR(p); + dprintk("RPC: %5u %s: auth check failed with %d\n", + task->tk_pid, __func__, error); + goto out_garbage; /* bad verifier, retry */ + } + len = p - (__be32 *)iov->iov_base - 1; + if (len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_SUCCESS: + return p; + case RPC_PROG_UNAVAIL: + dprintk("RPC: %5u %s: program %u is unsupported " + "by server %s\n", task->tk_pid, __func__, + (unsigned int)clnt->cl_prog, + task->tk_xprt->servername); + error = -EPFNOSUPPORT; + goto out_err; + case RPC_PROG_MISMATCH: + dprintk("RPC: %5u %s: program %u, version %u unsupported " + "by server %s\n", task->tk_pid, __func__, + (unsigned int)clnt->cl_prog, + (unsigned int)clnt->cl_vers, + task->tk_xprt->servername); + error = -EPROTONOSUPPORT; + goto out_err; + case RPC_PROC_UNAVAIL: + dprintk("RPC: %5u %s: proc %s unsupported by program %u, " + "version %u on server %s\n", + task->tk_pid, __func__, + rpc_proc_name(task), + clnt->cl_prog, clnt->cl_vers, + task->tk_xprt->servername); + error = -EOPNOTSUPP; + goto out_err; + case RPC_GARBAGE_ARGS: + dprintk("RPC: %5u %s: server saw garbage\n", + task->tk_pid, __func__); + break; /* retry */ + default: + dprintk("RPC: %5u %s: server accept status: %x\n", + task->tk_pid, __func__, n); + /* Also retry */ + } + +out_garbage: + clnt->cl_stats->rpcgarbage++; + if (task->tk_garb_retry) { + task->tk_garb_retry--; + dprintk("RPC: %5u %s: retrying\n", + task->tk_pid, __func__); + task->tk_action = call_bind; +out_retry: + return ERR_PTR(-EAGAIN); + } +out_err: + rpc_exit(task, error); + dprintk("RPC: %5u %s: call failed with error %d\n", task->tk_pid, + __func__, error); + return ERR_PTR(error); +out_overflow: + dprintk("RPC: %5u %s: server reply was truncated.\n", task->tk_pid, + __func__); + goto out_garbage; +} + +static void rpcproc_encode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + const void *obj) +{ +} + +static int rpcproc_decode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + void *obj) +{ + return 0; +} + +static const struct rpc_procinfo rpcproc_null = { + .p_encode = rpcproc_encode_null, + .p_decode = rpcproc_decode_null, +}; + +static int rpc_ping(struct rpc_clnt *clnt) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null, + }; + int err; + msg.rpc_cred = authnull_ops.lookup_cred(NULL, NULL, 0); + err = rpc_call_sync(clnt, &msg, RPC_TASK_SOFT | RPC_TASK_SOFTCONN); + put_rpccred(msg.rpc_cred); + return err; +} + +static +struct rpc_task *rpc_call_null_helper(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, struct rpc_cred *cred, int flags, + const struct rpc_call_ops *ops, void *data) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null, + .rpc_cred = cred, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_xprt = xprt, + .rpc_message = &msg, + .callback_ops = (ops != NULL) ? ops : &rpc_default_ops, + .callback_data = data, + .flags = flags, + }; + + return rpc_run_task(&task_setup_data); +} + +struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags) +{ + return rpc_call_null_helper(clnt, NULL, cred, flags, NULL, NULL); +} +EXPORT_SYMBOL_GPL(rpc_call_null); + +struct rpc_cb_add_xprt_calldata { + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; +}; + +static void rpc_cb_add_xprt_done(struct rpc_task *task, void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + if (task->tk_status == 0) + rpc_xprt_switch_add_xprt(data->xps, data->xprt); +} + +static void rpc_cb_add_xprt_release(void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + xprt_put(data->xprt); + xprt_switch_put(data->xps); + kfree(data); +} + +static const struct rpc_call_ops rpc_cb_add_xprt_call_ops = { + .rpc_call_done = rpc_cb_add_xprt_done, + .rpc_release = rpc_cb_add_xprt_release, +}; + +/** + * rpc_clnt_test_and_add_xprt - Test and add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xps: pointer to struct rpc_xprt_switch, + * @xprt: pointer struct rpc_xprt + * @dummy: unused + */ +int rpc_clnt_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, struct rpc_xprt *xprt, + void *dummy) +{ + struct rpc_cb_add_xprt_calldata *data; + struct rpc_cred *cred; + struct rpc_task *task; + + data = kmalloc(sizeof(*data), GFP_NOFS); + if (!data) + return -ENOMEM; + data->xps = xprt_switch_get(xps); + data->xprt = xprt_get(xprt); + + cred = authnull_ops.lookup_cred(NULL, NULL, 0); + task = rpc_call_null_helper(clnt, xprt, cred, + RPC_TASK_SOFT|RPC_TASK_SOFTCONN|RPC_TASK_ASYNC, + &rpc_cb_add_xprt_call_ops, data); + put_rpccred(cred); + if (IS_ERR(task)) + return PTR_ERR(task); + rpc_put_task(task); + return 1; +} +EXPORT_SYMBOL_GPL(rpc_clnt_test_and_add_xprt); + +/** + * rpc_clnt_setup_test_and_add_xprt() + * + * This is an rpc_clnt_add_xprt setup() function which returns 1 so: + * 1) caller of the test function must dereference the rpc_xprt_switch + * and the rpc_xprt. + * 2) test function must call rpc_xprt_switch_add_xprt, usually in + * the rpc_call_done routine. + * + * Upon success (return of 1), the test function adds the new + * transport to the rpc_clnt xprt switch + * + * @clnt: struct rpc_clnt to get the new transport + * @xps: the rpc_xprt_switch to hold the new transport + * @xprt: the rpc_xprt to test + * @data: a struct rpc_add_xprt_test pointer that holds the test function + * and test function call data + */ +int rpc_clnt_setup_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + void *data) +{ + struct rpc_cred *cred; + struct rpc_task *task; + struct rpc_add_xprt_test *xtest = (struct rpc_add_xprt_test *)data; + int status = -EADDRINUSE; + + xprt = xprt_get(xprt); + xprt_switch_get(xps); + + if (rpc_xprt_switch_has_addr(xps, (struct sockaddr *)&xprt->addr)) + goto out_err; + + /* Test the connection */ + cred = authnull_ops.lookup_cred(NULL, NULL, 0); + task = rpc_call_null_helper(clnt, xprt, cred, + RPC_TASK_SOFT | RPC_TASK_SOFTCONN, + NULL, NULL); + put_rpccred(cred); + if (IS_ERR(task)) { + status = PTR_ERR(task); + goto out_err; + } + status = task->tk_status; + rpc_put_task(task); + + if (status < 0) + goto out_err; + + /* rpc_xprt_switch and rpc_xprt are deferrenced by add_xprt_test() */ + xtest->add_xprt_test(clnt, xprt, xtest->data); + + /* so that rpc_clnt_add_xprt does not call rpc_xprt_switch_add_xprt */ + return 1; +out_err: + xprt_put(xprt); + xprt_switch_put(xps); + pr_info("RPC: rpc_clnt_test_xprt failed: %d addr %s not added\n", + status, xprt->address_strings[RPC_DISPLAY_ADDR]); + return status; +} +EXPORT_SYMBOL_GPL(rpc_clnt_setup_test_and_add_xprt); + +/** + * rpc_clnt_add_xprt - Add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xprtargs: pointer to struct xprt_create + * @setup: callback to test and/or set up the connection + * @data: pointer to setup function data + * + * Creates a new transport using the parameters set in args and + * adds it to clnt. + * If ping is set, then test that connectivity succeeds before + * adding the new transport. + * + */ +int rpc_clnt_add_xprt(struct rpc_clnt *clnt, + struct xprt_create *xprtargs, + int (*setup)(struct rpc_clnt *, + struct rpc_xprt_switch *, + struct rpc_xprt *, + void *), + void *data) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; + unsigned long connect_timeout; + unsigned long reconnect_timeout; + unsigned char resvport; + int ret = 0; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + xprt = xprt_iter_xprt(&clnt->cl_xpi); + if (xps == NULL || xprt == NULL) { + rcu_read_unlock(); + xprt_switch_put(xps); + return -EAGAIN; + } + resvport = xprt->resvport; + connect_timeout = xprt->connect_timeout; + reconnect_timeout = xprt->max_reconnect_timeout; + rcu_read_unlock(); + + xprt = xprt_create_transport(xprtargs); + if (IS_ERR(xprt)) { + ret = PTR_ERR(xprt); + goto out_put_switch; + } + xprt->resvport = resvport; + if (xprt->ops->set_connect_timeout != NULL) + xprt->ops->set_connect_timeout(xprt, + connect_timeout, + reconnect_timeout); + + rpc_xprt_switch_set_roundrobin(xps); + if (setup) { + ret = setup(clnt, xps, xprt, data); + if (ret != 0) + goto out_put_xprt; + } + rpc_xprt_switch_add_xprt(xps, xprt); +out_put_xprt: + xprt_put(xprt); +out_put_switch: + xprt_switch_put(xps); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_add_xprt); + +struct connect_timeout_data { + unsigned long connect_timeout; + unsigned long reconnect_timeout; +}; + +static int +rpc_xprt_set_connect_timeout(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *data) +{ + struct connect_timeout_data *timeo = data; + + if (xprt->ops->set_connect_timeout) + xprt->ops->set_connect_timeout(xprt, + timeo->connect_timeout, + timeo->reconnect_timeout); + return 0; +} + +void +rpc_set_connect_timeout(struct rpc_clnt *clnt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct connect_timeout_data timeout = { + .connect_timeout = connect_timeout, + .reconnect_timeout = reconnect_timeout, + }; + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_xprt_set_connect_timeout, + &timeout); +} +EXPORT_SYMBOL_GPL(rpc_set_connect_timeout); + +void rpc_clnt_xprt_switch_put(struct rpc_clnt *clnt) +{ + rcu_read_lock(); + xprt_switch_put(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_put); + +void rpc_clnt_xprt_switch_add_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + rcu_read_lock(); + rpc_xprt_switch_add_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch), + xprt); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_add_xprt); + +bool rpc_clnt_xprt_switch_has_addr(struct rpc_clnt *clnt, + const struct sockaddr *sap) +{ + struct rpc_xprt_switch *xps; + bool ret; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + ret = rpc_xprt_switch_has_addr(xps, sap); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_has_addr); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +static void rpc_show_header(void) +{ + printk(KERN_INFO "-pid- flgs status -client- --rqstp- " + "-timeout ---ops--\n"); +} + +static void rpc_show_task(const struct rpc_clnt *clnt, + const struct rpc_task *task) +{ + const char *rpc_waitq = "none"; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%ps q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt, task->tk_rqstp, task->tk_timeout, task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); +} + +void rpc_show_tasks(struct net *net) +{ + struct rpc_clnt *clnt; + struct rpc_task *task; + int header = 0; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) { + if (!header) { + rpc_show_header(); + header++; + } + rpc_show_task(clnt, task); + } + spin_unlock(&clnt->cl_lock); + } + spin_unlock(&sn->rpc_client_lock); +} +#endif + +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +static int +rpc_clnt_swap_activate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + return xprt_enable_swap(xprt); +} + +int +rpc_clnt_swap_activate(struct rpc_clnt *clnt) +{ + if (atomic_inc_return(&clnt->cl_swapper) == 1) + return rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_activate_callback, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_activate); + +static int +rpc_clnt_swap_deactivate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + xprt_disable_swap(xprt); + return 0; +} + +void +rpc_clnt_swap_deactivate(struct rpc_clnt *clnt) +{ + if (atomic_dec_if_positive(&clnt->cl_swapper) == 0) + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_deactivate_callback, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_deactivate); +#endif /* CONFIG_SUNRPC_SWAP */ diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c new file mode 100644 index 000000000..45a033329 --- /dev/null +++ b/net/sunrpc/debugfs.c @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: GPL-2.0 +/** + * debugfs interface for sunrpc + * + * (c) 2014 Jeff Layton <jlayton@primarydata.com> + */ + +#include <linux/debugfs.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/clnt.h> +#include "netns.h" + +static struct dentry *topdir; +static struct dentry *rpc_fault_dir; +static struct dentry *rpc_clnt_dir; +static struct dentry *rpc_xprt_dir; + +unsigned int rpc_inject_disconnect; + +static int +tasks_show(struct seq_file *f, void *v) +{ + u32 xid = 0; + struct rpc_task *task = v; + struct rpc_clnt *clnt = task->tk_client; + const char *rpc_waitq = "none"; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + if (task->tk_rqstp) + xid = be32_to_cpu(task->tk_rqstp->rq_xid); + + seq_printf(f, "%5u %04x %6d 0x%x 0x%x %8ld %ps %sv%u %s a:%ps q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt->cl_clid, xid, task->tk_timeout, task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); + return 0; +} + +static void * +tasks_start(struct seq_file *f, loff_t *ppos) + __acquires(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + loff_t pos = *ppos; + struct rpc_task *task; + + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) + if (pos-- == 0) + return task; + return NULL; +} + +static void * +tasks_next(struct seq_file *f, void *v, loff_t *pos) +{ + struct rpc_clnt *clnt = f->private; + struct rpc_task *task = v; + struct list_head *next = task->tk_task.next; + + ++*pos; + + /* If there's another task on list, return it */ + if (next == &clnt->cl_tasks) + return NULL; + return list_entry(next, struct rpc_task, tk_task); +} + +static void +tasks_stop(struct seq_file *f, void *v) + __releases(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + spin_unlock(&clnt->cl_lock); +} + +static const struct seq_operations tasks_seq_operations = { + .start = tasks_start, + .next = tasks_next, + .stop = tasks_stop, + .show = tasks_show, +}; + +static int tasks_open(struct inode *inode, struct file *filp) +{ + int ret = seq_open(filp, &tasks_seq_operations); + if (!ret) { + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private = inode->i_private; + + if (!atomic_inc_not_zero(&clnt->cl_count)) { + seq_release(inode, filp); + ret = -EINVAL; + } + } + + return ret; +} + +static int +tasks_release(struct inode *inode, struct file *filp) +{ + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private; + + rpc_release_client(clnt); + return seq_release(inode, filp); +} + +static const struct file_operations tasks_fops = { + .owner = THIS_MODULE, + .open = tasks_open, + .read = seq_read, + .llseek = seq_lseek, + .release = tasks_release, +}; + +void +rpc_clnt_debugfs_register(struct rpc_clnt *clnt) +{ + int len; + char name[24]; /* enough for "../../rpc_xprt/ + 8 hex digits + NULL */ + struct rpc_xprt *xprt; + + /* Already registered? */ + if (clnt->cl_debugfs || !rpc_clnt_dir) + return; + + len = snprintf(name, sizeof(name), "%x", clnt->cl_clid); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + clnt->cl_debugfs = debugfs_create_dir(name, rpc_clnt_dir); + if (!clnt->cl_debugfs) + return; + + /* make tasks file */ + if (!debugfs_create_file("tasks", S_IFREG | 0400, clnt->cl_debugfs, + clnt, &tasks_fops)) + goto out_err; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + /* no "debugfs" dentry? Don't bother with the symlink. */ + if (!xprt->debugfs) { + rcu_read_unlock(); + return; + } + len = snprintf(name, sizeof(name), "../../rpc_xprt/%s", + xprt->debugfs->d_name.name); + rcu_read_unlock(); + + if (len >= sizeof(name)) + goto out_err; + + if (!debugfs_create_symlink("xprt", clnt->cl_debugfs, name)) + goto out_err; + + return; +out_err: + debugfs_remove_recursive(clnt->cl_debugfs); + clnt->cl_debugfs = NULL; +} + +void +rpc_clnt_debugfs_unregister(struct rpc_clnt *clnt) +{ + debugfs_remove_recursive(clnt->cl_debugfs); + clnt->cl_debugfs = NULL; +} + +static int +xprt_info_show(struct seq_file *f, void *v) +{ + struct rpc_xprt *xprt = f->private; + + seq_printf(f, "netid: %s\n", xprt->address_strings[RPC_DISPLAY_NETID]); + seq_printf(f, "addr: %s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); + seq_printf(f, "port: %s\n", xprt->address_strings[RPC_DISPLAY_PORT]); + seq_printf(f, "state: 0x%lx\n", xprt->state); + return 0; +} + +static int +xprt_info_open(struct inode *inode, struct file *filp) +{ + int ret; + struct rpc_xprt *xprt = inode->i_private; + + ret = single_open(filp, xprt_info_show, xprt); + + if (!ret) { + if (!xprt_get(xprt)) { + single_release(inode, filp); + ret = -EINVAL; + } + } + return ret; +} + +static int +xprt_info_release(struct inode *inode, struct file *filp) +{ + struct rpc_xprt *xprt = inode->i_private; + + xprt_put(xprt); + return single_release(inode, filp); +} + +static const struct file_operations xprt_info_fops = { + .owner = THIS_MODULE, + .open = xprt_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = xprt_info_release, +}; + +void +rpc_xprt_debugfs_register(struct rpc_xprt *xprt) +{ + int len, id; + static atomic_t cur_id; + char name[9]; /* 8 hex digits + NULL term */ + + if (!rpc_xprt_dir) + return; + + id = (unsigned int)atomic_inc_return(&cur_id); + + len = snprintf(name, sizeof(name), "%x", id); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + xprt->debugfs = debugfs_create_dir(name, rpc_xprt_dir); + if (!xprt->debugfs) + return; + + /* make tasks file */ + if (!debugfs_create_file("info", S_IFREG | 0400, xprt->debugfs, + xprt, &xprt_info_fops)) { + debugfs_remove_recursive(xprt->debugfs); + xprt->debugfs = NULL; + } + + atomic_set(&xprt->inject_disconnect, rpc_inject_disconnect); +} + +void +rpc_xprt_debugfs_unregister(struct rpc_xprt *xprt) +{ + debugfs_remove_recursive(xprt->debugfs); + xprt->debugfs = NULL; +} + +static int +fault_open(struct inode *inode, struct file *filp) +{ + filp->private_data = kmalloc(128, GFP_KERNEL); + if (!filp->private_data) + return -ENOMEM; + return 0; +} + +static int +fault_release(struct inode *inode, struct file *filp) +{ + kfree(filp->private_data); + return 0; +} + +static ssize_t +fault_disconnect_read(struct file *filp, char __user *user_buf, + size_t len, loff_t *offset) +{ + char *buffer = (char *)filp->private_data; + size_t size; + + size = sprintf(buffer, "%u\n", rpc_inject_disconnect); + return simple_read_from_buffer(user_buf, len, offset, buffer, size); +} + +static ssize_t +fault_disconnect_write(struct file *filp, const char __user *user_buf, + size_t len, loff_t *offset) +{ + char buffer[16]; + + if (len >= sizeof(buffer)) + len = sizeof(buffer) - 1; + if (copy_from_user(buffer, user_buf, len)) + return -EFAULT; + buffer[len] = '\0'; + if (kstrtouint(buffer, 10, &rpc_inject_disconnect)) + return -EINVAL; + return len; +} + +static const struct file_operations fault_disconnect_fops = { + .owner = THIS_MODULE, + .open = fault_open, + .read = fault_disconnect_read, + .write = fault_disconnect_write, + .release = fault_release, +}; + +static struct dentry * +inject_fault_dir(struct dentry *topdir) +{ + struct dentry *faultdir; + + faultdir = debugfs_create_dir("inject_fault", topdir); + if (!faultdir) + return NULL; + + if (!debugfs_create_file("disconnect", S_IFREG | 0400, faultdir, + NULL, &fault_disconnect_fops)) + return NULL; + + return faultdir; +} + +void __exit +sunrpc_debugfs_exit(void) +{ + debugfs_remove_recursive(topdir); + topdir = NULL; + rpc_fault_dir = NULL; + rpc_clnt_dir = NULL; + rpc_xprt_dir = NULL; +} + +void __init +sunrpc_debugfs_init(void) +{ + topdir = debugfs_create_dir("sunrpc", NULL); + if (!topdir) + return; + + rpc_fault_dir = inject_fault_dir(topdir); + if (!rpc_fault_dir) + goto out_remove; + + rpc_clnt_dir = debugfs_create_dir("rpc_clnt", topdir); + if (!rpc_clnt_dir) + goto out_remove; + + rpc_xprt_dir = debugfs_create_dir("rpc_xprt", topdir); + if (!rpc_xprt_dir) + goto out_remove; + + return; +out_remove: + debugfs_remove_recursive(topdir); + topdir = NULL; + rpc_fault_dir = NULL; + rpc_clnt_dir = NULL; +} diff --git a/net/sunrpc/netns.h b/net/sunrpc/netns.h new file mode 100644 index 000000000..7ec10b92b --- /dev/null +++ b/net/sunrpc/netns.h @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __SUNRPC_NETNS_H__ +#define __SUNRPC_NETNS_H__ + +#include <net/net_namespace.h> +#include <net/netns/generic.h> + +struct cache_detail; + +struct sunrpc_net { + struct proc_dir_entry *proc_net_rpc; + struct cache_detail *ip_map_cache; + struct cache_detail *unix_gid_cache; + struct cache_detail *rsc_cache; + struct cache_detail *rsi_cache; + + struct super_block *pipefs_sb; + struct rpc_pipe *gssd_dummy; + struct mutex pipefs_sb_lock; + + struct list_head all_clients; + spinlock_t rpc_client_lock; + + struct rpc_clnt *rpcb_local_clnt; + struct rpc_clnt *rpcb_local_clnt4; + spinlock_t rpcb_clnt_lock; + unsigned int rpcb_users; + unsigned int rpcb_is_af_local : 1; + + struct mutex gssp_lock; + struct rpc_clnt *gssp_clnt; + int use_gss_proxy; + int pipe_version; + atomic_t pipe_users; + struct proc_dir_entry *use_gssp_proc; +}; + +extern unsigned int sunrpc_net_id; + +int ip_map_cache_create(struct net *); +void ip_map_cache_destroy(struct net *); + +#endif diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c new file mode 100644 index 000000000..285eab5b4 --- /dev/null +++ b/net/sunrpc/rpc_pipe.c @@ -0,0 +1,1513 @@ +/* + * net/sunrpc/rpc_pipe.c + * + * Userland/kernel interface for rpcauth_gss. + * Code shamelessly plagiarized from fs/nfsd/nfsctl.c + * and fs/sysfs/inode.c + * + * Copyright (c) 2002, Trond Myklebust <trond.myklebust@fys.uio.no> + * + */ +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/string.h> +#include <linux/pagemap.h> +#include <linux/mount.h> +#include <linux/namei.h> +#include <linux/fsnotify.h> +#include <linux/kernel.h> +#include <linux/rcupdate.h> +#include <linux/utsname.h> + +#include <asm/ioctls.h> +#include <linux/poll.h> +#include <linux/wait.h> +#include <linux/seq_file.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/cache.h> +#include <linux/nsproxy.h> +#include <linux/notifier.h> + +#include "netns.h" +#include "sunrpc.h" + +#define RPCDBG_FACILITY RPCDBG_DEBUG + +#define NET_NAME(net) ((net == &init_net) ? " (init_net)" : "") + +static struct file_system_type rpc_pipe_fs_type; +static const struct rpc_pipe_ops gssd_dummy_pipe_ops; + +static struct kmem_cache *rpc_inode_cachep __read_mostly; + +#define RPC_UPCALL_TIMEOUT (30*HZ) + +static BLOCKING_NOTIFIER_HEAD(rpc_pipefs_notifier_list); + +int rpc_pipefs_notifier_register(struct notifier_block *nb) +{ + return blocking_notifier_chain_cond_register(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_register); + +void rpc_pipefs_notifier_unregister(struct notifier_block *nb) +{ + blocking_notifier_chain_unregister(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_unregister); + +static void rpc_purge_list(wait_queue_head_t *waitq, struct list_head *head, + void (*destroy_msg)(struct rpc_pipe_msg *), int err) +{ + struct rpc_pipe_msg *msg; + + if (list_empty(head)) + return; + do { + msg = list_entry(head->next, struct rpc_pipe_msg, list); + list_del_init(&msg->list); + msg->errno = err; + destroy_msg(msg); + } while (!list_empty(head)); + + if (waitq) + wake_up(waitq); +} + +static void +rpc_timeout_upcall_queue(struct work_struct *work) +{ + LIST_HEAD(free_list); + struct rpc_pipe *pipe = + container_of(work, struct rpc_pipe, queue_timeout.work); + void (*destroy_msg)(struct rpc_pipe_msg *); + struct dentry *dentry; + + spin_lock(&pipe->lock); + destroy_msg = pipe->ops->destroy_msg; + if (pipe->nreaders == 0) { + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + } + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + rpc_purge_list(dentry ? &RPC_I(d_inode(dentry))->waitq : NULL, + &free_list, destroy_msg, -ETIMEDOUT); + dput(dentry); +} + +ssize_t rpc_pipe_generic_upcall(struct file *filp, struct rpc_pipe_msg *msg, + char __user *dst, size_t buflen) +{ + char *data = (char *)msg->data + msg->copied; + size_t mlen = min(msg->len - msg->copied, buflen); + unsigned long left; + + left = copy_to_user(dst, data, mlen); + if (left == mlen) { + msg->errno = -EFAULT; + return -EFAULT; + } + + mlen -= left; + msg->copied += mlen; + msg->errno = 0; + return mlen; +} +EXPORT_SYMBOL_GPL(rpc_pipe_generic_upcall); + +/** + * rpc_queue_upcall - queue an upcall message to userspace + * @pipe: upcall pipe on which to queue given message + * @msg: message to queue + * + * Call with an @inode created by rpc_mkpipe() to queue an upcall. + * A userspace process may then later read the upcall by performing a + * read on an open file for this inode. It is up to the caller to + * initialize the fields of @msg (other than @msg->list) appropriately. + */ +int +rpc_queue_upcall(struct rpc_pipe *pipe, struct rpc_pipe_msg *msg) +{ + int res = -EPIPE; + struct dentry *dentry; + + spin_lock(&pipe->lock); + if (pipe->nreaders) { + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; + res = 0; + } else if (pipe->flags & RPC_PIPE_WAIT_FOR_OPEN) { + if (list_empty(&pipe->pipe)) + queue_delayed_work(rpciod_workqueue, + &pipe->queue_timeout, + RPC_UPCALL_TIMEOUT); + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; + res = 0; + } + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + if (dentry) { + wake_up(&RPC_I(d_inode(dentry))->waitq); + dput(dentry); + } + return res; +} +EXPORT_SYMBOL_GPL(rpc_queue_upcall); + +static inline void +rpc_inode_setowner(struct inode *inode, void *private) +{ + RPC_I(inode)->private = private; +} + +static void +rpc_close_pipes(struct inode *inode) +{ + struct rpc_pipe *pipe = RPC_I(inode)->pipe; + int need_release; + LIST_HEAD(free_list); + + inode_lock(inode); + spin_lock(&pipe->lock); + need_release = pipe->nreaders != 0 || pipe->nwriters != 0; + pipe->nreaders = 0; + list_splice_init(&pipe->in_upcall, &free_list); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + pipe->dentry = NULL; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, pipe->ops->destroy_msg, -EPIPE); + pipe->nwriters = 0; + if (need_release && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); + cancel_delayed_work_sync(&pipe->queue_timeout); + rpc_inode_setowner(inode, NULL); + RPC_I(inode)->pipe = NULL; + inode_unlock(inode); +} + +static struct inode * +rpc_alloc_inode(struct super_block *sb) +{ + struct rpc_inode *rpci; + rpci = kmem_cache_alloc(rpc_inode_cachep, GFP_KERNEL); + if (!rpci) + return NULL; + return &rpci->vfs_inode; +} + +static void +rpc_i_callback(struct rcu_head *head) +{ + struct inode *inode = container_of(head, struct inode, i_rcu); + kmem_cache_free(rpc_inode_cachep, RPC_I(inode)); +} + +static void +rpc_destroy_inode(struct inode *inode) +{ + call_rcu(&inode->i_rcu, rpc_i_callback); +} + +static int +rpc_pipe_open(struct inode *inode, struct file *filp) +{ + struct rpc_pipe *pipe; + int first_open; + int res = -ENXIO; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) + goto out; + first_open = pipe->nreaders == 0 && pipe->nwriters == 0; + if (first_open && pipe->ops->open_pipe) { + res = pipe->ops->open_pipe(inode); + if (res) + goto out; + } + if (filp->f_mode & FMODE_READ) + pipe->nreaders++; + if (filp->f_mode & FMODE_WRITE) + pipe->nwriters++; + res = 0; +out: + inode_unlock(inode); + return res; +} + +static int +rpc_pipe_release(struct inode *inode, struct file *filp) +{ + struct rpc_pipe *pipe; + struct rpc_pipe_msg *msg; + int last_close; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) + goto out; + msg = filp->private_data; + if (msg != NULL) { + spin_lock(&pipe->lock); + msg->errno = -EAGAIN; + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); + } + if (filp->f_mode & FMODE_WRITE) + pipe->nwriters --; + if (filp->f_mode & FMODE_READ) { + pipe->nreaders --; + if (pipe->nreaders == 0) { + LIST_HEAD(free_list); + spin_lock(&pipe->lock); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, + pipe->ops->destroy_msg, -EAGAIN); + } + } + last_close = pipe->nwriters == 0 && pipe->nreaders == 0; + if (last_close && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); +out: + inode_unlock(inode); + return 0; +} + +static ssize_t +rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; + struct rpc_pipe_msg *msg; + int res = 0; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { + res = -EPIPE; + goto out_unlock; + } + msg = filp->private_data; + if (msg == NULL) { + spin_lock(&pipe->lock); + if (!list_empty(&pipe->pipe)) { + msg = list_entry(pipe->pipe.next, + struct rpc_pipe_msg, + list); + list_move(&msg->list, &pipe->in_upcall); + pipe->pipelen -= msg->len; + filp->private_data = msg; + msg->copied = 0; + } + spin_unlock(&pipe->lock); + if (msg == NULL) + goto out_unlock; + } + /* NOTE: it is up to the callback to update msg->copied */ + res = pipe->ops->upcall(filp, msg, buf, len); + if (res < 0 || msg->len == msg->copied) { + filp->private_data = NULL; + spin_lock(&pipe->lock); + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); + } +out_unlock: + inode_unlock(inode); + return res; +} + +static ssize_t +rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = file_inode(filp); + int res; + + inode_lock(inode); + res = -EPIPE; + if (RPC_I(inode)->pipe != NULL) + res = RPC_I(inode)->pipe->ops->downcall(filp, buf, len); + inode_unlock(inode); + return res; +} + +static __poll_t +rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait) +{ + struct inode *inode = file_inode(filp); + struct rpc_inode *rpci = RPC_I(inode); + __poll_t mask = EPOLLOUT | EPOLLWRNORM; + + poll_wait(filp, &rpci->waitq, wait); + + inode_lock(inode); + if (rpci->pipe == NULL) + mask |= EPOLLERR | EPOLLHUP; + else if (filp->private_data || !list_empty(&rpci->pipe->pipe)) + mask |= EPOLLIN | EPOLLRDNORM; + inode_unlock(inode); + return mask; +} + +static long +rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; + int len; + + switch (cmd) { + case FIONREAD: + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { + inode_unlock(inode); + return -EPIPE; + } + spin_lock(&pipe->lock); + len = pipe->pipelen; + if (filp->private_data) { + struct rpc_pipe_msg *msg; + msg = filp->private_data; + len += msg->len - msg->copied; + } + spin_unlock(&pipe->lock); + inode_unlock(inode); + return put_user(len, (int __user *)arg); + default: + return -EINVAL; + } +} + +static const struct file_operations rpc_pipe_fops = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = rpc_pipe_read, + .write = rpc_pipe_write, + .poll = rpc_pipe_poll, + .unlocked_ioctl = rpc_pipe_ioctl, + .open = rpc_pipe_open, + .release = rpc_pipe_release, +}; + +static int +rpc_show_info(struct seq_file *m, void *v) +{ + struct rpc_clnt *clnt = m->private; + + rcu_read_lock(); + seq_printf(m, "RPC server: %s\n", + rcu_dereference(clnt->cl_xprt)->servername); + seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_program->name, + clnt->cl_prog, clnt->cl_vers); + seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); + seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO)); + seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT)); + rcu_read_unlock(); + return 0; +} + +static int +rpc_info_open(struct inode *inode, struct file *file) +{ + struct rpc_clnt *clnt = NULL; + int ret = single_open(file, rpc_show_info, NULL); + + if (!ret) { + struct seq_file *m = file->private_data; + + spin_lock(&file->f_path.dentry->d_lock); + if (!d_unhashed(file->f_path.dentry)) + clnt = RPC_I(inode)->private; + if (clnt != NULL && atomic_inc_not_zero(&clnt->cl_count)) { + spin_unlock(&file->f_path.dentry->d_lock); + m->private = clnt; + } else { + spin_unlock(&file->f_path.dentry->d_lock); + single_release(inode, file); + ret = -EINVAL; + } + } + return ret; +} + +static int +rpc_info_release(struct inode *inode, struct file *file) +{ + struct seq_file *m = file->private_data; + struct rpc_clnt *clnt = (struct rpc_clnt *)m->private; + + if (clnt) + rpc_release_client(clnt); + return single_release(inode, file); +} + +static const struct file_operations rpc_info_operations = { + .owner = THIS_MODULE, + .open = rpc_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = rpc_info_release, +}; + + +/* + * Description of fs contents. + */ +struct rpc_filelist { + const char *name; + const struct file_operations *i_fop; + umode_t mode; +}; + +static struct inode * +rpc_get_inode(struct super_block *sb, umode_t mode) +{ + struct inode *inode = new_inode(sb); + if (!inode) + return NULL; + inode->i_ino = get_next_ino(); + inode->i_mode = mode; + inode->i_atime = inode->i_mtime = inode->i_ctime = current_time(inode); + switch (mode & S_IFMT) { + case S_IFDIR: + inode->i_fop = &simple_dir_operations; + inode->i_op = &simple_dir_inode_operations; + inc_nlink(inode); + default: + break; + } + return inode; +} + +static int __rpc_create_common(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + struct inode *inode; + + d_drop(dentry); + inode = rpc_get_inode(dir->i_sb, mode); + if (!inode) + goto out_err; + inode->i_ino = iunique(dir->i_sb, 100); + if (i_fop) + inode->i_fop = i_fop; + if (private) + rpc_inode_setowner(inode, private); + d_add(dentry, inode); + return 0; +out_err: + printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %pd\n", + __FILE__, __func__, dentry); + dput(dentry); + return -ENOMEM; +} + +static int __rpc_create(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + int err; + + err = __rpc_create_common(dir, dentry, S_IFREG | mode, i_fop, private); + if (err) + return err; + fsnotify_create(dir, dentry); + return 0; +} + +static int __rpc_mkdir(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + int err; + + err = __rpc_create_common(dir, dentry, S_IFDIR | mode, i_fop, private); + if (err) + return err; + inc_nlink(dir); + fsnotify_mkdir(dir, dentry); + return 0; +} + +static void +init_pipe(struct rpc_pipe *pipe) +{ + pipe->nreaders = 0; + pipe->nwriters = 0; + INIT_LIST_HEAD(&pipe->in_upcall); + INIT_LIST_HEAD(&pipe->in_downcall); + INIT_LIST_HEAD(&pipe->pipe); + pipe->pipelen = 0; + INIT_DELAYED_WORK(&pipe->queue_timeout, + rpc_timeout_upcall_queue); + pipe->ops = NULL; + spin_lock_init(&pipe->lock); + pipe->dentry = NULL; +} + +void rpc_destroy_pipe_data(struct rpc_pipe *pipe) +{ + kfree(pipe); +} +EXPORT_SYMBOL_GPL(rpc_destroy_pipe_data); + +struct rpc_pipe *rpc_mkpipe_data(const struct rpc_pipe_ops *ops, int flags) +{ + struct rpc_pipe *pipe; + + pipe = kzalloc(sizeof(struct rpc_pipe), GFP_KERNEL); + if (!pipe) + return ERR_PTR(-ENOMEM); + init_pipe(pipe); + pipe->ops = ops; + pipe->flags = flags; + return pipe; +} +EXPORT_SYMBOL_GPL(rpc_mkpipe_data); + +static int __rpc_mkpipe_dentry(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private, + struct rpc_pipe *pipe) +{ + struct rpc_inode *rpci; + int err; + + err = __rpc_create_common(dir, dentry, S_IFIFO | mode, i_fop, private); + if (err) + return err; + rpci = RPC_I(d_inode(dentry)); + rpci->private = private; + rpci->pipe = pipe; + fsnotify_create(dir, dentry); + return 0; +} + +static int __rpc_rmdir(struct inode *dir, struct dentry *dentry) +{ + int ret; + + dget(dentry); + ret = simple_rmdir(dir, dentry); + d_delete(dentry); + dput(dentry); + return ret; +} + +static int __rpc_unlink(struct inode *dir, struct dentry *dentry) +{ + int ret; + + dget(dentry); + ret = simple_unlink(dir, dentry); + d_delete(dentry); + dput(dentry); + return ret; +} + +static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry) +{ + struct inode *inode = d_inode(dentry); + + rpc_close_pipes(inode); + return __rpc_unlink(dir, dentry); +} + +static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, + const char *name) +{ + struct qstr q = QSTR_INIT(name, strlen(name)); + struct dentry *dentry = d_hash_and_lookup(parent, &q); + if (!dentry) { + dentry = d_alloc(parent, &q); + if (!dentry) + return ERR_PTR(-ENOMEM); + } + if (d_really_is_negative(dentry)) + return dentry; + dput(dentry); + return ERR_PTR(-EEXIST); +} + +/* + * FIXME: This probably has races. + */ +static void __rpc_depopulate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof) +{ + struct inode *dir = d_inode(parent); + struct dentry *dentry; + struct qstr name; + int i; + + for (i = start; i < eof; i++) { + name.name = files[i].name; + name.len = strlen(files[i].name); + dentry = d_hash_and_lookup(parent, &name); + + if (dentry == NULL) + continue; + if (d_really_is_negative(dentry)) + goto next; + switch (d_inode(dentry)->i_mode & S_IFMT) { + default: + BUG(); + case S_IFREG: + __rpc_unlink(dir, dentry); + break; + case S_IFDIR: + __rpc_rmdir(dir, dentry); + } +next: + dput(dentry); + } +} + +static void rpc_depopulate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof) +{ + struct inode *dir = d_inode(parent); + + inode_lock_nested(dir, I_MUTEX_CHILD); + __rpc_depopulate(parent, files, start, eof); + inode_unlock(dir); +} + +static int rpc_populate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof, + void *private) +{ + struct inode *dir = d_inode(parent); + struct dentry *dentry; + int i, err; + + inode_lock(dir); + for (i = start; i < eof; i++) { + dentry = __rpc_lookup_create_exclusive(parent, files[i].name); + err = PTR_ERR(dentry); + if (IS_ERR(dentry)) + goto out_bad; + switch (files[i].mode & S_IFMT) { + default: + BUG(); + case S_IFREG: + err = __rpc_create(dir, dentry, + files[i].mode, + files[i].i_fop, + private); + break; + case S_IFDIR: + err = __rpc_mkdir(dir, dentry, + files[i].mode, + NULL, + private); + } + if (err != 0) + goto out_bad; + } + inode_unlock(dir); + return 0; +out_bad: + __rpc_depopulate(parent, files, start, eof); + inode_unlock(dir); + printk(KERN_WARNING "%s: %s failed to populate directory %pd\n", + __FILE__, __func__, parent); + return err; +} + +static struct dentry *rpc_mkdir_populate(struct dentry *parent, + const char *name, umode_t mode, void *private, + int (*populate)(struct dentry *, void *), void *args_populate) +{ + struct dentry *dentry; + struct inode *dir = d_inode(parent); + int error; + + inode_lock_nested(dir, I_MUTEX_PARENT); + dentry = __rpc_lookup_create_exclusive(parent, name); + if (IS_ERR(dentry)) + goto out; + error = __rpc_mkdir(dir, dentry, mode, NULL, private); + if (error != 0) + goto out_err; + if (populate != NULL) { + error = populate(dentry, args_populate); + if (error) + goto err_rmdir; + } +out: + inode_unlock(dir); + return dentry; +err_rmdir: + __rpc_rmdir(dir, dentry); +out_err: + dentry = ERR_PTR(error); + goto out; +} + +static int rpc_rmdir_depopulate(struct dentry *dentry, + void (*depopulate)(struct dentry *)) +{ + struct dentry *parent; + struct inode *dir; + int error; + + parent = dget_parent(dentry); + dir = d_inode(parent); + inode_lock_nested(dir, I_MUTEX_PARENT); + if (depopulate != NULL) + depopulate(dentry); + error = __rpc_rmdir(dir, dentry); + inode_unlock(dir); + dput(parent); + return error; +} + +/** + * rpc_mkpipe - make an rpc_pipefs file for kernel<->userspace communication + * @parent: dentry of directory to create new "pipe" in + * @name: name of pipe + * @private: private data to associate with the pipe, for the caller's use + * @pipe: &rpc_pipe containing input parameters + * + * Data is made available for userspace to read by calls to + * rpc_queue_upcall(). The actual reads will result in calls to + * @ops->upcall, which will be called with the file pointer, + * message, and userspace buffer to copy to. + * + * Writes can come at any time, and do not necessarily have to be + * responses to upcalls. They will result in calls to @msg->downcall. + * + * The @private argument passed here will be available to all these methods + * from the file pointer, via RPC_I(file_inode(file))->private. + */ +struct dentry *rpc_mkpipe_dentry(struct dentry *parent, const char *name, + void *private, struct rpc_pipe *pipe) +{ + struct dentry *dentry; + struct inode *dir = d_inode(parent); + umode_t umode = S_IFIFO | 0600; + int err; + + if (pipe->ops->upcall == NULL) + umode &= ~0444; + if (pipe->ops->downcall == NULL) + umode &= ~0222; + + inode_lock_nested(dir, I_MUTEX_PARENT); + dentry = __rpc_lookup_create_exclusive(parent, name); + if (IS_ERR(dentry)) + goto out; + err = __rpc_mkpipe_dentry(dir, dentry, umode, &rpc_pipe_fops, + private, pipe); + if (err) + goto out_err; +out: + inode_unlock(dir); + return dentry; +out_err: + dentry = ERR_PTR(err); + printk(KERN_WARNING "%s: %s() failed to create pipe %pd/%s (errno = %d)\n", + __FILE__, __func__, parent, name, + err); + goto out; +} +EXPORT_SYMBOL_GPL(rpc_mkpipe_dentry); + +/** + * rpc_unlink - remove a pipe + * @dentry: dentry for the pipe, as returned from rpc_mkpipe + * + * After this call, lookups will no longer find the pipe, and any + * attempts to read or write using preexisting opens of the pipe will + * return -EPIPE. + */ +int +rpc_unlink(struct dentry *dentry) +{ + struct dentry *parent; + struct inode *dir; + int error = 0; + + parent = dget_parent(dentry); + dir = d_inode(parent); + inode_lock_nested(dir, I_MUTEX_PARENT); + error = __rpc_rmpipe(dir, dentry); + inode_unlock(dir); + dput(parent); + return error; +} +EXPORT_SYMBOL_GPL(rpc_unlink); + +/** + * rpc_init_pipe_dir_head - initialise a struct rpc_pipe_dir_head + * @pdh: pointer to struct rpc_pipe_dir_head + */ +void rpc_init_pipe_dir_head(struct rpc_pipe_dir_head *pdh) +{ + INIT_LIST_HEAD(&pdh->pdh_entries); + pdh->pdh_dentry = NULL; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_head); + +/** + * rpc_init_pipe_dir_object - initialise a struct rpc_pipe_dir_object + * @pdo: pointer to struct rpc_pipe_dir_object + * @pdo_ops: pointer to const struct rpc_pipe_dir_object_ops + * @pdo_data: pointer to caller-defined data + */ +void rpc_init_pipe_dir_object(struct rpc_pipe_dir_object *pdo, + const struct rpc_pipe_dir_object_ops *pdo_ops, + void *pdo_data) +{ + INIT_LIST_HEAD(&pdo->pdo_head); + pdo->pdo_ops = pdo_ops; + pdo->pdo_data = pdo_data; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_object); + +static int +rpc_add_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (pdh->pdh_dentry) + ret = pdo->pdo_ops->create(pdh->pdh_dentry, pdo); + if (ret == 0) + list_add_tail(&pdo->pdo_head, &pdh->pdh_entries); + return ret; +} + +static void +rpc_remove_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (pdh->pdh_dentry) + pdo->pdo_ops->destroy(pdh->pdh_dentry, pdo); + list_del_init(&pdo->pdo_head); +} + +/** + * rpc_add_pipe_dir_object - associate a rpc_pipe_dir_object to a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +int +rpc_add_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + ret = rpc_add_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } + return ret; +} +EXPORT_SYMBOL_GPL(rpc_add_pipe_dir_object); + +/** + * rpc_remove_pipe_dir_object - remove a rpc_pipe_dir_object from a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +void +rpc_remove_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (!list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + rpc_remove_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } +} +EXPORT_SYMBOL_GPL(rpc_remove_pipe_dir_object); + +/** + * rpc_find_or_alloc_pipe_dir_object + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @match: match struct rpc_pipe_dir_object to data + * @alloc: allocate a new struct rpc_pipe_dir_object + * @data: user defined data for match() and alloc() + * + */ +struct rpc_pipe_dir_object * +rpc_find_or_alloc_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + int (*match)(struct rpc_pipe_dir_object *, void *), + struct rpc_pipe_dir_object *(*alloc)(void *), + void *data) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe_dir_object *pdo; + + mutex_lock(&sn->pipefs_sb_lock); + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) { + if (!match(pdo, data)) + continue; + goto out; + } + pdo = alloc(data); + if (!pdo) + goto out; + rpc_add_pipe_dir_object_locked(net, pdh, pdo); +out: + mutex_unlock(&sn->pipefs_sb_lock); + return pdo; +} +EXPORT_SYMBOL_GPL(rpc_find_or_alloc_pipe_dir_object); + +static void +rpc_create_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->create(dir, pdo); +} + +static void +rpc_destroy_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->destroy(dir, pdo); +} + +enum { + RPCAUTH_info, + RPCAUTH_EOF +}; + +static const struct rpc_filelist authfiles[] = { + [RPCAUTH_info] = { + .name = "info", + .i_fop = &rpc_info_operations, + .mode = S_IFREG | 0400, + }, +}; + +static int rpc_clntdir_populate(struct dentry *dentry, void *private) +{ + return rpc_populate(dentry, + authfiles, RPCAUTH_info, RPCAUTH_EOF, + private); +} + +static void rpc_clntdir_depopulate(struct dentry *dentry) +{ + rpc_depopulate(dentry, authfiles, RPCAUTH_info, RPCAUTH_EOF); +} + +/** + * rpc_create_client_dir - Create a new rpc_client directory in rpc_pipefs + * @dentry: the parent of new directory + * @name: the name of new directory + * @rpc_client: rpc client to associate with this directory + * + * This creates a directory at the given @path associated with + * @rpc_clnt, which will contain a file named "info" with some basic + * information about the client, together with any "pipes" that may + * later be created using rpc_mkpipe(). + */ +struct dentry *rpc_create_client_dir(struct dentry *dentry, + const char *name, + struct rpc_clnt *rpc_client) +{ + struct dentry *ret; + + ret = rpc_mkdir_populate(dentry, name, 0555, NULL, + rpc_clntdir_populate, rpc_client); + if (!IS_ERR(ret)) { + rpc_client->cl_pipedir_objects.pdh_dentry = ret; + rpc_create_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + } + return ret; +} + +/** + * rpc_remove_client_dir - Remove a directory created with rpc_create_client_dir() + * @rpc_client: rpc_client for the pipe + */ +int rpc_remove_client_dir(struct rpc_clnt *rpc_client) +{ + struct dentry *dentry = rpc_client->cl_pipedir_objects.pdh_dentry; + + if (dentry == NULL) + return 0; + rpc_destroy_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + rpc_client->cl_pipedir_objects.pdh_dentry = NULL; + return rpc_rmdir_depopulate(dentry, rpc_clntdir_depopulate); +} + +static const struct rpc_filelist cache_pipefs_files[3] = { + [0] = { + .name = "channel", + .i_fop = &cache_file_operations_pipefs, + .mode = S_IFREG | 0600, + }, + [1] = { + .name = "content", + .i_fop = &content_file_operations_pipefs, + .mode = S_IFREG | 0400, + }, + [2] = { + .name = "flush", + .i_fop = &cache_flush_operations_pipefs, + .mode = S_IFREG | 0600, + }, +}; + +static int rpc_cachedir_populate(struct dentry *dentry, void *private) +{ + return rpc_populate(dentry, + cache_pipefs_files, 0, 3, + private); +} + +static void rpc_cachedir_depopulate(struct dentry *dentry) +{ + rpc_depopulate(dentry, cache_pipefs_files, 0, 3); +} + +struct dentry *rpc_create_cache_dir(struct dentry *parent, const char *name, + umode_t umode, struct cache_detail *cd) +{ + return rpc_mkdir_populate(parent, name, umode, NULL, + rpc_cachedir_populate, cd); +} + +void rpc_remove_cache_dir(struct dentry *dentry) +{ + rpc_rmdir_depopulate(dentry, rpc_cachedir_depopulate); +} + +/* + * populate the filesystem + */ +static const struct super_operations s_ops = { + .alloc_inode = rpc_alloc_inode, + .destroy_inode = rpc_destroy_inode, + .statfs = simple_statfs, +}; + +#define RPCAUTH_GSSMAGIC 0x67596969 + +/* + * We have a single directory with 1 node in it. + */ +enum { + RPCAUTH_lockd, + RPCAUTH_mount, + RPCAUTH_nfs, + RPCAUTH_portmap, + RPCAUTH_statd, + RPCAUTH_nfsd4_cb, + RPCAUTH_cache, + RPCAUTH_nfsd, + RPCAUTH_gssd, + RPCAUTH_RootEOF +}; + +static const struct rpc_filelist files[] = { + [RPCAUTH_lockd] = { + .name = "lockd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_mount] = { + .name = "mount", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfs] = { + .name = "nfs", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_portmap] = { + .name = "portmap", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_statd] = { + .name = "statd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfsd4_cb] = { + .name = "nfsd4_cb", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_cache] = { + .name = "cache", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfsd] = { + .name = "nfsd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_gssd] = { + .name = "gssd", + .mode = S_IFDIR | 0555, + }, +}; + +/* + * This call can be used only in RPC pipefs mount notification hooks. + */ +struct dentry *rpc_d_lookup_sb(const struct super_block *sb, + const unsigned char *dir_name) +{ + struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name)); + return d_hash_and_lookup(sb->s_root, &dir); +} +EXPORT_SYMBOL_GPL(rpc_d_lookup_sb); + +int rpc_pipefs_init_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + sn->gssd_dummy = rpc_mkpipe_data(&gssd_dummy_pipe_ops, 0); + if (IS_ERR(sn->gssd_dummy)) + return PTR_ERR(sn->gssd_dummy); + + mutex_init(&sn->pipefs_sb_lock); + sn->pipe_version = -1; + return 0; +} + +void rpc_pipefs_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_destroy_pipe_data(sn->gssd_dummy); +} + +/* + * This call will be used for per network namespace operations calls. + * Note: Function will be returned with pipefs_sb_lock taken if superblock was + * found. This lock have to be released by rpc_put_sb_net() when all operations + * will be completed. + */ +struct super_block *rpc_get_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb) + return sn->pipefs_sb; + mutex_unlock(&sn->pipefs_sb_lock); + return NULL; +} +EXPORT_SYMBOL_GPL(rpc_get_sb_net); + +void rpc_put_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + WARN_ON(sn->pipefs_sb == NULL); + mutex_unlock(&sn->pipefs_sb_lock); +} +EXPORT_SYMBOL_GPL(rpc_put_sb_net); + +static const struct rpc_filelist gssd_dummy_clnt_dir[] = { + [0] = { + .name = "clntXX", + .mode = S_IFDIR | 0555, + }, +}; + +static ssize_t +dummy_downcall(struct file *filp, const char __user *src, size_t len) +{ + return -EINVAL; +} + +static const struct rpc_pipe_ops gssd_dummy_pipe_ops = { + .upcall = rpc_pipe_generic_upcall, + .downcall = dummy_downcall, +}; + +/* + * Here we present a bogus "info" file to keep rpc.gssd happy. We don't expect + * that it will ever use this info to handle an upcall, but rpc.gssd expects + * that this file will be there and have a certain format. + */ +static int +rpc_show_dummy_info(struct seq_file *m, void *v) +{ + seq_printf(m, "RPC server: %s\n", utsname()->nodename); + seq_printf(m, "service: foo (1) version 0\n"); + seq_printf(m, "address: 127.0.0.1\n"); + seq_printf(m, "protocol: tcp\n"); + seq_printf(m, "port: 0\n"); + return 0; +} + +static int +rpc_dummy_info_open(struct inode *inode, struct file *file) +{ + return single_open(file, rpc_show_dummy_info, NULL); +} + +static const struct file_operations rpc_dummy_info_operations = { + .owner = THIS_MODULE, + .open = rpc_dummy_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = single_release, +}; + +static const struct rpc_filelist gssd_dummy_info_file[] = { + [0] = { + .name = "info", + .i_fop = &rpc_dummy_info_operations, + .mode = S_IFREG | 0400, + }, +}; + +/** + * rpc_gssd_dummy_populate - create a dummy gssd pipe + * @root: root of the rpc_pipefs filesystem + * @pipe_data: pipe data created when netns is initialized + * + * Create a dummy set of directories and a pipe that gssd can hold open to + * indicate that it is up and running. + */ +static struct dentry * +rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) +{ + int ret = 0; + struct dentry *gssd_dentry; + struct dentry *clnt_dentry = NULL; + struct dentry *pipe_dentry = NULL; + struct qstr q = QSTR_INIT(files[RPCAUTH_gssd].name, + strlen(files[RPCAUTH_gssd].name)); + + /* We should never get this far if "gssd" doesn't exist */ + gssd_dentry = d_hash_and_lookup(root, &q); + if (!gssd_dentry) + return ERR_PTR(-ENOENT); + + ret = rpc_populate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1, NULL); + if (ret) { + pipe_dentry = ERR_PTR(ret); + goto out; + } + + q.name = gssd_dummy_clnt_dir[0].name; + q.len = strlen(gssd_dummy_clnt_dir[0].name); + clnt_dentry = d_hash_and_lookup(gssd_dentry, &q); + if (!clnt_dentry) { + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + pipe_dentry = ERR_PTR(-ENOENT); + goto out; + } + + ret = rpc_populate(clnt_dentry, gssd_dummy_info_file, 0, 1, NULL); + if (ret) { + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + pipe_dentry = ERR_PTR(ret); + goto out; + } + + pipe_dentry = rpc_mkpipe_dentry(clnt_dentry, "gssd", NULL, pipe_data); + if (IS_ERR(pipe_dentry)) { + __rpc_depopulate(clnt_dentry, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + } +out: + dput(clnt_dentry); + dput(gssd_dentry); + return pipe_dentry; +} + +static void +rpc_gssd_dummy_depopulate(struct dentry *pipe_dentry) +{ + struct dentry *clnt_dir = pipe_dentry->d_parent; + struct dentry *gssd_dir = clnt_dir->d_parent; + + dget(pipe_dentry); + __rpc_rmpipe(d_inode(clnt_dir), pipe_dentry); + __rpc_depopulate(clnt_dir, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dir, gssd_dummy_clnt_dir, 0, 1); + dput(pipe_dentry); +} + +static int +rpc_fill_super(struct super_block *sb, void *data, int silent) +{ + struct inode *inode; + struct dentry *root, *gssd_dentry; + struct net *net = get_net(sb->s_fs_info); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int err; + + sb->s_blocksize = PAGE_SIZE; + sb->s_blocksize_bits = PAGE_SHIFT; + sb->s_magic = RPCAUTH_GSSMAGIC; + sb->s_op = &s_ops; + sb->s_d_op = &simple_dentry_operations; + sb->s_time_gran = 1; + + inode = rpc_get_inode(sb, S_IFDIR | 0555); + sb->s_root = root = d_make_root(inode); + if (!root) + return -ENOMEM; + if (rpc_populate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF, NULL)) + return -ENOMEM; + + gssd_dentry = rpc_gssd_dummy_populate(root, sn->gssd_dummy); + if (IS_ERR(gssd_dentry)) { + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + return PTR_ERR(gssd_dentry); + } + + dprintk("RPC: sending pipefs MOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); + mutex_lock(&sn->pipefs_sb_lock); + sn->pipefs_sb = sb; + err = blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_MOUNT, + sb); + if (err) + goto err_depopulate; + mutex_unlock(&sn->pipefs_sb_lock); + return 0; + +err_depopulate: + rpc_gssd_dummy_depopulate(gssd_dentry); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + sn->pipefs_sb = NULL; + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + mutex_unlock(&sn->pipefs_sb_lock); + return err; +} + +bool +gssd_running(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe = sn->gssd_dummy; + + return pipe->nreaders || pipe->nwriters; +} +EXPORT_SYMBOL_GPL(gssd_running); + +static struct dentry * +rpc_mount(struct file_system_type *fs_type, + int flags, const char *dev_name, void *data) +{ + struct net *net = current->nsproxy->net_ns; + return mount_ns(fs_type, flags, data, net, net->user_ns, rpc_fill_super); +} + +static void rpc_kill_sb(struct super_block *sb) +{ + struct net *net = sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb != sb) { + mutex_unlock(&sn->pipefs_sb_lock); + goto out; + } + sn->pipefs_sb = NULL; + dprintk("RPC: sending pipefs UMOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + mutex_unlock(&sn->pipefs_sb_lock); +out: + kill_litter_super(sb); + put_net(net); +} + +static struct file_system_type rpc_pipe_fs_type = { + .owner = THIS_MODULE, + .name = "rpc_pipefs", + .mount = rpc_mount, + .kill_sb = rpc_kill_sb, +}; +MODULE_ALIAS_FS("rpc_pipefs"); +MODULE_ALIAS("rpc_pipefs"); + +static void +init_once(void *foo) +{ + struct rpc_inode *rpci = (struct rpc_inode *) foo; + + inode_init_once(&rpci->vfs_inode); + rpci->private = NULL; + rpci->pipe = NULL; + init_waitqueue_head(&rpci->waitq); +} + +int register_rpc_pipefs(void) +{ + int err; + + rpc_inode_cachep = kmem_cache_create("rpc_inode_cache", + sizeof(struct rpc_inode), + 0, (SLAB_HWCACHE_ALIGN|SLAB_RECLAIM_ACCOUNT| + SLAB_MEM_SPREAD|SLAB_ACCOUNT), + init_once); + if (!rpc_inode_cachep) + return -ENOMEM; + err = rpc_clients_notifier_register(); + if (err) + goto err_notifier; + err = register_filesystem(&rpc_pipe_fs_type); + if (err) + goto err_register; + return 0; + +err_register: + rpc_clients_notifier_unregister(); +err_notifier: + kmem_cache_destroy(rpc_inode_cachep); + return err; +} + +void unregister_rpc_pipefs(void) +{ + rpc_clients_notifier_unregister(); + kmem_cache_destroy(rpc_inode_cachep); + unregister_filesystem(&rpc_pipe_fs_type); +} diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c new file mode 100644 index 000000000..ba8f36731 --- /dev/null +++ b/net/sunrpc/rpcb_clnt.c @@ -0,0 +1,1169 @@ +/* + * In-kernel rpcbind client supporting versions 2, 3, and 4 of the rpcbind + * protocol + * + * Based on RFC 1833: "Binding Protocols for ONC RPC Version 2" and + * RFC 3530: "Network File System (NFS) version 4 Protocol" + * + * Original: Gilles Quillard, Bull Open Source, 2005 <gilles.quillard@bull.net> + * Updated: Chuck Lever, Oracle Corporation, 2007 <chuck.lever@oracle.com> + * + * Descended from net/sunrpc/pmap_clnt.c, + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/un.h> +#include <linux/in.h> +#include <linux/in6.h> +#include <linux/kernel.h> +#include <linux/errno.h> +#include <linux/mutex.h> +#include <linux/slab.h> +#include <net/ipv6.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/xprtsock.h> + +#include "netns.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_BIND +#endif + +#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock" + +#define RPCBIND_PROGRAM (100000u) +#define RPCBIND_PORT (111u) + +#define RPCBVERS_2 (2u) +#define RPCBVERS_3 (3u) +#define RPCBVERS_4 (4u) + +enum { + RPCBPROC_NULL, + RPCBPROC_SET, + RPCBPROC_UNSET, + RPCBPROC_GETPORT, + RPCBPROC_GETADDR = 3, /* alias for GETPORT */ + RPCBPROC_DUMP, + RPCBPROC_CALLIT, + RPCBPROC_BCAST = 5, /* alias for CALLIT */ + RPCBPROC_GETTIME, + RPCBPROC_UADDR2TADDR, + RPCBPROC_TADDR2UADDR, + RPCBPROC_GETVERSADDR, + RPCBPROC_INDIRECT, + RPCBPROC_GETADDRLIST, + RPCBPROC_GETSTAT, +}; + +/* + * r_owner + * + * The "owner" is allowed to unset a service in the rpcbind database. + * + * For AF_LOCAL SET/UNSET requests, rpcbind treats this string as a + * UID which it maps to a local user name via a password lookup. + * In all other cases it is ignored. + * + * For SET/UNSET requests, user space provides a value, even for + * network requests, and GETADDR uses an empty string. We follow + * those precedents here. + */ +#define RPCB_OWNER_STRING "0" +#define RPCB_MAXOWNERLEN sizeof(RPCB_OWNER_STRING) + +/* + * XDR data type sizes + */ +#define RPCB_program_sz (1) +#define RPCB_version_sz (1) +#define RPCB_protocol_sz (1) +#define RPCB_port_sz (1) +#define RPCB_boolean_sz (1) + +#define RPCB_netid_sz (1 + XDR_QUADLEN(RPCBIND_MAXNETIDLEN)) +#define RPCB_addr_sz (1 + XDR_QUADLEN(RPCBIND_MAXUADDRLEN)) +#define RPCB_ownerstring_sz (1 + XDR_QUADLEN(RPCB_MAXOWNERLEN)) + +/* + * XDR argument and result sizes + */ +#define RPCB_mappingargs_sz (RPCB_program_sz + RPCB_version_sz + \ + RPCB_protocol_sz + RPCB_port_sz) +#define RPCB_getaddrargs_sz (RPCB_program_sz + RPCB_version_sz + \ + RPCB_netid_sz + RPCB_addr_sz + \ + RPCB_ownerstring_sz) + +#define RPCB_getportres_sz RPCB_port_sz +#define RPCB_setres_sz RPCB_boolean_sz + +/* + * Note that RFC 1833 does not put any size restrictions on the + * address string returned by the remote rpcbind database. + */ +#define RPCB_getaddrres_sz RPCB_addr_sz + +static void rpcb_getport_done(struct rpc_task *, void *); +static void rpcb_map_release(void *data); +static const struct rpc_program rpcb_program; + +struct rpcbind_args { + struct rpc_xprt * r_xprt; + + u32 r_prog; + u32 r_vers; + u32 r_prot; + unsigned short r_port; + const char * r_netid; + const char * r_addr; + const char * r_owner; + + int r_status; +}; + +static const struct rpc_procinfo rpcb_procedures2[]; +static const struct rpc_procinfo rpcb_procedures3[]; +static const struct rpc_procinfo rpcb_procedures4[]; + +struct rpcb_info { + u32 rpc_vers; + const struct rpc_procinfo *rpc_proc; +}; + +static const struct rpcb_info rpcb_next_version[]; +static const struct rpcb_info rpcb_next_version6[]; + +static const struct rpc_call_ops rpcb_getport_ops = { + .rpc_call_done = rpcb_getport_done, + .rpc_release = rpcb_map_release, +}; + +static void rpcb_wake_rpcbind_waiters(struct rpc_xprt *xprt, int status) +{ + xprt_clear_binding(xprt); + rpc_wake_up_status(&xprt->binding, status); +} + +static void rpcb_map_release(void *data) +{ + struct rpcbind_args *map = data; + + rpcb_wake_rpcbind_waiters(map->r_xprt, map->r_status); + xprt_put(map->r_xprt); + kfree(map->r_addr); + kfree(map); +} + +static int rpcb_get_local(struct net *net) +{ + int cnt; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) + sn->rpcb_users++; + cnt = sn->rpcb_users; + spin_unlock(&sn->rpcb_clnt_lock); + + return cnt; +} + +void rpcb_put_local(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt = sn->rpcb_local_clnt; + struct rpc_clnt *clnt4 = sn->rpcb_local_clnt4; + int shutdown = 0; + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) { + if (--sn->rpcb_users == 0) { + sn->rpcb_local_clnt = NULL; + sn->rpcb_local_clnt4 = NULL; + } + shutdown = !sn->rpcb_users; + } + spin_unlock(&sn->rpcb_clnt_lock); + + if (shutdown) { + /* + * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister + */ + if (clnt4) + rpc_shutdown_client(clnt4); + if (clnt) + rpc_shutdown_client(clnt); + } +} + +static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt, + struct rpc_clnt *clnt4, + bool is_af_local) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* Protected by rpcb_create_local_mutex */ + sn->rpcb_local_clnt = clnt; + sn->rpcb_local_clnt4 = clnt4; + sn->rpcb_is_af_local = is_af_local ? 1 : 0; + smp_wmb(); + sn->rpcb_users = 1; + dprintk("RPC: created new rpcb local clients (rpcb_local_clnt: " + "%p, rpcb_local_clnt4: %p) for net %x%s\n", + sn->rpcb_local_clnt, sn->rpcb_local_clnt4, + net->ns.inum, (net == &init_net) ? " (init_net)" : ""); +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +static int rpcb_create_local_unix(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_rpcbind = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_PATHNAME, + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&rpcb_localaddr_rpcbind, + .addrsize = sizeof(rpcb_localaddr_rpcbind), + .servername = "localhost", + .program = &rpcb_program, + .version = RPCBVERS_2, + .authflavor = RPC_AUTH_NULL, + /* + * We turn off the idle timeout to prevent the kernel + * from automatically disconnecting the socket. + * Otherwise, we'd have to cache the mount namespace + * of the caller and somehow pass that to the socket + * reconnect code. + */ + .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT, + }; + struct rpc_clnt *clnt, *clnt4; + int result = 0; + + /* + * Because we requested an RPC PING at transport creation time, + * this works only if the user space portmapper is rpcbind, and + * it's listening on AF_LOCAL on the named socket. + */ + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create AF_LOCAL rpcbind " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + goto out; + } + + clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); + if (IS_ERR(clnt4)) { + dprintk("RPC: failed to bind second program to " + "rpcbind v4 client (errno %ld).\n", + PTR_ERR(clnt4)); + clnt4 = NULL; + } + + rpcb_set_local(net, clnt, clnt4, true); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +static int rpcb_create_local_net(struct net *net) +{ + static const struct sockaddr_in rpcb_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_LOOPBACK), + .sin_port = htons(RPCBIND_PORT), + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_TCP, + .address = (struct sockaddr *)&rpcb_inaddr_loopback, + .addrsize = sizeof(rpcb_inaddr_loopback), + .servername = "localhost", + .program = &rpcb_program, + .version = RPCBVERS_2, + .authflavor = RPC_AUTH_UNIX, + .flags = RPC_CLNT_CREATE_NOPING, + }; + struct rpc_clnt *clnt, *clnt4; + int result = 0; + + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create local rpcbind " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + goto out; + } + + /* + * This results in an RPC ping. On systems running portmapper, + * the v4 ping will fail. Proceed anyway, but disallow rpcb + * v4 upcalls. + */ + clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); + if (IS_ERR(clnt4)) { + dprintk("RPC: failed to bind second program to " + "rpcbind v4 client (errno %ld).\n", + PTR_ERR(clnt4)); + clnt4 = NULL; + } + + rpcb_set_local(net, clnt, clnt4, false); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +int rpcb_create_local(struct net *net) +{ + static DEFINE_MUTEX(rpcb_create_local_mutex); + int result = 0; + + if (rpcb_get_local(net)) + return result; + + mutex_lock(&rpcb_create_local_mutex); + if (rpcb_get_local(net)) + goto out; + + if (rpcb_create_local_unix(net) != 0) + result = rpcb_create_local_net(net); + +out: + mutex_unlock(&rpcb_create_local_mutex); + return result; +} + +static struct rpc_clnt *rpcb_create(struct net *net, const char *nodename, + const char *hostname, + struct sockaddr *srvaddr, size_t salen, + int proto, u32 version) +{ + struct rpc_create_args args = { + .net = net, + .protocol = proto, + .address = srvaddr, + .addrsize = salen, + .servername = hostname, + .nodename = nodename, + .program = &rpcb_program, + .version = version, + .authflavor = RPC_AUTH_UNIX, + .flags = (RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_NONPRIVPORT), + }; + + switch (srvaddr->sa_family) { + case AF_INET: + ((struct sockaddr_in *)srvaddr)->sin_port = htons(RPCBIND_PORT); + break; + case AF_INET6: + ((struct sockaddr_in6 *)srvaddr)->sin6_port = htons(RPCBIND_PORT); + break; + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + return rpc_create(&args); +} + +static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set) +{ + int flags = RPC_TASK_NOCONNECT; + int error, result = 0; + + if (is_set || !sn->rpcb_is_af_local) + flags = RPC_TASK_SOFTCONN; + msg->rpc_resp = &result; + + error = rpc_call_sync(clnt, msg, flags); + if (error < 0) { + dprintk("RPC: failed to contact local rpcbind " + "server (errno %d).\n", -error); + return error; + } + + if (!result) + return -EACCES; + return 0; +} + +/** + * rpcb_register - set or unset a port registration with the local rpcbind svc + * @net: target network namespace + * @prog: RPC program number to bind + * @vers: RPC version number to bind + * @prot: transport protocol to register + * @port: port value to register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, transport] + * tuple they wish to advertise. + * + * Callers may also unregister RPC services that are no longer + * available by setting the passed-in port to zero. This removes + * all registered transports for [program, version] from the local + * rpcbind database. + * + * This function uses rpcbind protocol version 2 to contact the + * local rpcbind daemon. + * + * Registration works over both AF_INET and AF_INET6, and services + * registered via this function are advertised as available for any + * address. If the local rpcbind daemon is listening on AF_INET6, + * services registered via this function will be advertised on + * IN6ADDR_ANY (ie available for all AF_INET and AF_INET6 + * addresses). + */ +int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short port) +{ + struct rpcbind_args map = { + .r_prog = prog, + .r_vers = vers, + .r_prot = prot, + .r_port = port, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + bool is_set = false; + + dprintk("RPC: %sregistering (%u, %u, %d, %u) with local " + "rpcbind\n", (port ? "" : "un"), + prog, vers, prot, port); + + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET]; + if (port != 0) { + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET]; + is_set = true; + } + + return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set); +} + +/* + * Fill in AF_INET family-specific arguments to register + */ +static int rpcb_register_inet4(struct sunrpc_net *sn, + const struct sockaddr *sap, + struct rpc_message *msg) +{ + const struct sockaddr_in *sin = (const struct sockaddr_in *)sap; + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(sin->sin_port); + bool is_set = false; + int result; + + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); + + dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " + "local rpcbind\n", (port ? "" : "un"), + map->r_prog, map->r_vers, + map->r_addr, map->r_netid); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port != 0) { + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } + + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); + kfree(map->r_addr); + return result; +} + +/* + * Fill in AF_INET6 family-specific arguments to register + */ +static int rpcb_register_inet6(struct sunrpc_net *sn, + const struct sockaddr *sap, + struct rpc_message *msg) +{ + const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap; + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(sin6->sin6_port); + bool is_set = false; + int result; + + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); + + dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " + "local rpcbind\n", (port ? "" : "un"), + map->r_prog, map->r_vers, + map->r_addr, map->r_netid); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port != 0) { + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } + + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); + kfree(map->r_addr); + return result; +} + +static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn, + struct rpc_message *msg) +{ + struct rpcbind_args *map = msg->rpc_argp; + + dprintk("RPC: unregistering [%u, %u, '%s'] with " + "local rpcbind\n", + map->r_prog, map->r_vers, map->r_netid); + + map->r_addr = ""; + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + + return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false); +} + +/** + * rpcb_v4_register - set or unset a port registration with the local rpcbind + * @net: target network namespace + * @program: RPC program number of service to (un)register + * @version: RPC version number of service to (un)register + * @address: address family, IP address, and port to (un)register + * @netid: netid of transport protocol to (un)register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, address, + * netid] tuple they wish to advertise. + * + * Callers may also unregister RPC services that are registered at a + * specific address by setting the port number in @address to zero. + * They may unregister all registered protocol families at once for + * a service by passing a NULL @address argument. If @netid is "" + * then all netids for [program, version, address] are unregistered. + * + * This function uses rpcbind protocol version 4 to contact the + * local rpcbind daemon. The local rpcbind daemon must support + * version 4 of the rpcbind protocol in order for these functions + * to register a service successfully. + * + * Supported netids include "udp" and "tcp" for UDP and TCP over + * IPv4, and "udp6" and "tcp6" for UDP and TCP over IPv6, + * respectively. + * + * The contents of @address determine the address family and the + * port to be registered. The usual practice is to pass INADDR_ANY + * as the raw address, but specifying a non-zero address is also + * supported by this API if the caller wishes to advertise an RPC + * service on a specific network interface. + * + * Note that passing in INADDR_ANY does not create the same service + * registration as IN6ADDR_ANY. The former advertises an RPC + * service on any IPv4 address, but not on IPv6. The latter + * advertises the service on all IPv4 and IPv6 addresses. + */ +int rpcb_v4_register(struct net *net, const u32 program, const u32 version, + const struct sockaddr *address, const char *netid) +{ + struct rpcbind_args map = { + .r_prog = program, + .r_vers = version, + .r_netid = netid, + .r_owner = RPCB_OWNER_STRING, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->rpcb_local_clnt4 == NULL) + return -EPROTONOSUPPORT; + + if (address == NULL) + return rpcb_unregister_all_protofamilies(sn, &msg); + + switch (address->sa_family) { + case AF_INET: + return rpcb_register_inet4(sn, address, &msg); + case AF_INET6: + return rpcb_register_inet6(sn, address, &msg); + } + + return -EAFNOSUPPORT; +} + +static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, + struct rpcbind_args *map, const struct rpc_procinfo *proc) +{ + struct rpc_message msg = { + .rpc_proc = proc, + .rpc_argp = map, + .rpc_resp = map, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = rpcb_clnt, + .rpc_message = &msg, + .callback_ops = &rpcb_getport_ops, + .callback_data = map, + .flags = RPC_TASK_ASYNC | RPC_TASK_SOFTCONN, + }; + + return rpc_run_task(&task_setup_data); +} + +/* + * In the case where rpc clients have been cloned, we want to make + * sure that we use the program number/version etc of the actual + * owner of the xprt. To do so, we walk back up the tree of parents + * to find whoever created the transport and/or whoever has the + * autobind flag set. + */ +static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) +{ + struct rpc_clnt *parent = clnt->cl_parent; + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); + + while (parent != clnt) { + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) + break; + if (clnt->cl_autobind) + break; + clnt = parent; + parent = parent->cl_parent; + } + return clnt; +} + +/** + * rpcb_getport_async - obtain the port for a given RPC service on a given host + * @task: task that is waiting for portmapper request + * + * This one can be called for an ongoing RPC request, and can be used in + * an async (rpciod) context. + */ +void rpcb_getport_async(struct rpc_task *task) +{ + struct rpc_clnt *clnt; + const struct rpc_procinfo *proc; + u32 bind_version; + struct rpc_xprt *xprt; + struct rpc_clnt *rpcb_clnt; + struct rpcbind_args *map; + struct rpc_task *child; + struct sockaddr_storage addr; + struct sockaddr *sap = (struct sockaddr *)&addr; + size_t salen; + int status; + + rcu_read_lock(); + clnt = rpcb_find_transport_owner(task->tk_client); + rcu_read_unlock(); + xprt = xprt_get(task->tk_xprt); + + dprintk("RPC: %5u %s(%s, %u, %u, %d)\n", + task->tk_pid, __func__, + xprt->servername, clnt->cl_prog, clnt->cl_vers, xprt->prot); + + /* Put self on the wait queue to ensure we get notified if + * some other task is already attempting to bind the port */ + rpc_sleep_on(&xprt->binding, task, NULL); + + if (xprt_test_and_set_binding(xprt)) { + dprintk("RPC: %5u %s: waiting for another binder\n", + task->tk_pid, __func__); + xprt_put(xprt); + return; + } + + /* Someone else may have bound if we slept */ + if (xprt_bound(xprt)) { + status = 0; + dprintk("RPC: %5u %s: already bound\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + + /* Parent transport's destination address */ + salen = rpc_peeraddr(clnt, sap, sizeof(addr)); + + /* Don't ever use rpcbind v2 for AF_INET6 requests */ + switch (sap->sa_family) { + case AF_INET: + proc = rpcb_next_version[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version[xprt->bind_index].rpc_vers; + break; + case AF_INET6: + proc = rpcb_next_version6[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version6[xprt->bind_index].rpc_vers; + break; + default: + status = -EAFNOSUPPORT; + dprintk("RPC: %5u %s: bad address family\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + if (proc == NULL) { + xprt->bind_index = 0; + status = -EPFNOSUPPORT; + dprintk("RPC: %5u %s: no more getport versions available\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + + dprintk("RPC: %5u %s: trying rpcbind version %u\n", + task->tk_pid, __func__, bind_version); + + rpcb_clnt = rpcb_create(xprt->xprt_net, + clnt->cl_nodename, + xprt->servername, sap, salen, + xprt->prot, bind_version); + if (IS_ERR(rpcb_clnt)) { + status = PTR_ERR(rpcb_clnt); + dprintk("RPC: %5u %s: rpcb_create failed, error %ld\n", + task->tk_pid, __func__, PTR_ERR(rpcb_clnt)); + goto bailout_nofree; + } + + map = kzalloc(sizeof(struct rpcbind_args), GFP_ATOMIC); + if (!map) { + status = -ENOMEM; + dprintk("RPC: %5u %s: no memory available\n", + task->tk_pid, __func__); + goto bailout_release_client; + } + map->r_prog = clnt->cl_prog; + map->r_vers = clnt->cl_vers; + map->r_prot = xprt->prot; + map->r_port = 0; + map->r_xprt = xprt; + map->r_status = -EIO; + + switch (bind_version) { + case RPCBVERS_4: + case RPCBVERS_3: + map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID]; + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_ATOMIC); + if (!map->r_addr) { + status = -ENOMEM; + dprintk("RPC: %5u %s: no memory available\n", + task->tk_pid, __func__); + goto bailout_free_args; + } + map->r_owner = ""; + break; + case RPCBVERS_2: + map->r_addr = NULL; + break; + default: + BUG(); + } + + child = rpcb_call_async(rpcb_clnt, map, proc); + rpc_release_client(rpcb_clnt); + if (IS_ERR(child)) { + /* rpcb_map_release() has freed the arguments */ + dprintk("RPC: %5u %s: rpc_run_task failed\n", + task->tk_pid, __func__); + return; + } + + xprt->stat.bind_count++; + rpc_put_task(child); + return; + +bailout_free_args: + kfree(map); +bailout_release_client: + rpc_release_client(rpcb_clnt); +bailout_nofree: + rpcb_wake_rpcbind_waiters(xprt, status); + task->tk_status = status; + xprt_put(xprt); +} +EXPORT_SYMBOL_GPL(rpcb_getport_async); + +/* + * Rpcbind child task calls this callback via tk_exit. + */ +static void rpcb_getport_done(struct rpc_task *child, void *data) +{ + struct rpcbind_args *map = data; + struct rpc_xprt *xprt = map->r_xprt; + int status = child->tk_status; + + /* Garbage reply: retry with a lesser rpcbind version */ + if (status == -EIO) + status = -EPROTONOSUPPORT; + + /* rpcbind server doesn't support this rpcbind protocol version */ + if (status == -EPROTONOSUPPORT) + xprt->bind_index++; + + if (status < 0) { + /* rpcbind server not available on remote host? */ + xprt->ops->set_port(xprt, 0); + } else if (map->r_port == 0) { + /* Requested RPC service wasn't registered on remote host */ + xprt->ops->set_port(xprt, 0); + status = -EACCES; + } else { + /* Succeeded */ + xprt->ops->set_port(xprt, map->r_port); + xprt_set_bound(xprt); + status = 0; + } + + dprintk("RPC: %5u rpcb_getport_done(status %d, port %u)\n", + child->tk_pid, status, map->r_port); + + map->r_status = status; +} + +/* + * XDR functions for rpcbind + */ + +static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr, + const void *data) +{ + const struct rpcbind_args *rpcb = data; + __be32 *p; + + dprintk("RPC: %5u encoding PMAP_%s call (%u, %u, %d, %u)\n", + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, + rpcb->r_prog, rpcb->r_vers, rpcb->r_prot, rpcb->r_port); + + p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p++ = cpu_to_be32(rpcb->r_vers); + *p++ = cpu_to_be32(rpcb->r_prot); + *p = cpu_to_be32(rpcb->r_port); +} + +static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + struct rpcbind_args *rpcb = data; + unsigned long port; + __be32 *p; + + rpcb->r_port = 0; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -EIO; + + port = be32_to_cpup(p); + dprintk("RPC: %5u PMAP_%s result: %lu\n", req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, port); + if (unlikely(port > USHRT_MAX)) + return -EIO; + + rpcb->r_port = port; + return 0; +} + +static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + unsigned int *boolp = data; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -EIO; + + *boolp = 0; + if (*p != xdr_zero) + *boolp = 1; + + dprintk("RPC: %5u RPCB_%s call %s\n", + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, + (*boolp ? "succeeded" : "failed")); + return 0; +} + +static void encode_rpcb_string(struct xdr_stream *xdr, const char *string, + const u32 maxstrlen) +{ + __be32 *p; + u32 len; + + len = strlen(string); + WARN_ON_ONCE(len > maxstrlen); + if (len > maxstrlen) + /* truncate and hope for the best */ + len = maxstrlen; + p = xdr_reserve_space(xdr, 4 + len); + xdr_encode_opaque(p, string, len); +} + +static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, + const void *data) +{ + const struct rpcbind_args *rpcb = data; + __be32 *p; + + dprintk("RPC: %5u encoding RPCB_%s call (%u, %u, '%s', '%s')\n", + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, + rpcb->r_prog, rpcb->r_vers, + rpcb->r_netid, rpcb->r_addr); + + p = xdr_reserve_space(xdr, (RPCB_program_sz + RPCB_version_sz) << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p = cpu_to_be32(rpcb->r_vers); + + encode_rpcb_string(xdr, rpcb->r_netid, RPCBIND_MAXNETIDLEN); + encode_rpcb_string(xdr, rpcb->r_addr, RPCBIND_MAXUADDRLEN); + encode_rpcb_string(xdr, rpcb->r_owner, RPCB_MAXOWNERLEN); +} + +static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + struct rpcbind_args *rpcb = data; + struct sockaddr_storage address; + struct sockaddr *sap = (struct sockaddr *)&address; + __be32 *p; + u32 len; + + rpcb->r_port = 0; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + goto out_fail; + len = be32_to_cpup(p); + + /* + * If the returned universal address is a null string, + * the requested RPC service was not registered. + */ + if (len == 0) { + dprintk("RPC: %5u RPCB reply: program not registered\n", + req->rq_task->tk_pid); + return 0; + } + + if (unlikely(len > RPCBIND_MAXUADDRLEN)) + goto out_fail; + + p = xdr_inline_decode(xdr, len); + if (unlikely(p == NULL)) + goto out_fail; + dprintk("RPC: %5u RPCB_%s reply: %*pE\n", req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, len, (char *)p); + + if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len, + sap, sizeof(address)) == 0) + goto out_fail; + rpcb->r_port = rpc_get_port(sap); + + return 0; + +out_fail: + dprintk("RPC: %5u malformed RPCB_%s reply\n", + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name); + return -EIO; +} + +/* + * Not all rpcbind procedures described in RFC 1833 are implemented + * since the Linux kernel RPC code requires only these. + */ + +static const struct rpc_procinfo rpcb_procedures2[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETPORT] = { + .p_proc = RPCBPROC_GETPORT, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_getport, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_getportres_sz, + .p_statidx = RPCBPROC_GETPORT, + .p_timer = 0, + .p_name = "GETPORT", + }, +}; + +static const struct rpc_procinfo rpcb_procedures3[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETADDR] = { + .p_proc = RPCBPROC_GETADDR, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_getaddrres_sz, + .p_statidx = RPCBPROC_GETADDR, + .p_timer = 0, + .p_name = "GETADDR", + }, +}; + +static const struct rpc_procinfo rpcb_procedures4[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETADDR] = { + .p_proc = RPCBPROC_GETADDR, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_getaddrres_sz, + .p_statidx = RPCBPROC_GETADDR, + .p_timer = 0, + .p_name = "GETADDR", + }, +}; + +static const struct rpcb_info rpcb_next_version[] = { + { + .rpc_vers = RPCBVERS_2, + .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], + }, + { + .rpc_proc = NULL, + }, +}; + +static const struct rpcb_info rpcb_next_version6[] = { + { + .rpc_vers = RPCBVERS_4, + .rpc_proc = &rpcb_procedures4[RPCBPROC_GETADDR], + }, + { + .rpc_vers = RPCBVERS_3, + .rpc_proc = &rpcb_procedures3[RPCBPROC_GETADDR], + }, + { + .rpc_proc = NULL, + }, +}; + +static unsigned int rpcb_version2_counts[ARRAY_SIZE(rpcb_procedures2)]; +static const struct rpc_version rpcb_version2 = { + .number = RPCBVERS_2, + .nrprocs = ARRAY_SIZE(rpcb_procedures2), + .procs = rpcb_procedures2, + .counts = rpcb_version2_counts, +}; + +static unsigned int rpcb_version3_counts[ARRAY_SIZE(rpcb_procedures3)]; +static const struct rpc_version rpcb_version3 = { + .number = RPCBVERS_3, + .nrprocs = ARRAY_SIZE(rpcb_procedures3), + .procs = rpcb_procedures3, + .counts = rpcb_version3_counts, +}; + +static unsigned int rpcb_version4_counts[ARRAY_SIZE(rpcb_procedures4)]; +static const struct rpc_version rpcb_version4 = { + .number = RPCBVERS_4, + .nrprocs = ARRAY_SIZE(rpcb_procedures4), + .procs = rpcb_procedures4, + .counts = rpcb_version4_counts, +}; + +static const struct rpc_version *rpcb_version[] = { + NULL, + NULL, + &rpcb_version2, + &rpcb_version3, + &rpcb_version4 +}; + +static struct rpc_stat rpcb_stats; + +static const struct rpc_program rpcb_program = { + .name = "rpcbind", + .number = RPCBIND_PROGRAM, + .nrvers = ARRAY_SIZE(rpcb_version), + .version = rpcb_version, + .stats = &rpcb_stats, +}; diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c new file mode 100644 index 000000000..e36ae4d4b --- /dev/null +++ b/net/sunrpc/sched.c @@ -0,0 +1,1198 @@ +/* + * linux/net/sunrpc/sched.c + * + * Scheduling for synchronous and asynchronous RPC requests. + * + * Copyright (C) 1996 Olaf Kirch, <okir@monad.swb.de> + * + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie> + */ + +#include <linux/module.h> + +#include <linux/sched.h> +#include <linux/interrupt.h> +#include <linux/slab.h> +#include <linux/mempool.h> +#include <linux/smp.h> +#include <linux/spinlock.h> +#include <linux/mutex.h> +#include <linux/freezer.h> + +#include <linux/sunrpc/clnt.h> + +#include "sunrpc.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +#define RPCDBG_FACILITY RPCDBG_SCHED +#endif + +#define CREATE_TRACE_POINTS +#include <trace/events/sunrpc.h> + +/* + * RPC slabs and memory pools + */ +#define RPC_BUFFER_MAXSIZE (2048) +#define RPC_BUFFER_POOLSIZE (8) +#define RPC_TASK_POOLSIZE (8) +static struct kmem_cache *rpc_task_slabp __read_mostly; +static struct kmem_cache *rpc_buffer_slabp __read_mostly; +static mempool_t *rpc_task_mempool __read_mostly; +static mempool_t *rpc_buffer_mempool __read_mostly; + +static void rpc_async_schedule(struct work_struct *); +static void rpc_release_task(struct rpc_task *task); +static void __rpc_queue_timer_fn(struct timer_list *t); + +/* + * RPC tasks sit here while waiting for conditions to improve. + */ +static struct rpc_wait_queue delay_queue; + +/* + * rpciod-related stuff + */ +struct workqueue_struct *rpciod_workqueue __read_mostly; +struct workqueue_struct *xprtiod_workqueue __read_mostly; + +/* + * Disable the timer for a given RPC task. Should be called with + * queue->lock and bh_disabled in order to avoid races within + * rpc_run_timer(). + */ +static void +__rpc_disable_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (task->tk_timeout == 0) + return; + dprintk("RPC: %5u disabling timer\n", task->tk_pid); + task->tk_timeout = 0; + list_del(&task->u.tk_wait.timer_list); + if (list_empty(&queue->timer_list.list)) + del_timer(&queue->timer_list.timer); +} + +static void +rpc_set_queue_timer(struct rpc_wait_queue *queue, unsigned long expires) +{ + queue->timer_list.expires = expires; + mod_timer(&queue->timer_list.timer, expires); +} + +/* + * Set up a timer for the current task. + */ +static void +__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (!task->tk_timeout) + return; + + dprintk("RPC: %5u setting alarm for %u ms\n", + task->tk_pid, jiffies_to_msecs(task->tk_timeout)); + + task->u.tk_wait.expires = jiffies + task->tk_timeout; + if (list_empty(&queue->timer_list.list) || time_before(task->u.tk_wait.expires, queue->timer_list.expires)) + rpc_set_queue_timer(queue, task->u.tk_wait.expires); + list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); +} + +static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) +{ + if (queue->priority != priority) { + queue->priority = priority; + queue->nr = 1U << priority; + } +} + +static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) +{ + rpc_set_waitqueue_priority(queue, queue->maxpriority); +} + +/* + * Add a request to a queue list + */ +static void +__rpc_list_enqueue_task(struct list_head *q, struct rpc_task *task) +{ + struct rpc_task *t; + + list_for_each_entry(t, q, u.tk_wait.list) { + if (t->tk_owner == task->tk_owner) { + list_add_tail(&task->u.tk_wait.links, + &t->u.tk_wait.links); + /* Cache the queue head in task->u.tk_wait.list */ + task->u.tk_wait.list.next = q; + task->u.tk_wait.list.prev = NULL; + return; + } + } + INIT_LIST_HEAD(&task->u.tk_wait.links); + list_add_tail(&task->u.tk_wait.list, q); +} + +/* + * Remove request from a queue list + */ +static void +__rpc_list_dequeue_task(struct rpc_task *task) +{ + struct list_head *q; + struct rpc_task *t; + + if (task->u.tk_wait.list.prev == NULL) { + list_del(&task->u.tk_wait.links); + return; + } + if (!list_empty(&task->u.tk_wait.links)) { + t = list_first_entry(&task->u.tk_wait.links, + struct rpc_task, + u.tk_wait.links); + /* Assume __rpc_list_enqueue_task() cached the queue head */ + q = t->u.tk_wait.list.next; + list_add_tail(&t->u.tk_wait.list, q); + list_del(&task->u.tk_wait.links); + } + list_del(&task->u.tk_wait.list); +} + +/* + * Add new request to a priority queue. + */ +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (unlikely(queue_priority > queue->maxpriority)) + queue_priority = queue->maxpriority; + __rpc_list_enqueue_task(&queue->tasks[queue_priority], task); +} + +/* + * Add new request to wait queue. + * + * Swapper tasks always get inserted at the head of the queue. + * This should avoid many nasty memory deadlocks and hopefully + * improve overall performance. + * Everyone else gets appended to the queue to ensure proper FIFO behavior. + */ +static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + if (RPC_IS_QUEUED(task)) + return; + + if (RPC_IS_PRIORITY(queue)) + __rpc_add_wait_queue_priority(queue, task, queue_priority); + else if (RPC_IS_SWAPPER(task)) + list_add(&task->u.tk_wait.list, &queue->tasks[0]); + else + list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]); + task->tk_waitqueue = queue; + queue->qlen++; + /* barrier matches the read in rpc_wake_up_task_queue_locked() */ + smp_wmb(); + rpc_set_queued(task); + + dprintk("RPC: %5u added to queue %p \"%s\"\n", + task->tk_pid, queue, rpc_qname(queue)); +} + +/* + * Remove request from a priority queue. + */ +static void __rpc_remove_wait_queue_priority(struct rpc_task *task) +{ + __rpc_list_dequeue_task(task); +} + +/* + * Remove request from queue. + * Note: must be called with spin lock held. + */ +static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + __rpc_disable_timer(queue, task); + if (RPC_IS_PRIORITY(queue)) + __rpc_remove_wait_queue_priority(task); + else + list_del(&task->u.tk_wait.list); + queue->qlen--; + dprintk("RPC: %5u removed from queue %p \"%s\"\n", + task->tk_pid, queue, rpc_qname(queue)); +} + +static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues) +{ + int i; + + spin_lock_init(&queue->lock); + for (i = 0; i < ARRAY_SIZE(queue->tasks); i++) + INIT_LIST_HEAD(&queue->tasks[i]); + queue->maxpriority = nr_queues - 1; + rpc_reset_waitqueue_priority(queue); + queue->qlen = 0; + timer_setup(&queue->timer_list.timer, __rpc_queue_timer_fn, 0); + INIT_LIST_HEAD(&queue->timer_list.list); + rpc_assign_waitqueue_name(queue, qname); +} + +void rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, RPC_NR_PRIORITY); +} +EXPORT_SYMBOL_GPL(rpc_init_priority_wait_queue); + +void rpc_init_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, 1); +} +EXPORT_SYMBOL_GPL(rpc_init_wait_queue); + +void rpc_destroy_wait_queue(struct rpc_wait_queue *queue) +{ + del_timer_sync(&queue->timer_list.timer); +} +EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue); + +static int rpc_wait_bit_killable(struct wait_bit_key *key, int mode) +{ + freezable_schedule_unsafe(); + if (signal_pending_state(mode, current)) + return -ERESTARTSYS; + return 0; +} + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) || IS_ENABLED(CONFIG_TRACEPOINTS) +static void rpc_task_set_debuginfo(struct rpc_task *task) +{ + static atomic_t rpc_pid; + + task->tk_pid = atomic_inc_return(&rpc_pid); +} +#else +static inline void rpc_task_set_debuginfo(struct rpc_task *task) +{ +} +#endif + +static void rpc_set_active(struct rpc_task *task) +{ + rpc_task_set_debuginfo(task); + set_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + trace_rpc_task_begin(task, NULL); +} + +/* + * Mark an RPC call as having completed by clearing the 'active' bit + * and then waking up all tasks that were sleeping. + */ +static int rpc_complete_task(struct rpc_task *task) +{ + void *m = &task->tk_runstate; + wait_queue_head_t *wq = bit_waitqueue(m, RPC_TASK_ACTIVE); + struct wait_bit_key k = __WAIT_BIT_KEY_INITIALIZER(m, RPC_TASK_ACTIVE); + unsigned long flags; + int ret; + + trace_rpc_task_complete(task, NULL); + + spin_lock_irqsave(&wq->lock, flags); + clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + ret = atomic_dec_and_test(&task->tk_count); + if (waitqueue_active(wq)) + __wake_up_locked_key(wq, TASK_NORMAL, &k); + spin_unlock_irqrestore(&wq->lock, flags); + return ret; +} + +/* + * Allow callers to wait for completion of an RPC call + * + * Note the use of out_of_line_wait_on_bit() rather than wait_on_bit() + * to enforce taking of the wq->lock and hence avoid races with + * rpc_complete_task(). + */ +int __rpc_wait_for_completion_task(struct rpc_task *task, wait_bit_action_f *action) +{ + if (action == NULL) + action = rpc_wait_bit_killable; + return out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, + action, TASK_KILLABLE); +} +EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); + +/* + * Make an RPC task runnable. + * + * Note: If the task is ASYNC, and is being made runnable after sitting on an + * rpc_wait_queue, this must be called with the queue spinlock held to protect + * the wait queue operation. + * Note the ordering of rpc_test_and_set_running() and rpc_clear_queued(), + * which is needed to ensure that __rpc_execute() doesn't loop (due to the + * lockless RPC_IS_QUEUED() test) before we've had a chance to test + * the RPC_TASK_RUNNING flag. + */ +static void rpc_make_runnable(struct workqueue_struct *wq, + struct rpc_task *task) +{ + bool need_wakeup = !rpc_test_and_set_running(task); + + rpc_clear_queued(task); + if (!need_wakeup) + return; + if (RPC_IS_ASYNC(task)) { + INIT_WORK(&task->u.tk_work, rpc_async_schedule); + queue_work(wq, &task->u.tk_work); + } else + wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); +} + +/* + * Prepare for sleeping on a wait queue. + * By always appending tasks to the list we ensure FIFO behavior. + * NB: An RPC task will only receive interrupt-driven events as long + * as it's on a wait queue. + */ +static void __rpc_sleep_on_priority(struct rpc_wait_queue *q, + struct rpc_task *task, + rpc_action action, + unsigned char queue_priority) +{ + dprintk("RPC: %5u sleep_on(queue \"%s\" time %lu)\n", + task->tk_pid, rpc_qname(q), jiffies); + + trace_rpc_task_sleep(task, q); + + __rpc_add_wait_queue(q, task, queue_priority); + + WARN_ON_ONCE(task->tk_callback != NULL); + task->tk_callback = action; + __rpc_add_timer(q, task); +} + +void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action) +{ + /* We shouldn't ever put an inactive task to sleep */ + WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); + if (!RPC_IS_ACTIVATED(task)) { + task->tk_status = -EIO; + rpc_put_task_async(task); + return; + } + + /* + * Protect the queue operations. + */ + spin_lock_bh(&q->lock); + __rpc_sleep_on_priority(q, task, action, task->tk_priority); + spin_unlock_bh(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on); + +void rpc_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action, int priority) +{ + /* We shouldn't ever put an inactive task to sleep */ + WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); + if (!RPC_IS_ACTIVATED(task)) { + task->tk_status = -EIO; + rpc_put_task_async(task); + return; + } + + /* + * Protect the queue operations. + */ + spin_lock_bh(&q->lock); + __rpc_sleep_on_priority(q, task, action, priority - RPC_PRIORITY_LOW); + spin_unlock_bh(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_priority); + +/** + * __rpc_do_wake_up_task_on_wq - wake up a single rpc_task + * @wq: workqueue on which to run task + * @queue: wait queue + * @task: task to be woken up + * + * Caller must hold queue->lock, and have cleared the task queued flag. + */ +static void __rpc_do_wake_up_task_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + struct rpc_task *task) +{ + dprintk("RPC: %5u __rpc_wake_up_task (now %lu)\n", + task->tk_pid, jiffies); + + /* Has the task been executed yet? If not, we cannot wake it up! */ + if (!RPC_IS_ACTIVATED(task)) { + printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task); + return; + } + + trace_rpc_task_wakeup(task, queue); + + __rpc_remove_wait_queue(queue, task); + + rpc_make_runnable(wq, task); + + dprintk("RPC: __rpc_wake_up_task done\n"); +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static void rpc_wake_up_task_on_wq_queue_locked(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (RPC_IS_QUEUED(task)) { + smp_rmb(); + if (task->tk_waitqueue == queue) + __rpc_do_wake_up_task_on_wq(wq, queue, task); + } +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + rpc_wake_up_task_on_wq_queue_locked(rpciod_workqueue, queue, task); +} + +/* + * Wake up a task on a specific queue + */ +void rpc_wake_up_queued_task_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + struct rpc_task *task) +{ + spin_lock_bh(&queue->lock); + rpc_wake_up_task_on_wq_queue_locked(wq, queue, task); + spin_unlock_bh(&queue->lock); +} + +/* + * Wake up a task on a specific queue + */ +void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + spin_lock_bh(&queue->lock); + rpc_wake_up_task_queue_locked(queue, task); + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); + +/* + * Wake up the next task on a priority queue. + */ +static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *queue) +{ + struct list_head *q; + struct rpc_task *task; + + /* + * Service the privileged queue. + */ + q = &queue->tasks[RPC_NR_PRIORITY - 1]; + if (queue->maxpriority > RPC_PRIORITY_PRIVILEGED && !list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; + } + + /* + * Service a batch of tasks from a single owner. + */ + q = &queue->tasks[queue->priority]; + if (!list_empty(q) && queue->nr) { + queue->nr--; + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; + } + + /* + * Service the next queue. + */ + do { + if (q == &queue->tasks[0]) + q = &queue->tasks[queue->maxpriority]; + else + q = q - 1; + if (!list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto new_queue; + } + } while (q != &queue->tasks[queue->priority]); + + rpc_reset_waitqueue_priority(queue); + return NULL; + +new_queue: + rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0])); +out: + return task; +} + +static struct rpc_task *__rpc_find_next_queued(struct rpc_wait_queue *queue) +{ + if (RPC_IS_PRIORITY(queue)) + return __rpc_find_next_queued_priority(queue); + if (!list_empty(&queue->tasks[0])) + return list_first_entry(&queue->tasks[0], struct rpc_task, u.tk_wait.list); + return NULL; +} + +/* + * Wake up the first task on the wait queue. + */ +struct rpc_task *rpc_wake_up_first_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) +{ + struct rpc_task *task = NULL; + + dprintk("RPC: wake_up_first(%p \"%s\")\n", + queue, rpc_qname(queue)); + spin_lock_bh(&queue->lock); + task = __rpc_find_next_queued(queue); + if (task != NULL) { + if (func(task, data)) + rpc_wake_up_task_on_wq_queue_locked(wq, queue, task); + else + task = NULL; + } + spin_unlock_bh(&queue->lock); + + return task; +} + +/* + * Wake up the first task on the wait queue. + */ +struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) +{ + return rpc_wake_up_first_on_wq(rpciod_workqueue, queue, func, data); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_first); + +static bool rpc_wake_up_next_func(struct rpc_task *task, void *data) +{ + return true; +} + +/* + * Wake up the next task on the wait queue. +*/ +struct rpc_task *rpc_wake_up_next(struct rpc_wait_queue *queue) +{ + return rpc_wake_up_first(queue, rpc_wake_up_next_func, NULL); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_next); + +/** + * rpc_wake_up - wake up all rpc_tasks + * @queue: rpc_wait_queue on which the tasks are sleeping + * + * Grabs queue->lock + */ +void rpc_wake_up(struct rpc_wait_queue *queue) +{ + struct list_head *head; + + spin_lock_bh(&queue->lock); + head = &queue->tasks[queue->maxpriority]; + for (;;) { + while (!list_empty(head)) { + struct rpc_task *task; + task = list_first_entry(head, + struct rpc_task, + u.tk_wait.list); + rpc_wake_up_task_queue_locked(queue, task); + } + if (head == &queue->tasks[0]) + break; + head--; + } + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up); + +/** + * rpc_wake_up_status - wake up all rpc_tasks and set their status value. + * @queue: rpc_wait_queue on which the tasks are sleeping + * @status: status value to set + * + * Grabs queue->lock + */ +void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) +{ + struct list_head *head; + + spin_lock_bh(&queue->lock); + head = &queue->tasks[queue->maxpriority]; + for (;;) { + while (!list_empty(head)) { + struct rpc_task *task; + task = list_first_entry(head, + struct rpc_task, + u.tk_wait.list); + task->tk_status = status; + rpc_wake_up_task_queue_locked(queue, task); + } + if (head == &queue->tasks[0]) + break; + head--; + } + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_status); + +static void __rpc_queue_timer_fn(struct timer_list *t) +{ + struct rpc_wait_queue *queue = from_timer(queue, t, timer_list.timer); + struct rpc_task *task, *n; + unsigned long expires, now, timeo; + + spin_lock(&queue->lock); + expires = now = jiffies; + list_for_each_entry_safe(task, n, &queue->timer_list.list, u.tk_wait.timer_list) { + timeo = task->u.tk_wait.expires; + if (time_after_eq(now, timeo)) { + dprintk("RPC: %5u timeout\n", task->tk_pid); + task->tk_status = -ETIMEDOUT; + rpc_wake_up_task_queue_locked(queue, task); + continue; + } + if (expires == now || time_after(expires, timeo)) + expires = timeo; + } + if (!list_empty(&queue->timer_list.list)) + rpc_set_queue_timer(queue, expires); + spin_unlock(&queue->lock); +} + +static void __rpc_atrun(struct rpc_task *task) +{ + if (task->tk_status == -ETIMEDOUT) + task->tk_status = 0; +} + +/* + * Run a task at a later time + */ +void rpc_delay(struct rpc_task *task, unsigned long delay) +{ + task->tk_timeout = delay; + rpc_sleep_on(&delay_queue, task, __rpc_atrun); +} +EXPORT_SYMBOL_GPL(rpc_delay); + +/* + * Helper to call task->tk_ops->rpc_call_prepare + */ +void rpc_prepare_task(struct rpc_task *task) +{ + task->tk_ops->rpc_call_prepare(task, task->tk_calldata); +} + +static void +rpc_init_task_statistics(struct rpc_task *task) +{ + /* Initialize retry counters */ + task->tk_garb_retry = 2; + task->tk_cred_retry = 2; + task->tk_rebind_retry = 2; + + /* starting timestamp */ + task->tk_start = ktime_get(); +} + +static void +rpc_reset_task_statistics(struct rpc_task *task) +{ + task->tk_timeouts = 0; + task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_KILLED|RPC_TASK_SENT); + + rpc_init_task_statistics(task); +} + +/* + * Helper that calls task->tk_ops->rpc_call_done if it exists + */ +void rpc_exit_task(struct rpc_task *task) +{ + task->tk_action = NULL; + if (task->tk_ops->rpc_call_done != NULL) { + task->tk_ops->rpc_call_done(task, task->tk_calldata); + if (task->tk_action != NULL) { + WARN_ON(RPC_ASSASSINATED(task)); + /* Always release the RPC slot and buffer memory */ + xprt_release(task); + rpc_reset_task_statistics(task); + } + } +} + +void rpc_exit(struct rpc_task *task, int status) +{ + task->tk_status = status; + task->tk_action = rpc_exit_task; + if (RPC_IS_QUEUED(task)) + rpc_wake_up_queued_task(task->tk_waitqueue, task); +} +EXPORT_SYMBOL_GPL(rpc_exit); + +void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata) +{ + if (ops->rpc_release != NULL) + ops->rpc_release(calldata); +} + +/* + * This is the RPC `scheduler' (or rather, the finite state machine). + */ +static void __rpc_execute(struct rpc_task *task) +{ + struct rpc_wait_queue *queue; + int task_is_async = RPC_IS_ASYNC(task); + int status = 0; + + dprintk("RPC: %5u __rpc_execute flags=0x%x\n", + task->tk_pid, task->tk_flags); + + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + if (RPC_IS_QUEUED(task)) + return; + + for (;;) { + void (*do_action)(struct rpc_task *); + + /* + * Perform the next FSM step or a pending callback. + * + * tk_action may be NULL if the task has been killed. + * In particular, note that rpc_killall_tasks may + * do this at any time, so beware when dereferencing. + */ + do_action = task->tk_action; + if (task->tk_callback) { + do_action = task->tk_callback; + task->tk_callback = NULL; + } + if (!do_action) + break; + trace_rpc_task_run_action(task, do_action); + do_action(task); + + /* + * Lockless check for whether task is sleeping or not. + */ + if (!RPC_IS_QUEUED(task)) + continue; + /* + * The queue->lock protects against races with + * rpc_make_runnable(). + * + * Note that once we clear RPC_TASK_RUNNING on an asynchronous + * rpc_task, rpc_make_runnable() can assign it to a + * different workqueue. We therefore cannot assume that the + * rpc_task pointer may still be dereferenced. + */ + queue = task->tk_waitqueue; + spin_lock_bh(&queue->lock); + if (!RPC_IS_QUEUED(task)) { + spin_unlock_bh(&queue->lock); + continue; + } + rpc_clear_running(task); + spin_unlock_bh(&queue->lock); + if (task_is_async) + return; + + /* sync task: sleep here */ + dprintk("RPC: %5u sync task going to sleep\n", task->tk_pid); + status = out_of_line_wait_on_bit(&task->tk_runstate, + RPC_TASK_QUEUED, rpc_wait_bit_killable, + TASK_KILLABLE); + if (status == -ERESTARTSYS) { + /* + * When a sync task receives a signal, it exits with + * -ERESTARTSYS. In order to catch any callbacks that + * clean up after sleeping on some queue, we don't + * break the loop here, but go around once more. + */ + dprintk("RPC: %5u got signal\n", task->tk_pid); + task->tk_flags |= RPC_TASK_KILLED; + rpc_exit(task, -ERESTARTSYS); + } + dprintk("RPC: %5u sync task resuming\n", task->tk_pid); + } + + dprintk("RPC: %5u return %d, status %d\n", task->tk_pid, status, + task->tk_status); + /* Release all resources associated with the task */ + rpc_release_task(task); +} + +/* + * User-visible entry point to the scheduler. + * + * This may be called recursively if e.g. an async NFS task updates + * the attributes and finds that dirty pages must be flushed. + * NOTE: Upon exit of this function the task is guaranteed to be + * released. In particular note that tk_release() will have + * been called, so your task memory may have been freed. + */ +void rpc_execute(struct rpc_task *task) +{ + bool is_async = RPC_IS_ASYNC(task); + + rpc_set_active(task); + rpc_make_runnable(rpciod_workqueue, task); + if (!is_async) + __rpc_execute(task); +} + +static void rpc_async_schedule(struct work_struct *work) +{ + __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); +} + +/** + * rpc_malloc - allocate RPC buffer resources + * @task: RPC task + * + * A single memory region is allocated, which is split between the + * RPC call and RPC reply that this task is being used for. When + * this RPC is retired, the memory is released by calling rpc_free. + * + * To prevent rpciod from hanging, this allocator never sleeps, + * returning -ENOMEM and suppressing warning if the request cannot + * be serviced immediately. The caller can arrange to sleep in a + * way that is safe for rpciod. + * + * Most requests are 'small' (under 2KiB) and can be serviced from a + * mempool, ensuring that NFS reads and writes can always proceed, + * and that there is good locality of reference for these buffers. + * + * In order to avoid memory starvation triggering more writebacks of + * NFS requests, we avoid using GFP_KERNEL. + */ +int rpc_malloc(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize + rqst->rq_rcvsize; + struct rpc_buffer *buf; + gfp_t gfp = GFP_NOIO | __GFP_NOWARN; + + if (RPC_IS_ASYNC(task)) + gfp = GFP_NOWAIT | __GFP_NOWARN; + if (RPC_IS_SWAPPER(task)) + gfp |= __GFP_MEMALLOC; + + size += sizeof(struct rpc_buffer); + if (size <= RPC_BUFFER_MAXSIZE) + buf = mempool_alloc(rpc_buffer_mempool, gfp); + else + buf = kmalloc(size, gfp); + + if (!buf) + return -ENOMEM; + + buf->len = size; + dprintk("RPC: %5u allocated buffer of size %zu at %p\n", + task->tk_pid, size, buf); + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; +} +EXPORT_SYMBOL_GPL(rpc_malloc); + +/** + * rpc_free - free RPC buffer resources allocated via rpc_malloc + * @task: RPC task + * + */ +void rpc_free(struct rpc_task *task) +{ + void *buffer = task->tk_rqstp->rq_buffer; + size_t size; + struct rpc_buffer *buf; + + buf = container_of(buffer, struct rpc_buffer, data); + size = buf->len; + + dprintk("RPC: freeing buffer of size %zu at %p\n", + size, buf); + + if (size <= RPC_BUFFER_MAXSIZE) + mempool_free(buf, rpc_buffer_mempool); + else + kfree(buf); +} +EXPORT_SYMBOL_GPL(rpc_free); + +/* + * Creation and deletion of RPC task structures + */ +static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *task_setup_data) +{ + memset(task, 0, sizeof(*task)); + atomic_set(&task->tk_count, 1); + task->tk_flags = task_setup_data->flags; + task->tk_ops = task_setup_data->callback_ops; + task->tk_calldata = task_setup_data->callback_data; + INIT_LIST_HEAD(&task->tk_task); + + task->tk_priority = task_setup_data->priority - RPC_PRIORITY_LOW; + task->tk_owner = current->tgid; + + /* Initialize workqueue for async tasks */ + task->tk_workqueue = task_setup_data->workqueue; + + task->tk_xprt = xprt_get(task_setup_data->rpc_xprt); + + if (task->tk_ops->rpc_call_prepare != NULL) + task->tk_action = rpc_prepare_task; + + rpc_init_task_statistics(task); + + dprintk("RPC: new task initialized, procpid %u\n", + task_pid_nr(current)); +} + +static struct rpc_task * +rpc_alloc_task(void) +{ + return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOIO); +} + +/* + * Create a new task for the specified client. + */ +struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) +{ + struct rpc_task *task = setup_data->task; + unsigned short flags = 0; + + if (task == NULL) { + task = rpc_alloc_task(); + flags = RPC_TASK_DYNAMIC; + } + + rpc_init_task(task, setup_data); + task->tk_flags |= flags; + dprintk("RPC: allocated task %p\n", task); + return task; +} + +/* + * rpc_free_task - release rpc task and perform cleanups + * + * Note that we free up the rpc_task _after_ rpc_release_calldata() + * in order to work around a workqueue dependency issue. + * + * Tejun Heo states: + * "Workqueue currently considers two work items to be the same if they're + * on the same address and won't execute them concurrently - ie. it + * makes a work item which is queued again while being executed wait + * for the previous execution to complete. + * + * If a work function frees the work item, and then waits for an event + * which should be performed by another work item and *that* work item + * recycles the freed work item, it can create a false dependency loop. + * There really is no reliable way to detect this short of verifying + * every memory free." + * + */ +static void rpc_free_task(struct rpc_task *task) +{ + unsigned short tk_flags = task->tk_flags; + + rpc_release_calldata(task->tk_ops, task->tk_calldata); + + if (tk_flags & RPC_TASK_DYNAMIC) { + dprintk("RPC: %5u freeing task\n", task->tk_pid); + mempool_free(task, rpc_task_mempool); + } +} + +static void rpc_async_release(struct work_struct *work) +{ + rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); +} + +static void rpc_release_resources_task(struct rpc_task *task) +{ + xprt_release(task); + if (task->tk_msg.rpc_cred) { + put_rpccred(task->tk_msg.rpc_cred); + task->tk_msg.rpc_cred = NULL; + } + rpc_task_release_client(task); +} + +static void rpc_final_put_task(struct rpc_task *task, + struct workqueue_struct *q) +{ + if (q != NULL) { + INIT_WORK(&task->u.tk_work, rpc_async_release); + queue_work(q, &task->u.tk_work); + } else + rpc_free_task(task); +} + +static void rpc_do_put_task(struct rpc_task *task, struct workqueue_struct *q) +{ + if (atomic_dec_and_test(&task->tk_count)) { + rpc_release_resources_task(task); + rpc_final_put_task(task, q); + } +} + +void rpc_put_task(struct rpc_task *task) +{ + rpc_do_put_task(task, NULL); +} +EXPORT_SYMBOL_GPL(rpc_put_task); + +void rpc_put_task_async(struct rpc_task *task) +{ + rpc_do_put_task(task, task->tk_workqueue); +} +EXPORT_SYMBOL_GPL(rpc_put_task_async); + +static void rpc_release_task(struct rpc_task *task) +{ + dprintk("RPC: %5u release task\n", task->tk_pid); + + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + + rpc_release_resources_task(task); + + /* + * Note: at this point we have been removed from rpc_clnt->cl_tasks, + * so it should be safe to use task->tk_count as a test for whether + * or not any other processes still hold references to our rpc_task. + */ + if (atomic_read(&task->tk_count) != 1 + !RPC_IS_ASYNC(task)) { + /* Wake up anyone who may be waiting for task completion */ + if (!rpc_complete_task(task)) + return; + } else { + if (!atomic_dec_and_test(&task->tk_count)) + return; + } + rpc_final_put_task(task, task->tk_workqueue); +} + +int rpciod_up(void) +{ + return try_module_get(THIS_MODULE) ? 0 : -EINVAL; +} + +void rpciod_down(void) +{ + module_put(THIS_MODULE); +} + +/* + * Start up the rpciod workqueue. + */ +static int rpciod_start(void) +{ + struct workqueue_struct *wq; + + /* + * Create the rpciod thread and wait for it to start. + */ + dprintk("RPC: creating workqueue rpciod\n"); + wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM | WQ_UNBOUND, 0); + if (!wq) + goto out_failed; + rpciod_workqueue = wq; + /* Note: highpri because network receive is latency sensitive */ + wq = alloc_workqueue("xprtiod", WQ_UNBOUND|WQ_MEM_RECLAIM|WQ_HIGHPRI, 0); + if (!wq) + goto free_rpciod; + xprtiod_workqueue = wq; + return 1; +free_rpciod: + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); +out_failed: + return 0; +} + +static void rpciod_stop(void) +{ + struct workqueue_struct *wq = NULL; + + if (rpciod_workqueue == NULL) + return; + dprintk("RPC: destroying workqueue rpciod\n"); + + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); + wq = xprtiod_workqueue; + xprtiod_workqueue = NULL; + destroy_workqueue(wq); +} + +void +rpc_destroy_mempool(void) +{ + rpciod_stop(); + mempool_destroy(rpc_buffer_mempool); + mempool_destroy(rpc_task_mempool); + kmem_cache_destroy(rpc_task_slabp); + kmem_cache_destroy(rpc_buffer_slabp); + rpc_destroy_wait_queue(&delay_queue); +} + +int +rpc_init_mempool(void) +{ + /* + * The following is not strictly a mempool initialisation, + * but there is no harm in doing it here + */ + rpc_init_wait_queue(&delay_queue, "delayq"); + if (!rpciod_start()) + goto err_nomem; + + rpc_task_slabp = kmem_cache_create("rpc_tasks", + sizeof(struct rpc_task), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_task_slabp) + goto err_nomem; + rpc_buffer_slabp = kmem_cache_create("rpc_buffers", + RPC_BUFFER_MAXSIZE, + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_buffer_slabp) + goto err_nomem; + rpc_task_mempool = mempool_create_slab_pool(RPC_TASK_POOLSIZE, + rpc_task_slabp); + if (!rpc_task_mempool) + goto err_nomem; + rpc_buffer_mempool = mempool_create_slab_pool(RPC_BUFFER_POOLSIZE, + rpc_buffer_slabp); + if (!rpc_buffer_mempool) + goto err_nomem; + return 0; +err_nomem: + rpc_destroy_mempool(); + return -ENOMEM; +} diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c new file mode 100644 index 000000000..f217c348b --- /dev/null +++ b/net/sunrpc/socklib.c @@ -0,0 +1,187 @@ +/* + * linux/net/sunrpc/socklib.c + * + * Common socket helper routines for RPC client and server + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/compiler.h> +#include <linux/netdevice.h> +#include <linux/gfp.h> +#include <linux/skbuff.h> +#include <linux/types.h> +#include <linux/pagemap.h> +#include <linux/udp.h> +#include <linux/sunrpc/xdr.h> +#include <linux/export.h> + + +/** + * xdr_skb_read_bits - copy some data bits from skb to internal buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Possibly called several times to iterate over an sk_buff and copy + * data out of it. + */ +size_t xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + if (len > desc->count) + len = desc->count; + if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len))) + return 0; + desc->count -= len; + desc->offset += len; + return len; +} +EXPORT_SYMBOL_GPL(xdr_skb_read_bits); + +/** + * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Same as skb_read_bits, but calculate a checksum at the same time. + */ +static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + unsigned int pos; + __wsum csum2; + + if (len > desc->count) + len = desc->count; + pos = desc->offset; + csum2 = skb_copy_and_csum_bits(desc->skb, pos, to, len, 0); + desc->csum = csum_block_add(desc->csum, csum2, pos); + desc->count -= len; + desc->offset += len; + return len; +} + +/** + * xdr_partial_copy_from_skb - copy data out of an skb + * @xdr: target XDR buffer + * @base: starting offset + * @desc: sk_buff copy helper + * @copy_actor: virtual method for copying data + * + */ +ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) +{ + struct page **ppage = xdr->pages; + unsigned int len, pglen = xdr->page_len; + ssize_t copied = 0; + size_t ret; + + len = xdr->head[0].iov_len; + if (base < len) { + len -= base; + ret = copy_actor(desc, (char *)xdr->head[0].iov_base + base, len); + copied += ret; + if (ret != len || !desc->count) + goto out; + base = 0; + } else + base -= len; + + if (unlikely(pglen == 0)) + goto copy_tail; + if (unlikely(base >= pglen)) { + base -= pglen; + goto copy_tail; + } + if (base || xdr->page_base) { + pglen -= base; + base += xdr->page_base; + ppage += base >> PAGE_SHIFT; + base &= ~PAGE_MASK; + } + do { + char *kaddr; + + /* ACL likes to be lazy in allocating pages - ACLs + * are small by default but can get huge. */ + if (unlikely(*ppage == NULL)) { + *ppage = alloc_page(GFP_ATOMIC); + if (unlikely(*ppage == NULL)) { + if (copied == 0) + copied = -ENOMEM; + goto out; + } + } + + len = PAGE_SIZE; + kaddr = kmap_atomic(*ppage); + if (base) { + len -= base; + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr + base, len); + base = 0; + } else { + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr, len); + } + flush_dcache_page(*ppage); + kunmap_atomic(kaddr); + copied += ret; + if (ret != len || !desc->count) + goto out; + ppage++; + } while ((pglen -= len) != 0); +copy_tail: + len = xdr->tail[0].iov_len; + if (base < len) + copied += copy_actor(desc, (char *)xdr->tail[0].iov_base + base, len - base); +out: + return copied; +} +EXPORT_SYMBOL_GPL(xdr_partial_copy_from_skb); + +/** + * csum_partial_copy_to_xdr - checksum and copy data + * @xdr: target XDR buffer + * @skb: source skb + * + * We have set things up such that we perform the checksum of the UDP + * packet in parallel with the copies into the RPC client iovec. -DaveM + */ +int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) +{ + struct xdr_skb_reader desc; + + desc.skb = skb; + desc.offset = 0; + desc.count = skb->len - desc.offset; + + if (skb_csum_unnecessary(skb)) + goto no_checksum; + + desc.csum = csum_partial(skb->data, desc.offset, skb->csum); + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_and_csum_bits) < 0) + return -1; + if (desc.offset != skb->len) { + __wsum csum2; + csum2 = skb_checksum(skb, desc.offset, skb->len - desc.offset, 0); + desc.csum = csum_block_add(desc.csum, csum2, desc.offset); + } + if (desc.count) + return -1; + if (csum_fold(desc.csum)) + return -1; + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev); + return 0; +no_checksum: + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) + return -1; + if (desc.count) + return -1; + return 0; +} +EXPORT_SYMBOL_GPL(csum_partial_copy_to_xdr); diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c new file mode 100644 index 000000000..71166b393 --- /dev/null +++ b/net/sunrpc/stats.c @@ -0,0 +1,336 @@ +/* + * linux/net/sunrpc/stats.c + * + * procfs-based user access to generic RPC statistics. The stats files + * reside in /proc/net/rpc. + * + * The read routines assume that the buffer passed in is just big enough. + * If you implement an RPC service that has its own stats routine which + * appends the generic RPC stats, make sure you don't exceed the PAGE_SIZE + * limit. + * + * Copyright (C) 1995, 1996, 1997 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> +#include <linux/slab.h> + +#include <linux/init.h> +#include <linux/kernel.h> +#include <linux/proc_fs.h> +#include <linux/seq_file.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/metrics.h> +#include <linux/rcupdate.h> + +#include <trace/events/sunrpc.h> + +#include "netns.h" + +#define RPCDBG_FACILITY RPCDBG_MISC + +/* + * Get RPC client stats + */ +static int rpc_proc_show(struct seq_file *seq, void *v) { + const struct rpc_stat *statp = seq->private; + const struct rpc_program *prog = statp->program; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u\n", + statp->rpccnt, + statp->rpcretrans, + statp->rpcauthrefresh); + + for (i = 0; i < prog->nrvers; i++) { + const struct rpc_version *vers = prog->version[i]; + if (!vers) + continue; + seq_printf(seq, "proc%u %u", + vers->number, vers->nrprocs); + for (j = 0; j < vers->nrprocs; j++) + seq_printf(seq, " %u", vers->counts[j]); + seq_putc(seq, '\n'); + } + return 0; +} + +static int rpc_proc_open(struct inode *inode, struct file *file) +{ + return single_open(file, rpc_proc_show, PDE_DATA(inode)); +} + +static const struct file_operations rpc_proc_fops = { + .owner = THIS_MODULE, + .open = rpc_proc_open, + .read = seq_read, + .llseek = seq_lseek, + .release = single_release, +}; + +/* + * Get RPC server stats + */ +void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) +{ + const struct svc_program *prog = statp->program; + const struct svc_version *vers; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u %u %u\n", + statp->rpccnt, + statp->rpcbadfmt+statp->rpcbadauth+statp->rpcbadclnt, + statp->rpcbadfmt, + statp->rpcbadauth, + statp->rpcbadclnt); + + for (i = 0; i < prog->pg_nvers; i++) { + vers = prog->pg_vers[i]; + if (!vers) + continue; + seq_printf(seq, "proc%d %u", i, vers->vs_nproc); + for (j = 0; j < vers->vs_nproc; j++) + seq_printf(seq, " %u", vers->vs_count[j]); + seq_putc(seq, '\n'); + } +} +EXPORT_SYMBOL_GPL(svc_seq_show); + +/** + * rpc_alloc_iostats - allocate an rpc_iostats structure + * @clnt: RPC program, version, and xprt + * + */ +struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt) +{ + struct rpc_iostats *stats; + int i; + + stats = kcalloc(clnt->cl_maxproc, sizeof(*stats), GFP_KERNEL); + if (stats) { + for (i = 0; i < clnt->cl_maxproc; i++) + spin_lock_init(&stats[i].om_lock); + } + return stats; +} +EXPORT_SYMBOL_GPL(rpc_alloc_iostats); + +/** + * rpc_free_iostats - release an rpc_iostats structure + * @stats: doomed rpc_iostats structure + * + */ +void rpc_free_iostats(struct rpc_iostats *stats) +{ + kfree(stats); +} +EXPORT_SYMBOL_GPL(rpc_free_iostats); + +/** + * rpc_count_iostats_metrics - tally up per-task stats + * @task: completed rpc_task + * @op_metrics: stat structure for OP that will accumulate stats from @task + */ +void rpc_count_iostats_metrics(const struct rpc_task *task, + struct rpc_iostats *op_metrics) +{ + struct rpc_rqst *req = task->tk_rqstp; + ktime_t backlog, execute, now; + + if (!op_metrics || !req) + return; + + now = ktime_get(); + spin_lock(&op_metrics->om_lock); + + op_metrics->om_ops++; + /* kernel API: om_ops must never become larger than om_ntrans */ + op_metrics->om_ntrans += max(req->rq_ntrans, 1); + op_metrics->om_timeouts += task->tk_timeouts; + + op_metrics->om_bytes_sent += req->rq_xmit_bytes_sent; + op_metrics->om_bytes_recv += req->rq_reply_bytes_recvd; + + backlog = 0; + if (ktime_to_ns(req->rq_xtime)) { + backlog = ktime_sub(req->rq_xtime, task->tk_start); + op_metrics->om_queue = ktime_add(op_metrics->om_queue, backlog); + } + + op_metrics->om_rtt = ktime_add(op_metrics->om_rtt, req->rq_rtt); + + execute = ktime_sub(now, task->tk_start); + op_metrics->om_execute = ktime_add(op_metrics->om_execute, execute); + + spin_unlock(&op_metrics->om_lock); + + trace_rpc_stats_latency(req->rq_task, backlog, req->rq_rtt, execute); +} +EXPORT_SYMBOL_GPL(rpc_count_iostats_metrics); + +/** + * rpc_count_iostats - tally up per-task stats + * @task: completed rpc_task + * @stats: array of stat structures + * + * Uses the statidx from @task + */ +void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats) +{ + rpc_count_iostats_metrics(task, + &stats[task->tk_msg.rpc_proc->p_statidx]); +} +EXPORT_SYMBOL_GPL(rpc_count_iostats); + +static void _print_name(struct seq_file *seq, unsigned int op, + const struct rpc_procinfo *procs) +{ + if (procs[op].p_name) + seq_printf(seq, "\t%12s: ", procs[op].p_name); + else if (op == 0) + seq_printf(seq, "\t NULL: "); + else + seq_printf(seq, "\t%12u: ", op); +} + +static void _add_rpc_iostats(struct rpc_iostats *a, struct rpc_iostats *b) +{ + a->om_ops += b->om_ops; + a->om_ntrans += b->om_ntrans; + a->om_timeouts += b->om_timeouts; + a->om_bytes_sent += b->om_bytes_sent; + a->om_bytes_recv += b->om_bytes_recv; + a->om_queue = ktime_add(a->om_queue, b->om_queue); + a->om_rtt = ktime_add(a->om_rtt, b->om_rtt); + a->om_execute = ktime_add(a->om_execute, b->om_execute); +} + +static void _print_rpc_iostats(struct seq_file *seq, struct rpc_iostats *stats, + int op, const struct rpc_procinfo *procs) +{ + _print_name(seq, op, procs); + seq_printf(seq, "%lu %lu %lu %Lu %Lu %Lu %Lu %Lu\n", + stats->om_ops, + stats->om_ntrans, + stats->om_timeouts, + stats->om_bytes_sent, + stats->om_bytes_recv, + ktime_to_ms(stats->om_queue), + ktime_to_ms(stats->om_rtt), + ktime_to_ms(stats->om_execute)); +} + +void rpc_clnt_show_stats(struct seq_file *seq, struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + unsigned int op, maxproc = clnt->cl_maxproc; + + if (!clnt->cl_metrics) + return; + + seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS); + seq_printf(seq, "p/v: %u/%u (%s)\n", + clnt->cl_prog, clnt->cl_vers, clnt->cl_program->name); + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + if (xprt) + xprt->ops->print_stats(xprt, seq); + rcu_read_unlock(); + + seq_printf(seq, "\tper-op statistics\n"); + for (op = 0; op < maxproc; op++) { + struct rpc_iostats stats = {}; + struct rpc_clnt *next = clnt; + do { + _add_rpc_iostats(&stats, &next->cl_metrics[op]); + if (next == next->cl_parent) + break; + next = next->cl_parent; + } while (next); + _print_rpc_iostats(seq, &stats, op, clnt->cl_procinfo); + } +} +EXPORT_SYMBOL_GPL(rpc_clnt_show_stats); + +/* + * Register/unregister RPC proc files + */ +static inline struct proc_dir_entry * +do_register(struct net *net, const char *name, void *data, + const struct file_operations *fops) +{ + struct sunrpc_net *sn; + + dprintk("RPC: registering /proc/net/rpc/%s\n", name); + sn = net_generic(net, sunrpc_net_id); + return proc_create_data(name, 0, sn->proc_net_rpc, fops, data); +} + +struct proc_dir_entry * +rpc_proc_register(struct net *net, struct rpc_stat *statp) +{ + return do_register(net, statp->program->name, statp, &rpc_proc_fops); +} +EXPORT_SYMBOL_GPL(rpc_proc_register); + +void +rpc_proc_unregister(struct net *net, const char *name) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); +} +EXPORT_SYMBOL_GPL(rpc_proc_unregister); + +struct proc_dir_entry * +svc_proc_register(struct net *net, struct svc_stat *statp, const struct file_operations *fops) +{ + return do_register(net, statp->program->pg_name, statp, fops); +} +EXPORT_SYMBOL_GPL(svc_proc_register); + +void +svc_proc_unregister(struct net *net, const char *name) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); +} +EXPORT_SYMBOL_GPL(svc_proc_unregister); + +int rpc_proc_init(struct net *net) +{ + struct sunrpc_net *sn; + + dprintk("RPC: registering /proc/net/rpc\n"); + sn = net_generic(net, sunrpc_net_id); + sn->proc_net_rpc = proc_mkdir("rpc", net->proc_net); + if (sn->proc_net_rpc == NULL) + return -ENOMEM; + + return 0; +} + +void rpc_proc_exit(struct net *net) +{ + dprintk("RPC: unregistering /proc/net/rpc\n"); + remove_proc_entry("rpc", net->proc_net); +} diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h new file mode 100644 index 000000000..c9bacb3c9 --- /dev/null +++ b/net/sunrpc/sunrpc.h @@ -0,0 +1,59 @@ +/****************************************************************************** + +(c) 2008 NetApp. All Rights Reserved. + +NetApp provides this source code under the GPL v2 License. +The GPL v2 license is available at +http://opensource.org/licenses/gpl-license.php. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +******************************************************************************/ + +/* + * Functions and macros used internally by RPC + */ + +#ifndef _NET_SUNRPC_SUNRPC_H +#define _NET_SUNRPC_SUNRPC_H + +#include <linux/net.h> + +/* + * Header for dynamically allocated rpc buffers. + */ +struct rpc_buffer { + size_t len; + char data[]; +}; + +static inline int sock_is_loopback(struct sock *sk) +{ + struct dst_entry *dst; + int loopback = 0; + rcu_read_lock(); + dst = rcu_dereference(sk->sk_dst_cache); + if (dst && dst->dev && + (dst->dev->features & NETIF_F_LOOPBACK)) + loopback = 1; + rcu_read_unlock(); + return loopback; +} + +int svc_send_common(struct socket *sock, struct xdr_buf *xdr, + struct page *headpage, unsigned long headoffset, + struct page *tailpage, unsigned long tailoffset); + +int rpc_clients_notifier_register(void); +void rpc_clients_notifier_unregister(void); +#endif /* _NET_SUNRPC_SUNRPC_H */ diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c new file mode 100644 index 000000000..56f9eff74 --- /dev/null +++ b/net/sunrpc/sunrpc_syms.c @@ -0,0 +1,140 @@ +/* + * linux/net/sunrpc/sunrpc_syms.c + * + * Symbols exported by the sunrpc module. + * + * Copyright (C) 1997 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/uio.h> +#include <linux/unistd.h> +#include <linux/init.h> + +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/svc.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/auth.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/xprtsock.h> + +#include "netns.h" + +unsigned int sunrpc_net_id; +EXPORT_SYMBOL_GPL(sunrpc_net_id); + +static __net_init int sunrpc_init_net(struct net *net) +{ + int err; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + err = rpc_proc_init(net); + if (err) + goto err_proc; + + err = ip_map_cache_create(net); + if (err) + goto err_ipmap; + + err = unix_gid_cache_create(net); + if (err) + goto err_unixgid; + + err = rpc_pipefs_init_net(net); + if (err) + goto err_pipefs; + + INIT_LIST_HEAD(&sn->all_clients); + spin_lock_init(&sn->rpc_client_lock); + spin_lock_init(&sn->rpcb_clnt_lock); + return 0; + +err_pipefs: + unix_gid_cache_destroy(net); +err_unixgid: + ip_map_cache_destroy(net); +err_ipmap: + rpc_proc_exit(net); +err_proc: + return err; +} + +static __net_exit void sunrpc_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_pipefs_exit_net(net); + unix_gid_cache_destroy(net); + ip_map_cache_destroy(net); + rpc_proc_exit(net); + WARN_ON_ONCE(!list_empty(&sn->all_clients)); +} + +static struct pernet_operations sunrpc_net_ops = { + .init = sunrpc_init_net, + .exit = sunrpc_exit_net, + .id = &sunrpc_net_id, + .size = sizeof(struct sunrpc_net), +}; + +static int __init +init_sunrpc(void) +{ + int err = rpc_init_mempool(); + if (err) + goto out; + err = rpcauth_init_module(); + if (err) + goto out2; + + cache_initialize(); + + err = register_pernet_subsys(&sunrpc_net_ops); + if (err) + goto out3; + + err = register_rpc_pipefs(); + if (err) + goto out4; + + sunrpc_debugfs_init(); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + rpc_register_sysctl(); +#endif + svc_init_xprt_sock(); /* svc sock transport */ + init_socket_xprt(); /* clnt sock transport */ + return 0; + +out4: + unregister_pernet_subsys(&sunrpc_net_ops); +out3: + rpcauth_remove_module(); +out2: + rpc_destroy_mempool(); +out: + return err; +} + +static void __exit +cleanup_sunrpc(void) +{ + rpc_cleanup_clids(); + rpcauth_remove_module(); + cleanup_socket_xprt(); + svc_cleanup_xprt_sock(); + sunrpc_debugfs_exit(); + unregister_rpc_pipefs(); + rpc_destroy_mempool(); + unregister_pernet_subsys(&sunrpc_net_ops); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + rpc_unregister_sysctl(); +#endif + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} +MODULE_LICENSE("GPL"); +fs_initcall(init_sunrpc); /* Ensure we're initialised before nfs */ +module_exit(cleanup_sunrpc); diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c new file mode 100644 index 000000000..6a606af8d --- /dev/null +++ b/net/sunrpc/svc.c @@ -0,0 +1,1646 @@ +/* + * linux/net/sunrpc/svc.c + * + * High-level RPC service routines + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + * + * Multiple threads pools and NUMAisation + * Copyright (c) 2006 Silicon Graphics, Inc. + * by Greg Banks <gnb@melbourne.sgi.com> + */ + +#include <linux/linkage.h> +#include <linux/sched/signal.h> +#include <linux/errno.h> +#include <linux/net.h> +#include <linux/in.h> +#include <linux/mm.h> +#include <linux/interrupt.h> +#include <linux/module.h> +#include <linux/kthread.h> +#include <linux/slab.h> + +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/bc_xprt.h> + +#include <trace/events/sunrpc.h> + +#define RPCDBG_FACILITY RPCDBG_SVCDSP + +static void svc_unregister(const struct svc_serv *serv, struct net *net); + +#define svc_serv_is_pooled(serv) ((serv)->sv_ops->svo_function) + +#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL + +/* + * Structure for mapping cpus to pools and vice versa. + * Setup once during sunrpc initialisation. + */ +struct svc_pool_map svc_pool_map = { + .mode = SVC_POOL_DEFAULT +}; +EXPORT_SYMBOL_GPL(svc_pool_map); + +static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */ + +static int +param_set_pool_mode(const char *val, const struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + struct svc_pool_map *m = &svc_pool_map; + int err; + + mutex_lock(&svc_pool_map_mutex); + + err = -EBUSY; + if (m->count) + goto out; + + err = 0; + if (!strncmp(val, "auto", 4)) + *ip = SVC_POOL_AUTO; + else if (!strncmp(val, "global", 6)) + *ip = SVC_POOL_GLOBAL; + else if (!strncmp(val, "percpu", 6)) + *ip = SVC_POOL_PERCPU; + else if (!strncmp(val, "pernode", 7)) + *ip = SVC_POOL_PERNODE; + else + err = -EINVAL; + +out: + mutex_unlock(&svc_pool_map_mutex); + return err; +} + +static int +param_get_pool_mode(char *buf, const struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + + switch (*ip) + { + case SVC_POOL_AUTO: + return strlcpy(buf, "auto", 20); + case SVC_POOL_GLOBAL: + return strlcpy(buf, "global", 20); + case SVC_POOL_PERCPU: + return strlcpy(buf, "percpu", 20); + case SVC_POOL_PERNODE: + return strlcpy(buf, "pernode", 20); + default: + return sprintf(buf, "%d", *ip); + } +} + +module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode, + &svc_pool_map.mode, 0644); + +/* + * Detect best pool mapping mode heuristically, + * according to the machine's topology. + */ +static int +svc_pool_map_choose_mode(void) +{ + unsigned int node; + + if (nr_online_nodes > 1) { + /* + * Actually have multiple NUMA nodes, + * so split pools on NUMA node boundaries + */ + return SVC_POOL_PERNODE; + } + + node = first_online_node; + if (nr_cpus_node(node) > 2) { + /* + * Non-trivial SMP, or CONFIG_NUMA on + * non-NUMA hardware, e.g. with a generic + * x86_64 kernel on Xeons. In this case we + * want to divide the pools on cpu boundaries. + */ + return SVC_POOL_PERCPU; + } + + /* default: one global pool */ + return SVC_POOL_GLOBAL; +} + +/* + * Allocate the to_pool[] and pool_to[] arrays. + * Returns 0 on success or an errno. + */ +static int +svc_pool_map_alloc_arrays(struct svc_pool_map *m, unsigned int maxpools) +{ + m->to_pool = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->to_pool) + goto fail; + m->pool_to = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->pool_to) + goto fail_free; + + return 0; + +fail_free: + kfree(m->to_pool); + m->to_pool = NULL; +fail: + return -ENOMEM; +} + +/* + * Initialise the pool map for SVC_POOL_PERCPU mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_percpu(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_cpu_ids; + unsigned int pidx = 0; + unsigned int cpu; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_online_cpu(cpu) { + BUG_ON(pidx >= maxpools); + m->to_pool[cpu] = pidx; + m->pool_to[pidx] = cpu; + pidx++; + } + /* cpus brought online later all get mapped to pool0, sorry */ + + return pidx; +}; + + +/* + * Initialise the pool map for SVC_POOL_PERNODE mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_pernode(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_node_ids; + unsigned int pidx = 0; + unsigned int node; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_node_with_cpus(node) { + /* some architectures (e.g. SN2) have cpuless nodes */ + BUG_ON(pidx > maxpools); + m->to_pool[node] = pidx; + m->pool_to[pidx] = node; + pidx++; + } + /* nodes brought online later all get mapped to pool0, sorry */ + + return pidx; +} + + +/* + * Add a reference to the global map of cpus to pools (and + * vice versa). Initialise the map if we're the first user. + * Returns the number of pools. + */ +unsigned int +svc_pool_map_get(void) +{ + struct svc_pool_map *m = &svc_pool_map; + int npools = -1; + + mutex_lock(&svc_pool_map_mutex); + + if (m->count++) { + mutex_unlock(&svc_pool_map_mutex); + return m->npools; + } + + if (m->mode == SVC_POOL_AUTO) + m->mode = svc_pool_map_choose_mode(); + + switch (m->mode) { + case SVC_POOL_PERCPU: + npools = svc_pool_map_init_percpu(m); + break; + case SVC_POOL_PERNODE: + npools = svc_pool_map_init_pernode(m); + break; + } + + if (npools < 0) { + /* default, or memory allocation failure */ + npools = 1; + m->mode = SVC_POOL_GLOBAL; + } + m->npools = npools; + + mutex_unlock(&svc_pool_map_mutex); + return m->npools; +} +EXPORT_SYMBOL_GPL(svc_pool_map_get); + +/* + * Drop a reference to the global map of cpus to pools. + * When the last reference is dropped, the map data is + * freed; this allows the sysadmin to change the pool + * mode using the pool_mode module option without + * rebooting or re-loading sunrpc.ko. + */ +void +svc_pool_map_put(void) +{ + struct svc_pool_map *m = &svc_pool_map; + + mutex_lock(&svc_pool_map_mutex); + + if (!--m->count) { + kfree(m->to_pool); + m->to_pool = NULL; + kfree(m->pool_to); + m->pool_to = NULL; + m->npools = 0; + } + + mutex_unlock(&svc_pool_map_mutex); +} +EXPORT_SYMBOL_GPL(svc_pool_map_put); + +static int svc_pool_map_get_node(unsigned int pidx) +{ + const struct svc_pool_map *m = &svc_pool_map; + + if (m->count) { + if (m->mode == SVC_POOL_PERCPU) + return cpu_to_node(m->pool_to[pidx]); + if (m->mode == SVC_POOL_PERNODE) + return m->pool_to[pidx]; + } + return NUMA_NO_NODE; +} +/* + * Set the given thread's cpus_allowed mask so that it + * will only run on cpus in the given pool. + */ +static inline void +svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx) +{ + struct svc_pool_map *m = &svc_pool_map; + unsigned int node = m->pool_to[pidx]; + + /* + * The caller checks for sv_nrpools > 1, which + * implies that we've been initialized. + */ + WARN_ON_ONCE(m->count == 0); + if (m->count == 0) + return; + + switch (m->mode) { + case SVC_POOL_PERCPU: + { + set_cpus_allowed_ptr(task, cpumask_of(node)); + break; + } + case SVC_POOL_PERNODE: + { + set_cpus_allowed_ptr(task, cpumask_of_node(node)); + break; + } + } +} + +/* + * Use the mapping mode to choose a pool for a given CPU. + * Used when enqueueing an incoming RPC. Always returns + * a non-NULL pool pointer. + */ +struct svc_pool * +svc_pool_for_cpu(struct svc_serv *serv, int cpu) +{ + struct svc_pool_map *m = &svc_pool_map; + unsigned int pidx = 0; + + /* + * An uninitialised map happens in a pure client when + * lockd is brought up, so silently treat it the + * same as SVC_POOL_GLOBAL. + */ + if (svc_serv_is_pooled(serv)) { + switch (m->mode) { + case SVC_POOL_PERCPU: + pidx = m->to_pool[cpu]; + break; + case SVC_POOL_PERNODE: + pidx = m->to_pool[cpu_to_node(cpu)]; + break; + } + } + return &serv->sv_pools[pidx % serv->sv_nrpools]; +} + +int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +{ + int err; + + err = rpcb_create_local(net); + if (err) + return err; + + /* Remove any stale portmap registrations */ + svc_unregister(serv, net); + return 0; +} +EXPORT_SYMBOL_GPL(svc_rpcb_setup); + +void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) +{ + svc_unregister(serv, net); + rpcb_put_local(net); +} +EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); + +static int svc_uses_rpcbind(struct svc_serv *serv) +{ + struct svc_program *progp; + unsigned int i; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (!progp->pg_vers[i]->vs_hidden) + return 1; + } + } + + return 0; +} + +int svc_bind(struct svc_serv *serv, struct net *net) +{ + if (!svc_uses_rpcbind(serv)) + return 0; + return svc_rpcb_setup(serv, net); +} +EXPORT_SYMBOL_GPL(svc_bind); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void +__svc_init_bc(struct svc_serv *serv) +{ + INIT_LIST_HEAD(&serv->sv_cb_list); + spin_lock_init(&serv->sv_cb_lock); + init_waitqueue_head(&serv->sv_cb_waitq); +} +#else +static void +__svc_init_bc(struct svc_serv *serv) +{ +} +#endif + +/* + * Create an RPC service + */ +static struct svc_serv * +__svc_create(struct svc_program *prog, unsigned int bufsize, int npools, + const struct svc_serv_ops *ops) +{ + struct svc_serv *serv; + unsigned int vers; + unsigned int xdrsize; + unsigned int i; + + if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) + return NULL; + serv->sv_name = prog->pg_name; + serv->sv_program = prog; + serv->sv_nrthreads = 1; + serv->sv_stats = prog->pg_stats; + if (bufsize > RPCSVC_MAXPAYLOAD) + bufsize = RPCSVC_MAXPAYLOAD; + serv->sv_max_payload = bufsize? bufsize : 4096; + serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); + serv->sv_ops = ops; + xdrsize = 0; + while (prog) { + prog->pg_lovers = prog->pg_nvers-1; + for (vers=0; vers<prog->pg_nvers ; vers++) + if (prog->pg_vers[vers]) { + prog->pg_hivers = vers; + if (prog->pg_lovers > vers) + prog->pg_lovers = vers; + if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = prog->pg_vers[vers]->vs_xdrsize; + } + prog = prog->pg_next; + } + serv->sv_xdrsize = xdrsize; + INIT_LIST_HEAD(&serv->sv_tempsocks); + INIT_LIST_HEAD(&serv->sv_permsocks); + timer_setup(&serv->sv_temptimer, NULL, 0); + spin_lock_init(&serv->sv_lock); + + __svc_init_bc(serv); + + serv->sv_nrpools = npools; + serv->sv_pools = + kcalloc(serv->sv_nrpools, sizeof(struct svc_pool), + GFP_KERNEL); + if (!serv->sv_pools) { + kfree(serv); + return NULL; + } + + for (i = 0; i < serv->sv_nrpools; i++) { + struct svc_pool *pool = &serv->sv_pools[i]; + + dprintk("svc: initialising pool %u for %s\n", + i, serv->sv_name); + + pool->sp_id = i; + INIT_LIST_HEAD(&pool->sp_sockets); + INIT_LIST_HEAD(&pool->sp_all_threads); + spin_lock_init(&pool->sp_lock); + } + + return serv; +} + +struct svc_serv * +svc_create(struct svc_program *prog, unsigned int bufsize, + const struct svc_serv_ops *ops) +{ + return __svc_create(prog, bufsize, /*npools*/1, ops); +} +EXPORT_SYMBOL_GPL(svc_create); + +struct svc_serv * +svc_create_pooled(struct svc_program *prog, unsigned int bufsize, + const struct svc_serv_ops *ops) +{ + struct svc_serv *serv; + unsigned int npools = svc_pool_map_get(); + + serv = __svc_create(prog, bufsize, npools, ops); + if (!serv) + goto out_err; + return serv; +out_err: + svc_pool_map_put(); + return NULL; +} +EXPORT_SYMBOL_GPL(svc_create_pooled); + +void svc_shutdown_net(struct svc_serv *serv, struct net *net) +{ + svc_close_net(serv, net); + + if (serv->sv_ops->svo_shutdown) + serv->sv_ops->svo_shutdown(serv, net); +} +EXPORT_SYMBOL_GPL(svc_shutdown_net); + +/* + * Destroy an RPC service. Should be called with appropriate locking to + * protect the sv_nrthreads, sv_permsocks and sv_tempsocks. + */ +void +svc_destroy(struct svc_serv *serv) +{ + dprintk("svc: svc_destroy(%s, %d)\n", + serv->sv_program->pg_name, + serv->sv_nrthreads); + + if (serv->sv_nrthreads) { + if (--(serv->sv_nrthreads) != 0) { + svc_sock_update_bufs(serv); + return; + } + } else + printk("svc_destroy: no threads for serv=%p!\n", serv); + + del_timer_sync(&serv->sv_temptimer); + + /* + * The last user is gone and thus all sockets have to be destroyed to + * the point. Check this. + */ + BUG_ON(!list_empty(&serv->sv_permsocks)); + BUG_ON(!list_empty(&serv->sv_tempsocks)); + + cache_clean_deferred(serv); + + if (svc_serv_is_pooled(serv)) + svc_pool_map_put(); + + kfree(serv->sv_pools); + kfree(serv); +} +EXPORT_SYMBOL_GPL(svc_destroy); + +/* + * Allocate an RPC server's buffer space. + * We allocate pages and place them in rq_argpages. + */ +static int +svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) +{ + unsigned int pages, arghi; + + /* bc_xprt uses fore channel allocated buffers */ + if (svc_is_backchannel(rqstp)) + return 1; + + pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply. + * We assume one is at most one page + */ + arghi = 0; + WARN_ON_ONCE(pages > RPCSVC_MAXPAGES); + if (pages > RPCSVC_MAXPAGES) + pages = RPCSVC_MAXPAGES; + while (pages) { + struct page *p = alloc_pages_node(node, GFP_KERNEL, 0); + if (!p) + break; + rqstp->rq_pages[arghi++] = p; + pages--; + } + return pages == 0; +} + +/* + * Release an RPC server buffer + */ +static void +svc_release_buffer(struct svc_rqst *rqstp) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++) + if (rqstp->rq_pages[i]) + put_page(rqstp->rq_pages[i]); +} + +struct svc_rqst * +svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) +{ + struct svc_rqst *rqstp; + + rqstp = kzalloc_node(sizeof(*rqstp), GFP_KERNEL, node); + if (!rqstp) + return rqstp; + + __set_bit(RQ_BUSY, &rqstp->rq_flags); + spin_lock_init(&rqstp->rq_lock); + rqstp->rq_server = serv; + rqstp->rq_pool = pool; + + rqstp->rq_argp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); + if (!rqstp->rq_argp) + goto out_enomem; + + rqstp->rq_resp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); + if (!rqstp->rq_resp) + goto out_enomem; + + if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) + goto out_enomem; + + return rqstp; +out_enomem: + svc_rqst_free(rqstp); + return NULL; +} +EXPORT_SYMBOL_GPL(svc_rqst_alloc); + +struct svc_rqst * +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) +{ + struct svc_rqst *rqstp; + + rqstp = svc_rqst_alloc(serv, pool, node); + if (!rqstp) + return ERR_PTR(-ENOMEM); + + serv->sv_nrthreads++; + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads++; + list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads); + spin_unlock_bh(&pool->sp_lock); + return rqstp; +} +EXPORT_SYMBOL_GPL(svc_prepare_thread); + +/* + * Choose a pool in which to create a new thread, for svc_set_num_threads + */ +static inline struct svc_pool * +choose_pool(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + if (pool != NULL) + return pool; + + return &serv->sv_pools[(*state)++ % serv->sv_nrpools]; +} + +/* + * Choose a thread to kill, for svc_set_num_threads + */ +static inline struct task_struct * +choose_victim(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + unsigned int i; + struct task_struct *task = NULL; + + if (pool != NULL) { + spin_lock_bh(&pool->sp_lock); + } else { + /* choose a pool in round-robin fashion */ + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; + spin_lock_bh(&pool->sp_lock); + if (!list_empty(&pool->sp_all_threads)) + goto found_pool; + spin_unlock_bh(&pool->sp_lock); + } + return NULL; + } + +found_pool: + if (!list_empty(&pool->sp_all_threads)) { + struct svc_rqst *rqstp; + + /* + * Remove from the pool->sp_all_threads list + * so we don't try to kill it again. + */ + rqstp = list_entry(pool->sp_all_threads.next, struct svc_rqst, rq_all); + set_bit(RQ_VICTIM, &rqstp->rq_flags); + list_del_rcu(&rqstp->rq_all); + task = rqstp->rq_task; + } + spin_unlock_bh(&pool->sp_lock); + + return task; +} + +/* create new threads */ +static int +svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct svc_rqst *rqstp; + struct task_struct *task; + struct svc_pool *chosen_pool; + unsigned int state = serv->sv_nrthreads-1; + int node; + + do { + nrservs--; + chosen_pool = choose_pool(serv, pool, &state); + + node = svc_pool_map_get_node(chosen_pool->sp_id); + rqstp = svc_prepare_thread(serv, chosen_pool, node); + if (IS_ERR(rqstp)) + return PTR_ERR(rqstp); + + __module_get(serv->sv_ops->svo_module); + task = kthread_create_on_node(serv->sv_ops->svo_function, rqstp, + node, "%s", serv->sv_name); + if (IS_ERR(task)) { + module_put(serv->sv_ops->svo_module); + svc_exit_thread(rqstp); + return PTR_ERR(task); + } + + rqstp->rq_task = task; + if (serv->sv_nrpools > 1) + svc_pool_map_set_cpumask(task, chosen_pool->sp_id); + + svc_sock_update_bufs(serv); + wake_up_process(task); + } while (nrservs > 0); + + return 0; +} + + +/* destroy old threads */ +static int +svc_signal_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct task_struct *task; + unsigned int state = serv->sv_nrthreads-1; + + /* destroy old threads */ + do { + task = choose_victim(serv, pool, &state); + if (task == NULL) + break; + send_sig(SIGINT, task, 1); + nrservs++; + } while (nrservs < 0); + + return 0; +} + +/* + * Create or destroy enough new threads to make the number + * of threads the given number. If `pool' is non-NULL, applies + * only to threads in that pool, otherwise round-robins between + * all pools. Caller must ensure that mutual exclusion between this and + * server startup or shutdown. + * + * Destroying threads relies on the service threads filling in + * rqstp->rq_task, which only the nfs ones do. Assumes the serv + * has been created using svc_create_pooled(). + * + * Based on code that used to be in nfsd_svc() but tweaked + * to be pool-aware. + */ +int +svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + if (pool == NULL) { + /* The -1 assumes caller has done a svc_get() */ + nrservs -= (serv->sv_nrthreads-1); + } else { + spin_lock_bh(&pool->sp_lock); + nrservs -= pool->sp_nrthreads; + spin_unlock_bh(&pool->sp_lock); + } + + if (nrservs > 0) + return svc_start_kthreads(serv, pool, nrservs); + if (nrservs < 0) + return svc_signal_kthreads(serv, pool, nrservs); + return 0; +} +EXPORT_SYMBOL_GPL(svc_set_num_threads); + +/* destroy old threads */ +static int +svc_stop_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct task_struct *task; + unsigned int state = serv->sv_nrthreads-1; + + /* destroy old threads */ + do { + task = choose_victim(serv, pool, &state); + if (task == NULL) + break; + kthread_stop(task); + nrservs++; + } while (nrservs < 0); + return 0; +} + +int +svc_set_num_threads_sync(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + if (pool == NULL) { + /* The -1 assumes caller has done a svc_get() */ + nrservs -= (serv->sv_nrthreads-1); + } else { + spin_lock_bh(&pool->sp_lock); + nrservs -= pool->sp_nrthreads; + spin_unlock_bh(&pool->sp_lock); + } + + if (nrservs > 0) + return svc_start_kthreads(serv, pool, nrservs); + if (nrservs < 0) + return svc_stop_kthreads(serv, pool, nrservs); + return 0; +} +EXPORT_SYMBOL_GPL(svc_set_num_threads_sync); + +/* + * Called from a server thread as it's exiting. Caller must hold the "service + * mutex" for the service. + */ +void +svc_rqst_free(struct svc_rqst *rqstp) +{ + svc_release_buffer(rqstp); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + kfree_rcu(rqstp, rq_rcu_head); +} +EXPORT_SYMBOL_GPL(svc_rqst_free); + +void +svc_exit_thread(struct svc_rqst *rqstp) +{ + struct svc_serv *serv = rqstp->rq_server; + struct svc_pool *pool = rqstp->rq_pool; + + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads--; + if (!test_and_set_bit(RQ_VICTIM, &rqstp->rq_flags)) + list_del_rcu(&rqstp->rq_all); + spin_unlock_bh(&pool->sp_lock); + + svc_rqst_free(rqstp); + + /* Release the server */ + if (serv) + svc_destroy(serv); +} +EXPORT_SYMBOL_GPL(svc_exit_thread); + +/* + * Register an "inet" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register4(struct net *net, const u32 program, + const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + const struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; + const char *netid; + int error; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP; + break; + default: + return -ENOPROTOOPT; + } + + error = rpcb_v4_register(net, program, version, + (const struct sockaddr *)&sin, netid); + + /* + * User space didn't support rpcbind v4, so retry this + * registration request with the legacy rpcbind v2 protocol. + */ + if (error == -EPROTONOSUPPORT) + error = rpcb_register(net, program, version, protocol, port); + + return error; +} + +#if IS_ENABLED(CONFIG_IPV6) +/* + * Register an "inet6" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register6(struct net *net, const u32 program, + const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + const struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; + const char *netid; + int error; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP6; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP6; + break; + default: + return -ENOPROTOOPT; + } + + error = rpcb_v4_register(net, program, version, + (const struct sockaddr *)&sin6, netid); + + /* + * User space didn't support rpcbind version 4, so we won't + * use a PF_INET6 listener. + */ + if (error == -EPROTONOSUPPORT) + error = -EAFNOSUPPORT; + + return error; +} +#endif /* IS_ENABLED(CONFIG_IPV6) */ + +/* + * Register a kernel RPC service via rpcbind version 4. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_register(struct net *net, const char *progname, + const u32 program, const u32 version, + const int family, + const unsigned short protocol, + const unsigned short port) +{ + int error = -EAFNOSUPPORT; + + switch (family) { + case PF_INET: + error = __svc_rpcb_register4(net, program, version, + protocol, port); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + error = __svc_rpcb_register6(net, program, version, + protocol, port); +#endif + } + + return error; +} + +/** + * svc_register - register an RPC service with the local portmapper + * @serv: svc_serv struct for the service to register + * @net: net namespace for the service to register + * @family: protocol family of service's listener socket + * @proto: transport protocol number to advertise + * @port: port to advertise + * + * Service is registered for any address in the passed-in protocol family + */ +int svc_register(const struct svc_serv *serv, struct net *net, + const int family, const unsigned short proto, + const unsigned short port) +{ + struct svc_program *progp; + const struct svc_version *vers; + unsigned int i; + int error = 0; + + WARN_ON_ONCE(proto == 0 && port == 0); + if (proto == 0 && port == 0) + return -EINVAL; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + vers = progp->pg_vers[i]; + if (vers == NULL) + continue; + + dprintk("svc: svc_register(%sv%d, %s, %u, %u)%s\n", + progp->pg_name, + i, + proto == IPPROTO_UDP? "udp" : "tcp", + port, + family, + vers->vs_hidden ? + " (but not telling portmap)" : ""); + + if (vers->vs_hidden) + continue; + + /* + * Don't register a UDP port if we need congestion + * control. + */ + if (vers->vs_need_cong_ctrl && proto == IPPROTO_UDP) + continue; + + error = __svc_register(net, progp->pg_name, progp->pg_prog, + i, family, proto, port); + + if (vers->vs_rpcb_optnl) { + error = 0; + continue; + } + + if (error < 0) { + printk(KERN_WARNING "svc: failed to register " + "%sv%u RPC service (errno %d).\n", + progp->pg_name, i, -error); + break; + } + } + } + + return error; +} + +/* + * If user space is running rpcbind, it should take the v4 UNSET + * and clear everything for this [program, version]. If user space + * is running portmap, it will reject the v4 UNSET, but won't have + * any "inet6" entries anyway. So a PMAP_UNSET should be sufficient + * in this case to clear all existing entries for [program, version]. + */ +static void __svc_unregister(struct net *net, const u32 program, const u32 version, + const char *progname) +{ + int error; + + error = rpcb_v4_register(net, program, version, NULL, ""); + + /* + * User space didn't support rpcbind v4, so retry this + * request with the legacy rpcbind v2 protocol. + */ + if (error == -EPROTONOSUPPORT) + error = rpcb_register(net, program, version, 0, 0); + + dprintk("svc: %s(%sv%u), error %d\n", + __func__, progname, version, error); +} + +/* + * All netids, bind addresses and ports registered for [program, version] + * are removed from the local rpcbind database (if the service is not + * hidden) to make way for a new instance of the service. + * + * The result of unregistration is reported via dprintk for those who want + * verification of the result, but is otherwise not important. + */ +static void svc_unregister(const struct svc_serv *serv, struct net *net) +{ + struct svc_program *progp; + unsigned long flags; + unsigned int i; + + clear_thread_flag(TIF_SIGPENDING); + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (progp->pg_vers[i]->vs_hidden) + continue; + + dprintk("svc: attempting to unregister %sv%u\n", + progp->pg_name, i); + __svc_unregister(net, progp->pg_prog, i, progp->pg_name); + } + } + + spin_lock_irqsave(¤t->sighand->siglock, flags); + recalc_sigpending(); + spin_unlock_irqrestore(¤t->sighand->siglock, flags); +} + +/* + * dprintk the given error with the address of the client that caused it. + */ +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +static __printf(2, 3) +void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) +{ + struct va_format vaf; + va_list args; + char buf[RPC_MAX_ADDRBUFLEN]; + + va_start(args, fmt); + + vaf.fmt = fmt; + vaf.va = &args; + + dprintk("svc: %s: %pV", svc_print_addr(rqstp, buf, sizeof(buf)), &vaf); + + va_end(args); +} +#else +static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) {} +#endif + +extern void svc_tcp_prep_reply_hdr(struct svc_rqst *); + +__be32 +svc_return_autherr(struct svc_rqst *rqstp, __be32 auth_err) +{ + set_bit(RQ_AUTHERR, &rqstp->rq_flags); + return auth_err; +} +EXPORT_SYMBOL_GPL(svc_return_autherr); + +static __be32 +svc_get_autherr(struct svc_rqst *rqstp, __be32 *statp) +{ + if (test_and_clear_bit(RQ_AUTHERR, &rqstp->rq_flags)) + return *statp; + return rpc_auth_ok; +} + +/* + * Common routine for processing the RPC request. + */ +static int +svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) +{ + struct svc_program *progp; + const struct svc_version *versp = NULL; /* compiler food */ + const struct svc_procedure *procp = NULL; + struct svc_serv *serv = rqstp->rq_server; + __be32 *statp; + u32 prog, vers, proc; + __be32 auth_stat, rpc_stat; + int auth_res; + __be32 *reply_statp; + + rpc_stat = rpc_success; + + if (argv->iov_len < 6*4) + goto err_short_len; + + /* Will be turned off by GSS integrity and privacy services */ + set_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + /* Will be turned off only when NFSv4 Sessions are used */ + set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); + clear_bit(RQ_DROPME, &rqstp->rq_flags); + + /* Setup reply header */ + if (rqstp->rq_prot == IPPROTO_TCP) + svc_tcp_prep_reply_hdr(rqstp); + + svc_putu32(resv, rqstp->rq_xid); + + vers = svc_getnl(argv); + + /* First words of reply: */ + svc_putnl(resv, 1); /* REPLY */ + + if (vers != 2) /* RPC version number */ + goto err_bad_rpc; + + /* Save position in case we later decide to reject: */ + reply_statp = resv->iov_base + resv->iov_len; + + svc_putnl(resv, 0); /* ACCEPT */ + + rqstp->rq_prog = prog = svc_getnl(argv); /* program number */ + rqstp->rq_vers = vers = svc_getnl(argv); /* version number */ + rqstp->rq_proc = proc = svc_getnl(argv); /* procedure number */ + + for (progp = serv->sv_program; progp; progp = progp->pg_next) + if (prog == progp->pg_prog) + break; + + /* + * Decode auth data, and add verifier to reply buffer. + * We do this before anything else in order to get a decent + * auth verifier. + */ + auth_res = svc_authenticate(rqstp, &auth_stat); + /* Also give the program a chance to reject this call: */ + if (auth_res == SVC_OK && progp) { + auth_stat = rpc_autherr_badcred; + auth_res = progp->pg_authenticate(rqstp); + } + switch (auth_res) { + case SVC_OK: + break; + case SVC_GARBAGE: + goto err_garbage; + case SVC_SYSERR: + rpc_stat = rpc_system_err; + goto err_bad; + case SVC_DENIED: + goto err_bad_auth; + case SVC_CLOSE: + goto close; + case SVC_DROP: + goto dropit; + case SVC_COMPLETE: + goto sendit; + } + + if (progp == NULL) + goto err_bad_prog; + + if (vers >= progp->pg_nvers || + !(versp = progp->pg_vers[vers])) + goto err_bad_vers; + + /* + * Some protocol versions (namely NFSv4) require some form of + * congestion control. (See RFC 7530 section 3.1 paragraph 2) + * In other words, UDP is not allowed. We mark those when setting + * up the svc_xprt, and verify that here. + * + * The spec is not very clear about what error should be returned + * when someone tries to access a server that is listening on UDP + * for lower versions. RPC_PROG_MISMATCH seems to be the closest + * fit. + */ + if (versp->vs_need_cong_ctrl && rqstp->rq_xprt && + !test_bit(XPT_CONG_CTRL, &rqstp->rq_xprt->xpt_flags)) + goto err_bad_vers; + + procp = versp->vs_proc + proc; + if (proc >= versp->vs_nproc || !procp->pc_func) + goto err_bad_proc; + rqstp->rq_procinfo = procp; + + /* Syntactic check complete */ + serv->sv_stats->rpccnt++; + trace_svc_process(rqstp, progp->pg_name); + + /* Build the reply header. */ + statp = resv->iov_base +resv->iov_len; + svc_putnl(resv, RPC_SUCCESS); + + /* Bump per-procedure stats counter */ + versp->vs_count[proc]++; + + /* Initialize storage for argp and resp */ + memset(rqstp->rq_argp, 0, procp->pc_argsize); + memset(rqstp->rq_resp, 0, procp->pc_ressize); + + /* un-reserve some of the out-queue now that we have a + * better idea of reply size + */ + if (procp->pc_xdrressize) + svc_reserve_auth(rqstp, procp->pc_xdrressize<<2); + + /* Call the function that processes the request. */ + if (!versp->vs_dispatch) { + /* + * Decode arguments + * XXX: why do we ignore the return value? + */ + if (procp->pc_decode && + !procp->pc_decode(rqstp, argv->iov_base)) + goto err_garbage; + + *statp = procp->pc_func(rqstp); + + /* Encode reply */ + if (*statp == rpc_drop_reply || + test_bit(RQ_DROPME, &rqstp->rq_flags)) { + if (procp->pc_release) + procp->pc_release(rqstp); + goto dropit; + } + auth_stat = svc_get_autherr(rqstp, statp); + if (auth_stat != rpc_auth_ok) + goto err_release_bad_auth; + if (*statp == rpc_success && procp->pc_encode && + !procp->pc_encode(rqstp, resv->iov_base + resv->iov_len)) { + dprintk("svc: failed to encode reply\n"); + /* serv->sv_stats->rpcsystemerr++; */ + *statp = rpc_system_err; + } + } else { + dprintk("svc: calling dispatcher\n"); + if (!versp->vs_dispatch(rqstp, statp)) { + /* Release reply info */ + if (procp->pc_release) + procp->pc_release(rqstp); + goto dropit; + } + } + + /* Check RPC status result */ + if (*statp != rpc_success) + resv->iov_len = ((void*)statp) - resv->iov_base + 4; + + /* Release reply info */ + if (procp->pc_release) + procp->pc_release(rqstp); + + if (procp->pc_encode == NULL) + goto dropit; + + sendit: + if (svc_authorise(rqstp)) + goto close_xprt; + return 1; /* Caller can now send it */ + + dropit: + svc_authorise(rqstp); /* doesn't hurt to call this twice */ + dprintk("svc: svc_process dropit\n"); + return 0; + + close: + svc_authorise(rqstp); +close_xprt: + if (rqstp->rq_xprt && test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags)) + svc_close_xprt(rqstp->rq_xprt); + dprintk("svc: svc_process close\n"); + return 0; + +err_short_len: + svc_printk(rqstp, "short len %zd, dropping request\n", + argv->iov_len); + goto close_xprt; + +err_bad_rpc: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 0); /* RPC_MISMATCH */ + svc_putnl(resv, 2); /* Only RPCv2 supported */ + svc_putnl(resv, 2); + goto sendit; + +err_release_bad_auth: + if (procp->pc_release) + procp->pc_release(rqstp); +err_bad_auth: + dprintk("svc: authentication failed (%d)\n", ntohl(auth_stat)); + serv->sv_stats->rpcbadauth++; + /* Restore write pointer to location of accept status: */ + xdr_ressize_check(rqstp, reply_statp); + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 1); /* AUTH_ERROR */ + svc_putnl(resv, ntohl(auth_stat)); /* status */ + goto sendit; + +err_bad_prog: + dprintk("svc: unknown program %d\n", prog); + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_UNAVAIL); + goto sendit; + +err_bad_vers: + svc_printk(rqstp, "unknown version (%d for prog %d, %s)\n", + vers, prog, progp->pg_name); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_MISMATCH); + svc_putnl(resv, progp->pg_lovers); + svc_putnl(resv, progp->pg_hivers); + goto sendit; + +err_bad_proc: + svc_printk(rqstp, "unknown procedure (%d)\n", proc); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROC_UNAVAIL); + goto sendit; + +err_garbage: + svc_printk(rqstp, "failed to decode args\n"); + + rpc_stat = rpc_garbage_args; +err_bad: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, ntohl(rpc_stat)); + goto sendit; +} + +/* + * Process the RPC request. + */ +int +svc_process(struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_serv *serv = rqstp->rq_server; + u32 dir; + + /* + * Setup response xdr_buf. + * Initially it has just one page + */ + rqstp->rq_next_page = &rqstp->rq_respages[1]; + resv->iov_base = page_address(rqstp->rq_respages[0]); + resv->iov_len = 0; + rqstp->rq_res.pages = rqstp->rq_respages + 1; + rqstp->rq_res.len = 0; + rqstp->rq_res.page_base = 0; + rqstp->rq_res.page_len = 0; + rqstp->rq_res.buflen = PAGE_SIZE; + rqstp->rq_res.tail[0].iov_base = NULL; + rqstp->rq_res.tail[0].iov_len = 0; + + dir = svc_getnl(argv); + if (dir != 0) { + /* direction != CALL */ + svc_printk(rqstp, "bad direction %d, dropping request\n", dir); + serv->sv_stats->rpcbadfmt++; + goto out_drop; + } + + /* Returns 1 for send, 0 for drop */ + if (likely(svc_process_common(rqstp, argv, resv))) + return svc_send(rqstp); + +out_drop: + svc_drop(rqstp); + return 0; +} +EXPORT_SYMBOL_GPL(svc_process); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/* + * Process a backchannel RPC request that arrived over an existing + * outbound connection + */ +int +bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req, + struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct rpc_task *task; + int proc_error; + int error; + + dprintk("svc: %s(%p)\n", __func__, req); + + /* Build the svc_rqst used by the common processing routine */ + rqstp->rq_xid = req->rq_xid; + rqstp->rq_prot = req->rq_xprt->prot; + rqstp->rq_server = serv; + rqstp->rq_bc_net = req->rq_xprt->xprt_net; + + rqstp->rq_addrlen = sizeof(req->rq_xprt->addr); + memcpy(&rqstp->rq_addr, &req->rq_xprt->addr, rqstp->rq_addrlen); + memcpy(&rqstp->rq_arg, &req->rq_rcv_buf, sizeof(rqstp->rq_arg)); + memcpy(&rqstp->rq_res, &req->rq_snd_buf, sizeof(rqstp->rq_res)); + + /* Adjust the argument buffer length */ + rqstp->rq_arg.len = req->rq_private_buf.len; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; + rqstp->rq_arg.page_len = 0; + } else if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len) + rqstp->rq_arg.page_len = rqstp->rq_arg.len - + rqstp->rq_arg.head[0].iov_len; + else + rqstp->rq_arg.len = rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len; + + /* reset result send buffer "put" position */ + resv->iov_len = 0; + + /* + * Skip the next two words because they've already been + * processed in the transport + */ + svc_getu32(argv); /* XID */ + svc_getnl(argv); /* CALLDIR */ + + /* Parse and execute the bc call */ + proc_error = svc_process_common(rqstp, argv, resv); + + atomic_inc(&req->rq_xprt->bc_free_slots); + if (!proc_error) { + /* Processing error: drop the request */ + xprt_free_bc_request(req); + return 0; + } + + /* Finally, send the reply synchronously */ + memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf)); + task = rpc_run_bc_task(req); + if (IS_ERR(task)) { + error = PTR_ERR(task); + goto out; + } + + WARN_ON_ONCE(atomic_read(&task->tk_count) != 1); + error = task->tk_status; + rpc_put_task(task); + +out: + dprintk("svc: %s(), error=%d\n", __func__, error); + return error; +} +EXPORT_SYMBOL_GPL(bc_svc_process); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/* + * Return (transport-specific) limit on the rpc payload. + */ +u32 svc_max_payload(const struct svc_rqst *rqstp) +{ + u32 max = rqstp->rq_xprt->xpt_class->xcl_max_payload; + + if (rqstp->rq_server->sv_max_payload < max) + max = rqstp->rq_server->sv_max_payload; + return max; +} +EXPORT_SYMBOL_GPL(svc_max_payload); + +/** + * svc_fill_write_vector - Construct data argument for VFS write call + * @rqstp: svc_rqst to operate on + * @pages: list of pages containing data payload + * @first: buffer containing first section of write payload + * @total: total number of bytes of write payload + * + * Fills in rqstp::rq_vec, and returns the number of elements. + */ +unsigned int svc_fill_write_vector(struct svc_rqst *rqstp, struct page **pages, + struct kvec *first, size_t total) +{ + struct kvec *vec = rqstp->rq_vec; + unsigned int i; + + /* Some types of transport can present the write payload + * entirely in rq_arg.pages. In this case, @first is empty. + */ + i = 0; + if (first->iov_len) { + vec[i].iov_base = first->iov_base; + vec[i].iov_len = min_t(size_t, total, first->iov_len); + total -= vec[i].iov_len; + ++i; + } + + while (total) { + vec[i].iov_base = page_address(*pages); + vec[i].iov_len = min_t(size_t, total, PAGE_SIZE); + total -= vec[i].iov_len; + ++i; + ++pages; + } + + WARN_ON_ONCE(i > ARRAY_SIZE(rqstp->rq_vec)); + return i; +} +EXPORT_SYMBOL_GPL(svc_fill_write_vector); + +/** + * svc_fill_symlink_pathname - Construct pathname argument for VFS symlink call + * @rqstp: svc_rqst to operate on + * @first: buffer containing first section of pathname + * @p: buffer containing remaining section of pathname + * @total: total length of the pathname argument + * + * The VFS symlink API demands a NUL-terminated pathname in mapped memory. + * Returns pointer to a NUL-terminated string, or an ERR_PTR. Caller must free + * the returned string. + */ +char *svc_fill_symlink_pathname(struct svc_rqst *rqstp, struct kvec *first, + void *p, size_t total) +{ + size_t len, remaining; + char *result, *dst; + + result = kmalloc(total + 1, GFP_KERNEL); + if (!result) + return ERR_PTR(-ESERVERFAULT); + + dst = result; + remaining = total; + + len = min_t(size_t, total, first->iov_len); + if (len) { + memcpy(dst, first->iov_base, len); + dst += len; + remaining -= len; + } + + if (remaining) { + len = min_t(size_t, remaining, PAGE_SIZE); + memcpy(dst, p, len); + dst += len; + } + + *dst = '\0'; + + /* Sanity check: Linux doesn't allow the pathname argument to + * contain a NUL byte. + */ + if (strlen(result) != total) { + kfree(result); + return ERR_PTR(-EINVAL); + } + return result; +} +EXPORT_SYMBOL_GPL(svc_fill_symlink_pathname); diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c new file mode 100644 index 000000000..4b56b949a --- /dev/null +++ b/net/sunrpc/svc_xprt.c @@ -0,0 +1,1425 @@ +/* + * linux/net/sunrpc/svc_xprt.c + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sched.h> +#include <linux/errno.h> +#include <linux/freezer.h> +#include <linux/kthread.h> +#include <linux/slab.h> +#include <net/sock.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/xprt.h> +#include <linux/module.h> +#include <linux/netdevice.h> +#include <trace/events/sunrpc.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static unsigned int svc_rpc_per_connection_limit __read_mostly; +module_param(svc_rpc_per_connection_limit, uint, 0644); + + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt); +static int svc_deferred_recv(struct svc_rqst *rqstp); +static struct cache_deferred_req *svc_defer(struct cache_req *req); +static void svc_age_temp_xprts(struct timer_list *t); +static void svc_delete_xprt(struct svc_xprt *xprt); + +/* apparently the "standard" is that clients close + * idle connections after 5 minutes, servers after + * 6 minutes + * http://www.connectathon.org/talks96/nfstcp.pdf + */ +static int svc_conn_age_period = 6*60; + +/* List of registered transport classes */ +static DEFINE_SPINLOCK(svc_xprt_class_lock); +static LIST_HEAD(svc_xprt_class_list); + +/* SMP locking strategy: + * + * svc_pool->sp_lock protects most of the fields of that pool. + * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt. + * when both need to be taken (rare), svc_serv->sv_lock is first. + * The "service mutex" protects svc_serv->sv_nrthread. + * svc_sock->sk_lock protects the svc_sock->sk_deferred list + * and the ->sk_info_authunix cache. + * + * The XPT_BUSY bit in xprt->xpt_flags prevents a transport being + * enqueued multiply. During normal transport processing this bit + * is set by svc_xprt_enqueue and cleared by svc_xprt_received. + * Providers should not manipulate this bit directly. + * + * Some flags can be set to certain values at any time + * providing that certain rules are followed: + * + * XPT_CONN, XPT_DATA: + * - Can be set or cleared at any time. + * - After a set, svc_xprt_enqueue must be called to enqueue + * the transport for processing. + * - After a clear, the transport must be read/accepted. + * If this succeeds, it must be set again. + * XPT_CLOSE: + * - Can set at any time. It is never cleared. + * XPT_DEAD: + * - Can only be set while XPT_BUSY is held which ensures + * that no other thread will be using the transport or will + * try to set XPT_DEAD. + */ +int svc_reg_xprt_class(struct svc_xprt_class *xcl) +{ + struct svc_xprt_class *cl; + int res = -EEXIST; + + dprintk("svc: Adding svc transport class '%s'\n", xcl->xcl_name); + + INIT_LIST_HEAD(&xcl->xcl_list); + spin_lock(&svc_xprt_class_lock); + /* Make sure there isn't already a class with the same name */ + list_for_each_entry(cl, &svc_xprt_class_list, xcl_list) { + if (strcmp(xcl->xcl_name, cl->xcl_name) == 0) + goto out; + } + list_add_tail(&xcl->xcl_list, &svc_xprt_class_list); + res = 0; +out: + spin_unlock(&svc_xprt_class_lock); + return res; +} +EXPORT_SYMBOL_GPL(svc_reg_xprt_class); + +void svc_unreg_xprt_class(struct svc_xprt_class *xcl) +{ + dprintk("svc: Removing svc transport class '%s'\n", xcl->xcl_name); + spin_lock(&svc_xprt_class_lock); + list_del_init(&xcl->xcl_list); + spin_unlock(&svc_xprt_class_lock); +} +EXPORT_SYMBOL_GPL(svc_unreg_xprt_class); + +/** + * svc_print_xprts - Format the transport list for printing + * @buf: target buffer for formatted address + * @maxlen: length of target buffer + * + * Fills in @buf with a string containing a list of transport names, each name + * terminated with '\n'. If the buffer is too small, some entries may be + * missing, but it is guaranteed that all lines in the output buffer are + * complete. + * + * Returns positive length of the filled-in string. + */ +int svc_print_xprts(char *buf, int maxlen) +{ + struct svc_xprt_class *xcl; + char tmpstr[80]; + int len = 0; + buf[0] = '\0'; + + spin_lock(&svc_xprt_class_lock); + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { + int slen; + + slen = snprintf(tmpstr, sizeof(tmpstr), "%s %d\n", + xcl->xcl_name, xcl->xcl_max_payload); + if (slen >= sizeof(tmpstr) || len + slen >= maxlen) + break; + len += slen; + strcat(buf, tmpstr); + } + spin_unlock(&svc_xprt_class_lock); + + return len; +} + +static void svc_xprt_free(struct kref *kref) +{ + struct svc_xprt *xprt = + container_of(kref, struct svc_xprt, xpt_ref); + struct module *owner = xprt->xpt_class->xcl_owner; + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) + svcauth_unix_info_release(xprt); + put_net(xprt->xpt_net); + /* See comment on corresponding get in xs_setup_bc_tcp(): */ + if (xprt->xpt_bc_xprt) + xprt_put(xprt->xpt_bc_xprt); + if (xprt->xpt_bc_xps) + xprt_switch_put(xprt->xpt_bc_xps); + xprt->xpt_ops->xpo_free(xprt); + module_put(owner); +} + +void svc_xprt_put(struct svc_xprt *xprt) +{ + kref_put(&xprt->xpt_ref, svc_xprt_free); +} +EXPORT_SYMBOL_GPL(svc_xprt_put); + +/* + * Called by transport drivers to initialize the transport independent + * portion of the transport instance. + */ +void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, + struct svc_xprt *xprt, struct svc_serv *serv) +{ + memset(xprt, 0, sizeof(*xprt)); + xprt->xpt_class = xcl; + xprt->xpt_ops = xcl->xcl_ops; + kref_init(&xprt->xpt_ref); + xprt->xpt_server = serv; + INIT_LIST_HEAD(&xprt->xpt_list); + INIT_LIST_HEAD(&xprt->xpt_ready); + INIT_LIST_HEAD(&xprt->xpt_deferred); + INIT_LIST_HEAD(&xprt->xpt_users); + mutex_init(&xprt->xpt_mutex); + spin_lock_init(&xprt->xpt_lock); + set_bit(XPT_BUSY, &xprt->xpt_flags); + rpc_init_wait_queue(&xprt->xpt_bc_pending, "xpt_bc_pending"); + xprt->xpt_net = get_net(net); + strcpy(xprt->xpt_remotebuf, "uninitialized"); +} +EXPORT_SYMBOL_GPL(svc_xprt_init); + +static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, + struct svc_serv *serv, + struct net *net, + const int family, + const unsigned short port, + int flags) +{ + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; +#endif + struct sockaddr *sap; + size_t len; + + switch (family) { + case PF_INET: + sap = (struct sockaddr *)&sin; + len = sizeof(sin); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + sap = (struct sockaddr *)&sin6; + len = sizeof(sin6); + break; +#endif + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + return xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); +} + +/* + * svc_xprt_received conditionally queues the transport for processing + * by another thread. The caller must hold the XPT_BUSY bit and must + * not thereafter touch transport data. + * + * Note: XPT_DATA only gets cleared when a read-attempt finds no (or + * insufficient) data. + */ +static void svc_xprt_received(struct svc_xprt *xprt) +{ + if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) { + WARN_ONCE(1, "xprt=0x%p already busy!", xprt); + return; + } + + /* As soon as we clear busy, the xprt could be closed and + * 'put', so we need a reference to call svc_enqueue_xprt with: + */ + svc_xprt_get(xprt); + smp_mb__before_atomic(); + clear_bit(XPT_BUSY, &xprt->xpt_flags); + xprt->xpt_server->sv_ops->svo_enqueue_xprt(xprt); + svc_xprt_put(xprt); +} + +void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) +{ + clear_bit(XPT_TEMP, &new->xpt_flags); + spin_lock_bh(&serv->sv_lock); + list_add(&new->xpt_list, &serv->sv_permsocks); + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(new); +} + +static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name, + struct net *net, const int family, + const unsigned short port, int flags) +{ + struct svc_xprt_class *xcl; + + spin_lock(&svc_xprt_class_lock); + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { + struct svc_xprt *newxprt; + unsigned short newport; + + if (strcmp(xprt_name, xcl->xcl_name)) + continue; + + if (!try_module_get(xcl->xcl_owner)) + goto err; + + spin_unlock(&svc_xprt_class_lock); + newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags); + if (IS_ERR(newxprt)) { + module_put(xcl->xcl_owner); + return PTR_ERR(newxprt); + } + svc_add_new_perm_xprt(serv, newxprt); + newport = svc_xprt_local_port(newxprt); + return newport; + } + err: + spin_unlock(&svc_xprt_class_lock); + /* This errno is exposed to user space. Provide a reasonable + * perror msg for a bad transport. */ + return -EPROTONOSUPPORT; +} + +int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, + struct net *net, const int family, + const unsigned short port, int flags) +{ + int err; + + dprintk("svc: creating transport %s[%d]\n", xprt_name, port); + err = _svc_create_xprt(serv, xprt_name, net, family, port, flags); + if (err == -EPROTONOSUPPORT) { + request_module("svc%s", xprt_name); + err = _svc_create_xprt(serv, xprt_name, net, family, port, flags); + } + if (err) + dprintk("svc: transport %s not found, err %d\n", + xprt_name, err); + return err; +} +EXPORT_SYMBOL_GPL(svc_create_xprt); + +/* + * Copy the local and remote xprt addresses to the rqstp structure + */ +void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + memcpy(&rqstp->rq_addr, &xprt->xpt_remote, xprt->xpt_remotelen); + rqstp->rq_addrlen = xprt->xpt_remotelen; + + /* + * Destination address in request is needed for binding the + * source address in RPC replies/callbacks later. + */ + memcpy(&rqstp->rq_daddr, &xprt->xpt_local, xprt->xpt_locallen); + rqstp->rq_daddrlen = xprt->xpt_locallen; +} +EXPORT_SYMBOL_GPL(svc_xprt_copy_addrs); + +/** + * svc_print_addr - Format rq_addr field for printing + * @rqstp: svc_rqst struct containing address to print + * @buf: target buffer for formatted address + * @len: length of target buffer + * + */ +char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len) +{ + return __svc_print_addr(svc_addr(rqstp), buf, len); +} +EXPORT_SYMBOL_GPL(svc_print_addr); + +static bool svc_xprt_slots_in_range(struct svc_xprt *xprt) +{ + unsigned int limit = svc_rpc_per_connection_limit; + int nrqsts = atomic_read(&xprt->xpt_nr_rqsts); + + return limit == 0 || (nrqsts >= 0 && nrqsts < limit); +} + +static bool svc_xprt_reserve_slot(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + if (!test_bit(RQ_DATA, &rqstp->rq_flags)) { + if (!svc_xprt_slots_in_range(xprt)) + return false; + atomic_inc(&xprt->xpt_nr_rqsts); + set_bit(RQ_DATA, &rqstp->rq_flags); + } + return true; +} + +static void svc_xprt_release_slot(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + if (test_and_clear_bit(RQ_DATA, &rqstp->rq_flags)) { + atomic_dec(&xprt->xpt_nr_rqsts); + svc_xprt_enqueue(xprt); + } +} + +static bool svc_xprt_has_something_to_do(struct svc_xprt *xprt) +{ + if (xprt->xpt_flags & ((1<<XPT_CONN)|(1<<XPT_CLOSE))) + return true; + if (xprt->xpt_flags & ((1<<XPT_DATA)|(1<<XPT_DEFERRED))) { + if (xprt->xpt_ops->xpo_has_wspace(xprt) && + svc_xprt_slots_in_range(xprt)) + return true; + trace_svc_xprt_no_write_space(xprt); + return false; + } + return false; +} + +void svc_xprt_do_enqueue(struct svc_xprt *xprt) +{ + struct svc_pool *pool; + struct svc_rqst *rqstp = NULL; + int cpu; + + if (!svc_xprt_has_something_to_do(xprt)) + return; + + /* Mark transport as busy. It will remain in this state until + * the provider calls svc_xprt_received. We update XPT_BUSY + * atomically because it also guards against trying to enqueue + * the transport twice. + */ + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + return; + + cpu = get_cpu(); + pool = svc_pool_for_cpu(xprt->xpt_server, cpu); + + atomic_long_inc(&pool->sp_stats.packets); + + spin_lock_bh(&pool->sp_lock); + list_add_tail(&xprt->xpt_ready, &pool->sp_sockets); + pool->sp_stats.sockets_queued++; + spin_unlock_bh(&pool->sp_lock); + + /* find a thread for this xprt */ + rcu_read_lock(); + list_for_each_entry_rcu(rqstp, &pool->sp_all_threads, rq_all) { + if (test_and_set_bit(RQ_BUSY, &rqstp->rq_flags)) + continue; + atomic_long_inc(&pool->sp_stats.threads_woken); + rqstp->rq_qtime = ktime_get(); + wake_up_process(rqstp->rq_task); + goto out_unlock; + } + set_bit(SP_CONGESTED, &pool->sp_flags); + rqstp = NULL; +out_unlock: + rcu_read_unlock(); + put_cpu(); + trace_svc_xprt_do_enqueue(xprt, rqstp); +} +EXPORT_SYMBOL_GPL(svc_xprt_do_enqueue); + +/* + * Queue up a transport with data pending. If there are idle nfsd + * processes, wake 'em up. + * + */ +void svc_xprt_enqueue(struct svc_xprt *xprt) +{ + if (test_bit(XPT_BUSY, &xprt->xpt_flags)) + return; + xprt->xpt_server->sv_ops->svo_enqueue_xprt(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_enqueue); + +/* + * Dequeue the first transport, if there is one. + */ +static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) +{ + struct svc_xprt *xprt = NULL; + + if (list_empty(&pool->sp_sockets)) + goto out; + + spin_lock_bh(&pool->sp_lock); + if (likely(!list_empty(&pool->sp_sockets))) { + xprt = list_first_entry(&pool->sp_sockets, + struct svc_xprt, xpt_ready); + list_del_init(&xprt->xpt_ready); + svc_xprt_get(xprt); + } + spin_unlock_bh(&pool->sp_lock); +out: + return xprt; +} + +/** + * svc_reserve - change the space reserved for the reply to a request. + * @rqstp: The request in question + * @space: new max space to reserve + * + * Each request reserves some space on the output queue of the transport + * to make sure the reply fits. This function reduces that reserved + * space to be the amount of space used already, plus @space. + * + */ +void svc_reserve(struct svc_rqst *rqstp, int space) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + space += rqstp->rq_res.head[0].iov_len; + + if (xprt && space < rqstp->rq_reserved) { + atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved); + rqstp->rq_reserved = space; + + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL_GPL(svc_reserve); + +static void svc_xprt_release(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + xprt->xpt_ops->xpo_release_rqst(rqstp); + + kfree(rqstp->rq_deferred); + rqstp->rq_deferred = NULL; + + svc_free_res_pages(rqstp); + rqstp->rq_res.page_len = 0; + rqstp->rq_res.page_base = 0; + + /* Reset response buffer and release + * the reservation. + * But first, check that enough space was reserved + * for the reply, otherwise we have a bug! + */ + if ((rqstp->rq_res.len) > rqstp->rq_reserved) + printk(KERN_ERR "RPC request reserved %d but used %d\n", + rqstp->rq_reserved, + rqstp->rq_res.len); + + rqstp->rq_res.head[0].iov_len = 0; + svc_reserve(rqstp, 0); + svc_xprt_release_slot(rqstp); + rqstp->rq_xprt = NULL; + svc_xprt_put(xprt); +} + +/* + * Some svc_serv's will have occasional work to do, even when a xprt is not + * waiting to be serviced. This function is there to "kick" a task in one of + * those services so that it can wake up and do that work. Note that we only + * bother with pool 0 as we don't need to wake up more than one thread for + * this purpose. + */ +void svc_wake_up(struct svc_serv *serv) +{ + struct svc_rqst *rqstp; + struct svc_pool *pool; + + pool = &serv->sv_pools[0]; + + rcu_read_lock(); + list_for_each_entry_rcu(rqstp, &pool->sp_all_threads, rq_all) { + /* skip any that aren't queued */ + if (test_bit(RQ_BUSY, &rqstp->rq_flags)) + continue; + rcu_read_unlock(); + wake_up_process(rqstp->rq_task); + trace_svc_wake_up(rqstp->rq_task->pid); + return; + } + rcu_read_unlock(); + + /* No free entries available */ + set_bit(SP_TASK_PENDING, &pool->sp_flags); + smp_wmb(); + trace_svc_wake_up(0); +} +EXPORT_SYMBOL_GPL(svc_wake_up); + +int svc_port_is_privileged(struct sockaddr *sin) +{ + switch (sin->sa_family) { + case AF_INET: + return ntohs(((struct sockaddr_in *)sin)->sin_port) + < PROT_SOCK; + case AF_INET6: + return ntohs(((struct sockaddr_in6 *)sin)->sin6_port) + < PROT_SOCK; + default: + return 0; + } +} + +/* + * Make sure that we don't have too many active connections. If we have, + * something must be dropped. It's not clear what will happen if we allow + * "too many" connections, but when dealing with network-facing software, + * we have to code defensively. Here we do that by imposing hard limits. + * + * There's no point in trying to do random drop here for DoS + * prevention. The NFS clients does 1 reconnect in 15 seconds. An + * attacker can easily beat that. + * + * The only somewhat efficient mechanism would be if drop old + * connections from the same IP first. But right now we don't even + * record the client IP in svc_sock. + * + * single-threaded services that expect a lot of clients will probably + * need to set sv_maxconn to override the default value which is based + * on the number of threads + */ +static void svc_check_conn_limits(struct svc_serv *serv) +{ + unsigned int limit = serv->sv_maxconn ? serv->sv_maxconn : + (serv->sv_nrthreads+3) * 20; + + if (serv->sv_tmpcnt > limit) { + struct svc_xprt *xprt = NULL; + spin_lock_bh(&serv->sv_lock); + if (!list_empty(&serv->sv_tempsocks)) { + /* Try to help the admin */ + net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n", + serv->sv_name, serv->sv_maxconn ? + "max number of connections" : + "number of threads"); + /* + * Always select the oldest connection. It's not fair, + * but so is life + */ + xprt = list_entry(serv->sv_tempsocks.prev, + struct svc_xprt, + xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_get(xprt); + } + spin_unlock_bh(&serv->sv_lock); + + if (xprt) { + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + } + } +} + +static int svc_alloc_arg(struct svc_rqst *rqstp) +{ + struct svc_serv *serv = rqstp->rq_server; + struct xdr_buf *arg; + int pages; + int i; + + /* now allocate needed pages. If we get a failure, sleep briefly */ + pages = (serv->sv_max_mesg + 2 * PAGE_SIZE) >> PAGE_SHIFT; + if (pages > RPCSVC_MAXPAGES) { + pr_warn_once("svc: warning: pages=%u > RPCSVC_MAXPAGES=%lu\n", + pages, RPCSVC_MAXPAGES); + /* use as many pages as possible */ + pages = RPCSVC_MAXPAGES; + } + for (i = 0; i < pages ; i++) + while (rqstp->rq_pages[i] == NULL) { + struct page *p = alloc_page(GFP_KERNEL); + if (!p) { + set_current_state(TASK_INTERRUPTIBLE); + if (signalled() || kthread_should_stop()) { + set_current_state(TASK_RUNNING); + return -EINTR; + } + schedule_timeout(msecs_to_jiffies(500)); + } + rqstp->rq_pages[i] = p; + } + rqstp->rq_page_end = &rqstp->rq_pages[i]; + rqstp->rq_pages[i++] = NULL; /* this might be seen in nfs_read_actor */ + + /* Make arg->head point to first page and arg->pages point to rest */ + arg = &rqstp->rq_arg; + arg->head[0].iov_base = page_address(rqstp->rq_pages[0]); + arg->head[0].iov_len = PAGE_SIZE; + arg->pages = rqstp->rq_pages + 1; + arg->page_base = 0; + /* save at least one page for response */ + arg->page_len = (pages-2)*PAGE_SIZE; + arg->len = (pages-1)*PAGE_SIZE; + arg->tail[0].iov_len = 0; + return 0; +} + +static bool +rqst_should_sleep(struct svc_rqst *rqstp) +{ + struct svc_pool *pool = rqstp->rq_pool; + + /* did someone call svc_wake_up? */ + if (test_and_clear_bit(SP_TASK_PENDING, &pool->sp_flags)) + return false; + + /* was a socket queued? */ + if (!list_empty(&pool->sp_sockets)) + return false; + + /* are we shutting down? */ + if (signalled() || kthread_should_stop()) + return false; + + /* are we freezing? */ + if (freezing(current)) + return false; + + return true; +} + +static struct svc_xprt *svc_get_next_xprt(struct svc_rqst *rqstp, long timeout) +{ + struct svc_pool *pool = rqstp->rq_pool; + long time_left = 0; + + /* rq_xprt should be clear on entry */ + WARN_ON_ONCE(rqstp->rq_xprt); + + rqstp->rq_xprt = svc_xprt_dequeue(pool); + if (rqstp->rq_xprt) + goto out_found; + + /* + * We have to be able to interrupt this wait + * to bring down the daemons ... + */ + set_current_state(TASK_INTERRUPTIBLE); + smp_mb__before_atomic(); + clear_bit(SP_CONGESTED, &pool->sp_flags); + clear_bit(RQ_BUSY, &rqstp->rq_flags); + smp_mb__after_atomic(); + + if (likely(rqst_should_sleep(rqstp))) + time_left = schedule_timeout(timeout); + else + __set_current_state(TASK_RUNNING); + + try_to_freeze(); + + set_bit(RQ_BUSY, &rqstp->rq_flags); + smp_mb__after_atomic(); + rqstp->rq_xprt = svc_xprt_dequeue(pool); + if (rqstp->rq_xprt) + goto out_found; + + if (!time_left) + atomic_long_inc(&pool->sp_stats.threads_timedout); + + if (signalled() || kthread_should_stop()) + return ERR_PTR(-EINTR); + return ERR_PTR(-EAGAIN); +out_found: + /* Normally we will wait up to 5 seconds for any required + * cache information to be provided. + */ + if (!test_bit(SP_CONGESTED, &pool->sp_flags)) + rqstp->rq_chandle.thread_wait = 5*HZ; + else + rqstp->rq_chandle.thread_wait = 1*HZ; + trace_svc_xprt_dequeue(rqstp); + return rqstp->rq_xprt; +} + +static void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) +{ + spin_lock_bh(&serv->sv_lock); + set_bit(XPT_TEMP, &newxpt->xpt_flags); + list_add(&newxpt->xpt_list, &serv->sv_tempsocks); + serv->sv_tmpcnt++; + if (serv->sv_temptimer.function == NULL) { + /* setup timer to age temp transports */ + serv->sv_temptimer.function = svc_age_temp_xprts; + mod_timer(&serv->sv_temptimer, + jiffies + svc_conn_age_period * HZ); + } + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(newxpt); +} + +static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + struct svc_serv *serv = rqstp->rq_server; + int len = 0; + + if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) { + dprintk("svc_recv: found XPT_CLOSE\n"); + if (test_and_clear_bit(XPT_KILL_TEMP, &xprt->xpt_flags)) + xprt->xpt_ops->xpo_kill_temp_xprt(xprt); + svc_delete_xprt(xprt); + /* Leave XPT_BUSY set on the dead xprt: */ + goto out; + } + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + struct svc_xprt *newxpt; + /* + * We know this module_get will succeed because the + * listener holds a reference too + */ + __module_get(xprt->xpt_class->xcl_owner); + svc_check_conn_limits(xprt->xpt_server); + newxpt = xprt->xpt_ops->xpo_accept(xprt); + if (newxpt) + svc_add_new_temp_xprt(serv, newxpt); + else + module_put(xprt->xpt_class->xcl_owner); + } else if (svc_xprt_reserve_slot(rqstp, xprt)) { + /* XPT_DATA|XPT_DEFERRED case: */ + dprintk("svc: server %p, pool %u, transport %p, inuse=%d\n", + rqstp, rqstp->rq_pool->sp_id, xprt, + kref_read(&xprt->xpt_ref)); + rqstp->rq_deferred = svc_deferred_dequeue(xprt); + if (rqstp->rq_deferred) + len = svc_deferred_recv(rqstp); + else + len = xprt->xpt_ops->xpo_recvfrom(rqstp); + rqstp->rq_stime = ktime_get(); + rqstp->rq_reserved = serv->sv_max_mesg; + atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); + } + /* clear XPT_BUSY: */ + svc_xprt_received(xprt); +out: + trace_svc_handle_xprt(xprt, len); + return len; +} + +/* + * Receive the next request on any transport. This code is carefully + * organised not to touch any cachelines in the shared svc_serv + * structure, only cachelines in the local svc_pool. + */ +int svc_recv(struct svc_rqst *rqstp, long timeout) +{ + struct svc_xprt *xprt = NULL; + struct svc_serv *serv = rqstp->rq_server; + int len, err; + + dprintk("svc: server %p waiting for data (to = %ld)\n", + rqstp, timeout); + + if (rqstp->rq_xprt) + printk(KERN_ERR + "svc_recv: service %p, transport not NULL!\n", + rqstp); + + err = svc_alloc_arg(rqstp); + if (err) + goto out; + + try_to_freeze(); + cond_resched(); + err = -EINTR; + if (signalled() || kthread_should_stop()) + goto out; + + xprt = svc_get_next_xprt(rqstp, timeout); + if (IS_ERR(xprt)) { + err = PTR_ERR(xprt); + goto out; + } + + len = svc_handle_xprt(rqstp, xprt); + + /* No data, incomplete (TCP) read, or accept() */ + err = -EAGAIN; + if (len <= 0) + goto out_release; + + clear_bit(XPT_OLD, &xprt->xpt_flags); + + xprt->xpt_ops->xpo_secure_port(rqstp); + rqstp->rq_chandle.defer = svc_defer; + rqstp->rq_xid = svc_getu32(&rqstp->rq_arg.head[0]); + + if (serv->sv_stats) + serv->sv_stats->netcnt++; + trace_svc_recv(rqstp, len); + return len; +out_release: + rqstp->rq_res.len = 0; + svc_xprt_release(rqstp); +out: + return err; +} +EXPORT_SYMBOL_GPL(svc_recv); + +/* + * Drop request + */ +void svc_drop(struct svc_rqst *rqstp) +{ + trace_svc_drop(rqstp); + dprintk("svc: xprt %p dropped request\n", rqstp->rq_xprt); + svc_xprt_release(rqstp); +} +EXPORT_SYMBOL_GPL(svc_drop); + +/* + * Return reply to client. + */ +int svc_send(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt; + int len = -EFAULT; + struct xdr_buf *xb; + + xprt = rqstp->rq_xprt; + if (!xprt) + goto out; + + /* calculate over-all length */ + xb = &rqstp->rq_res; + xb->len = xb->head[0].iov_len + + xb->page_len + + xb->tail[0].iov_len; + + /* Grab mutex to serialize outgoing data. */ + mutex_lock(&xprt->xpt_mutex); + trace_svc_stats_latency(rqstp); + if (test_bit(XPT_DEAD, &xprt->xpt_flags) + || test_bit(XPT_CLOSE, &xprt->xpt_flags)) + len = -ENOTCONN; + else + len = xprt->xpt_ops->xpo_sendto(rqstp); + mutex_unlock(&xprt->xpt_mutex); + rpc_wake_up(&xprt->xpt_bc_pending); + trace_svc_send(rqstp, len); + svc_xprt_release(rqstp); + + if (len == -ECONNREFUSED || len == -ENOTCONN || len == -EAGAIN) + len = 0; +out: + return len; +} + +/* + * Timer function to close old temporary transports, using + * a mark-and-sweep algorithm. + */ +static void svc_age_temp_xprts(struct timer_list *t) +{ + struct svc_serv *serv = from_timer(serv, t, sv_temptimer); + struct svc_xprt *xprt; + struct list_head *le, *next; + + dprintk("svc_age_temp_xprts\n"); + + if (!spin_trylock_bh(&serv->sv_lock)) { + /* busy, try again 1 sec later */ + dprintk("svc_age_temp_xprts: busy\n"); + mod_timer(&serv->sv_temptimer, jiffies + HZ); + return; + } + + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + + /* First time through, just mark it OLD. Second time + * through, close it. */ + if (!test_and_set_bit(XPT_OLD, &xprt->xpt_flags)) + continue; + if (kref_read(&xprt->xpt_ref) > 1 || + test_bit(XPT_BUSY, &xprt->xpt_flags)) + continue; + list_del_init(le); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + dprintk("queuing xprt %p for closing\n", xprt); + + /* a thread will dequeue and close it soon */ + svc_xprt_enqueue(xprt); + } + spin_unlock_bh(&serv->sv_lock); + + mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); +} + +/* Close temporary transports whose xpt_local matches server_addr immediately + * instead of waiting for them to be picked up by the timer. + * + * This is meant to be called from a notifier_block that runs when an ip + * address is deleted. + */ +void svc_age_temp_xprts_now(struct svc_serv *serv, struct sockaddr *server_addr) +{ + struct svc_xprt *xprt; + struct list_head *le, *next; + LIST_HEAD(to_be_closed); + + spin_lock_bh(&serv->sv_lock); + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + if (rpc_cmp_addr(server_addr, (struct sockaddr *) + &xprt->xpt_local)) { + dprintk("svc_age_temp_xprts_now: found %p\n", xprt); + list_move(le, &to_be_closed); + } + } + spin_unlock_bh(&serv->sv_lock); + + while (!list_empty(&to_be_closed)) { + le = to_be_closed.next; + list_del_init(le); + xprt = list_entry(le, struct svc_xprt, xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + set_bit(XPT_KILL_TEMP, &xprt->xpt_flags); + dprintk("svc_age_temp_xprts_now: queuing xprt %p for closing\n", + xprt); + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL_GPL(svc_age_temp_xprts_now); + +static void call_xpt_users(struct svc_xprt *xprt) +{ + struct svc_xpt_user *u; + + spin_lock(&xprt->xpt_lock); + while (!list_empty(&xprt->xpt_users)) { + u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list); + list_del_init(&u->list); + u->callback(u); + } + spin_unlock(&xprt->xpt_lock); +} + +/* + * Remove a dead transport + */ +static void svc_delete_xprt(struct svc_xprt *xprt) +{ + struct svc_serv *serv = xprt->xpt_server; + struct svc_deferred_req *dr; + + /* Only do this once */ + if (test_and_set_bit(XPT_DEAD, &xprt->xpt_flags)) + BUG(); + + dprintk("svc: svc_delete_xprt(%p)\n", xprt); + xprt->xpt_ops->xpo_detach(xprt); + + spin_lock_bh(&serv->sv_lock); + list_del_init(&xprt->xpt_list); + WARN_ON_ONCE(!list_empty(&xprt->xpt_ready)); + if (test_bit(XPT_TEMP, &xprt->xpt_flags)) + serv->sv_tmpcnt--; + spin_unlock_bh(&serv->sv_lock); + + while ((dr = svc_deferred_dequeue(xprt)) != NULL) + kfree(dr); + + call_xpt_users(xprt); + svc_xprt_put(xprt); +} + +void svc_close_xprt(struct svc_xprt *xprt) +{ + set_bit(XPT_CLOSE, &xprt->xpt_flags); + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + /* someone else will have to effect the close */ + return; + /* + * We expect svc_close_xprt() to work even when no threads are + * running (e.g., while configuring the server before starting + * any threads), so if the transport isn't busy, we delete + * it ourself: + */ + svc_delete_xprt(xprt); +} +EXPORT_SYMBOL_GPL(svc_close_xprt); + +static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, struct net *net) +{ + struct svc_xprt *xprt; + int ret = 0; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, xprt_list, xpt_list) { + if (xprt->xpt_net != net) + continue; + ret++; + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + } + spin_unlock_bh(&serv->sv_lock); + return ret; +} + +static struct svc_xprt *svc_dequeue_net(struct svc_serv *serv, struct net *net) +{ + struct svc_pool *pool; + struct svc_xprt *xprt; + struct svc_xprt *tmp; + int i; + + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[i]; + + spin_lock_bh(&pool->sp_lock); + list_for_each_entry_safe(xprt, tmp, &pool->sp_sockets, xpt_ready) { + if (xprt->xpt_net != net) + continue; + list_del_init(&xprt->xpt_ready); + spin_unlock_bh(&pool->sp_lock); + return xprt; + } + spin_unlock_bh(&pool->sp_lock); + } + return NULL; +} + +static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) +{ + struct svc_xprt *xprt; + + while ((xprt = svc_dequeue_net(serv, net))) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_delete_xprt(xprt); + } +} + +/* + * Server threads may still be running (especially in the case where the + * service is still running in other network namespaces). + * + * So we shut down sockets the same way we would on a running server, by + * setting XPT_CLOSE, enqueuing, and letting a thread pick it up to do + * the close. In the case there are no such other threads, + * threads running, svc_clean_up_xprts() does a simple version of a + * server's main event loop, and in the case where there are other + * threads, we may need to wait a little while and then check again to + * see if they're done. + */ +void svc_close_net(struct svc_serv *serv, struct net *net) +{ + int delay = 0; + + while (svc_close_list(serv, &serv->sv_permsocks, net) + + svc_close_list(serv, &serv->sv_tempsocks, net)) { + + svc_clean_up_xprts(serv, net); + msleep(delay++); + } +} + +/* + * Handle defer and revisit of requests + */ + +static void svc_revisit(struct cache_deferred_req *dreq, int too_many) +{ + struct svc_deferred_req *dr = + container_of(dreq, struct svc_deferred_req, handle); + struct svc_xprt *xprt = dr->xprt; + + spin_lock(&xprt->xpt_lock); + set_bit(XPT_DEFERRED, &xprt->xpt_flags); + if (too_many || test_bit(XPT_DEAD, &xprt->xpt_flags)) { + spin_unlock(&xprt->xpt_lock); + dprintk("revisit canceled\n"); + svc_xprt_put(xprt); + trace_svc_drop_deferred(dr); + kfree(dr); + return; + } + dprintk("revisit queued\n"); + dr->xprt = NULL; + list_add(&dr->handle.recent, &xprt->xpt_deferred); + spin_unlock(&xprt->xpt_lock); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); +} + +/* + * Save the request off for later processing. The request buffer looks + * like this: + * + * <xprt-header><rpc-header><rpc-pagelist><rpc-tail> + * + * This code can only handle requests that consist of an xprt-header + * and rpc-header. + */ +static struct cache_deferred_req *svc_defer(struct cache_req *req) +{ + struct svc_rqst *rqstp = container_of(req, struct svc_rqst, rq_chandle); + struct svc_deferred_req *dr; + + if (rqstp->rq_arg.page_len || !test_bit(RQ_USEDEFERRAL, &rqstp->rq_flags)) + return NULL; /* if more than a page, give up FIXME */ + if (rqstp->rq_deferred) { + dr = rqstp->rq_deferred; + rqstp->rq_deferred = NULL; + } else { + size_t skip; + size_t size; + /* FIXME maybe discard if size too large */ + size = sizeof(struct svc_deferred_req) + rqstp->rq_arg.len; + dr = kmalloc(size, GFP_KERNEL); + if (dr == NULL) + return NULL; + + dr->handle.owner = rqstp->rq_server; + dr->prot = rqstp->rq_prot; + memcpy(&dr->addr, &rqstp->rq_addr, rqstp->rq_addrlen); + dr->addrlen = rqstp->rq_addrlen; + dr->daddr = rqstp->rq_daddr; + dr->argslen = rqstp->rq_arg.len >> 2; + dr->xprt_hlen = rqstp->rq_xprt_hlen; + + /* back up head to the start of the buffer and copy */ + skip = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; + memcpy(dr->args, rqstp->rq_arg.head[0].iov_base - skip, + dr->argslen << 2); + } + svc_xprt_get(rqstp->rq_xprt); + dr->xprt = rqstp->rq_xprt; + set_bit(RQ_DROPME, &rqstp->rq_flags); + + dr->handle.revisit = svc_revisit; + trace_svc_defer(rqstp); + return &dr->handle; +} + +/* + * recv data from a deferred request into an active one + */ +static int svc_deferred_recv(struct svc_rqst *rqstp) +{ + struct svc_deferred_req *dr = rqstp->rq_deferred; + + /* setup iov_base past transport header */ + rqstp->rq_arg.head[0].iov_base = dr->args + (dr->xprt_hlen>>2); + /* The iov_len does not include the transport header bytes */ + rqstp->rq_arg.head[0].iov_len = (dr->argslen<<2) - dr->xprt_hlen; + rqstp->rq_arg.page_len = 0; + /* The rq_arg.len includes the transport header bytes */ + rqstp->rq_arg.len = dr->argslen<<2; + rqstp->rq_prot = dr->prot; + memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen); + rqstp->rq_addrlen = dr->addrlen; + /* Save off transport header len in case we get deferred again */ + rqstp->rq_xprt_hlen = dr->xprt_hlen; + rqstp->rq_daddr = dr->daddr; + rqstp->rq_respages = rqstp->rq_pages; + return (dr->argslen<<2) - dr->xprt_hlen; +} + + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) +{ + struct svc_deferred_req *dr = NULL; + + if (!test_bit(XPT_DEFERRED, &xprt->xpt_flags)) + return NULL; + spin_lock(&xprt->xpt_lock); + if (!list_empty(&xprt->xpt_deferred)) { + dr = list_entry(xprt->xpt_deferred.next, + struct svc_deferred_req, + handle.recent); + list_del_init(&dr->handle.recent); + trace_svc_revisit_deferred(dr); + } else + clear_bit(XPT_DEFERRED, &xprt->xpt_flags); + spin_unlock(&xprt->xpt_lock); + return dr; +} + +/** + * svc_find_xprt - find an RPC transport instance + * @serv: pointer to svc_serv to search + * @xcl_name: C string containing transport's class name + * @net: owner net pointer + * @af: Address family of transport's local address + * @port: transport's IP port number + * + * Return the transport instance pointer for the endpoint accepting + * connections/peer traffic from the specified transport class, + * address family and port. + * + * Specifying 0 for the address family or port is effectively a + * wild-card, and will result in matching the first transport in the + * service's list that has a matching class name. + */ +struct svc_xprt *svc_find_xprt(struct svc_serv *serv, const char *xcl_name, + struct net *net, const sa_family_t af, + const unsigned short port) +{ + struct svc_xprt *xprt; + struct svc_xprt *found = NULL; + + /* Sanity check the args */ + if (serv == NULL || xcl_name == NULL) + return found; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (xprt->xpt_net != net) + continue; + if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) + continue; + if (af != AF_UNSPEC && af != xprt->xpt_local.ss_family) + continue; + if (port != 0 && port != svc_xprt_local_port(xprt)) + continue; + found = xprt; + svc_xprt_get(xprt); + break; + } + spin_unlock_bh(&serv->sv_lock); + return found; +} +EXPORT_SYMBOL_GPL(svc_find_xprt); + +static int svc_one_xprt_name(const struct svc_xprt *xprt, + char *pos, int remaining) +{ + int len; + + len = snprintf(pos, remaining, "%s %u\n", + xprt->xpt_class->xcl_name, + svc_xprt_local_port(xprt)); + if (len >= remaining) + return -ENAMETOOLONG; + return len; +} + +/** + * svc_xprt_names - format a buffer with a list of transport names + * @serv: pointer to an RPC service + * @buf: pointer to a buffer to be filled in + * @buflen: length of buffer to be filled in + * + * Fills in @buf with a string containing a list of transport names, + * each name terminated with '\n'. + * + * Returns positive length of the filled-in string on success; otherwise + * a negative errno value is returned if an error occurs. + */ +int svc_xprt_names(struct svc_serv *serv, char *buf, const int buflen) +{ + struct svc_xprt *xprt; + int len, totlen; + char *pos; + + /* Sanity check args */ + if (!serv) + return 0; + + spin_lock_bh(&serv->sv_lock); + + pos = buf; + totlen = 0; + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + len = svc_one_xprt_name(xprt, pos, buflen - totlen); + if (len < 0) { + *buf = '\0'; + totlen = len; + } + if (len <= 0) + break; + + pos += len; + totlen += len; + } + + spin_unlock_bh(&serv->sv_lock); + return totlen; +} +EXPORT_SYMBOL_GPL(svc_xprt_names); + + +/*----------------------------------------------------------------------------*/ + +static void *svc_pool_stats_start(struct seq_file *m, loff_t *pos) +{ + unsigned int pidx = (unsigned int)*pos; + struct svc_serv *serv = m->private; + + dprintk("svc_pool_stats_start, *pidx=%u\n", pidx); + + if (!pidx) + return SEQ_START_TOKEN; + return (pidx > serv->sv_nrpools ? NULL : &serv->sv_pools[pidx-1]); +} + +static void *svc_pool_stats_next(struct seq_file *m, void *p, loff_t *pos) +{ + struct svc_pool *pool = p; + struct svc_serv *serv = m->private; + + dprintk("svc_pool_stats_next, *pos=%llu\n", *pos); + + if (p == SEQ_START_TOKEN) { + pool = &serv->sv_pools[0]; + } else { + unsigned int pidx = (pool - &serv->sv_pools[0]); + if (pidx < serv->sv_nrpools-1) + pool = &serv->sv_pools[pidx+1]; + else + pool = NULL; + } + ++*pos; + return pool; +} + +static void svc_pool_stats_stop(struct seq_file *m, void *p) +{ +} + +static int svc_pool_stats_show(struct seq_file *m, void *p) +{ + struct svc_pool *pool = p; + + if (p == SEQ_START_TOKEN) { + seq_puts(m, "# pool packets-arrived sockets-enqueued threads-woken threads-timedout\n"); + return 0; + } + + seq_printf(m, "%u %lu %lu %lu %lu\n", + pool->sp_id, + (unsigned long)atomic_long_read(&pool->sp_stats.packets), + pool->sp_stats.sockets_queued, + (unsigned long)atomic_long_read(&pool->sp_stats.threads_woken), + (unsigned long)atomic_long_read(&pool->sp_stats.threads_timedout)); + + return 0; +} + +static const struct seq_operations svc_pool_stats_seq_ops = { + .start = svc_pool_stats_start, + .next = svc_pool_stats_next, + .stop = svc_pool_stats_stop, + .show = svc_pool_stats_show, +}; + +int svc_pool_stats_open(struct svc_serv *serv, struct file *file) +{ + int err; + + err = seq_open(file, &svc_pool_stats_seq_ops); + if (!err) + ((struct seq_file *) file->private_data)->private = serv; + return err; +} +EXPORT_SYMBOL(svc_pool_stats_open); + +/*----------------------------------------------------------------------------*/ diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c new file mode 100644 index 000000000..bb8db3cb8 --- /dev/null +++ b/net/sunrpc/svcauth.c @@ -0,0 +1,172 @@ +/* + * linux/net/sunrpc/svcauth.c + * + * The generic interface for RPC authentication on the server side. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + * + * CHANGES + * 19-Apr-2000 Chris Evans - Security fix + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/err.h> +#include <linux/hash.h> + +#define RPCDBG_FACILITY RPCDBG_AUTH + + +/* + * Table of authenticators + */ +extern struct auth_ops svcauth_null; +extern struct auth_ops svcauth_unix; + +static DEFINE_SPINLOCK(authtab_lock); +static struct auth_ops *authtab[RPC_AUTH_MAXFLAVOR] = { + [0] = &svcauth_null, + [1] = &svcauth_unix, +}; + +int +svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) +{ + rpc_authflavor_t flavor; + struct auth_ops *aops; + + *authp = rpc_auth_ok; + + flavor = svc_getnl(&rqstp->rq_arg.head[0]); + + dprintk("svc: svc_authenticate (%d)\n", flavor); + + spin_lock(&authtab_lock); + if (flavor >= RPC_AUTH_MAXFLAVOR || !(aops = authtab[flavor]) || + !try_module_get(aops->owner)) { + spin_unlock(&authtab_lock); + *authp = rpc_autherr_badcred; + return SVC_DENIED; + } + spin_unlock(&authtab_lock); + + rqstp->rq_auth_slack = 0; + init_svc_cred(&rqstp->rq_cred); + + rqstp->rq_authop = aops; + return aops->accept(rqstp, authp); +} +EXPORT_SYMBOL_GPL(svc_authenticate); + +int svc_set_client(struct svc_rqst *rqstp) +{ + rqstp->rq_client = NULL; + return rqstp->rq_authop->set_client(rqstp); +} +EXPORT_SYMBOL_GPL(svc_set_client); + +/* A request, which was authenticated, has now executed. + * Time to finalise the credentials and verifier + * and release and resources + */ +int svc_authorise(struct svc_rqst *rqstp) +{ + struct auth_ops *aops = rqstp->rq_authop; + int rv = 0; + + rqstp->rq_authop = NULL; + + if (aops) { + rv = aops->release(rqstp); + module_put(aops->owner); + } + return rv; +} + +int +svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops) +{ + int rv = -EINVAL; + spin_lock(&authtab_lock); + if (flavor < RPC_AUTH_MAXFLAVOR && authtab[flavor] == NULL) { + authtab[flavor] = aops; + rv = 0; + } + spin_unlock(&authtab_lock); + return rv; +} +EXPORT_SYMBOL_GPL(svc_auth_register); + +void +svc_auth_unregister(rpc_authflavor_t flavor) +{ + spin_lock(&authtab_lock); + if (flavor < RPC_AUTH_MAXFLAVOR) + authtab[flavor] = NULL; + spin_unlock(&authtab_lock); +} +EXPORT_SYMBOL_GPL(svc_auth_unregister); + +/************************************************** + * 'auth_domains' are stored in a hash table indexed by name. + * When the last reference to an 'auth_domain' is dropped, + * the object is unhashed and freed. + * If auth_domain_lookup fails to find an entry, it will return + * it's second argument 'new'. If this is non-null, it will + * have been atomically linked into the table. + */ + +#define DN_HASHBITS 6 +#define DN_HASHMAX (1<<DN_HASHBITS) + +static struct hlist_head auth_domain_table[DN_HASHMAX]; +static DEFINE_SPINLOCK(auth_domain_lock); + +static void auth_domain_release(struct kref *kref) +{ + struct auth_domain *dom = container_of(kref, struct auth_domain, ref); + + hlist_del(&dom->hash); + dom->flavour->domain_release(dom); + spin_unlock(&auth_domain_lock); +} + +void auth_domain_put(struct auth_domain *dom) +{ + kref_put_lock(&dom->ref, auth_domain_release, &auth_domain_lock); +} +EXPORT_SYMBOL_GPL(auth_domain_put); + +struct auth_domain * +auth_domain_lookup(char *name, struct auth_domain *new) +{ + struct auth_domain *hp; + struct hlist_head *head; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + spin_lock(&auth_domain_lock); + + hlist_for_each_entry(hp, head, hash) { + if (strcmp(hp->name, name)==0) { + kref_get(&hp->ref); + spin_unlock(&auth_domain_lock); + return hp; + } + } + if (new) + hlist_add_head(&new->hash, head); + spin_unlock(&auth_domain_lock); + return new; +} +EXPORT_SYMBOL_GPL(auth_domain_lookup); + +struct auth_domain *auth_domain_find(char *name) +{ + return auth_domain_lookup(name, NULL); +} +EXPORT_SYMBOL_GPL(auth_domain_find); diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c new file mode 100644 index 000000000..af7f28fb8 --- /dev/null +++ b/net/sunrpc/svcauth_unix.c @@ -0,0 +1,908 @@ +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/addr.h> +#include <linux/err.h> +#include <linux/seq_file.h> +#include <linux/hash.h> +#include <linux/string.h> +#include <linux/slab.h> +#include <net/sock.h> +#include <net/ipv6.h> +#include <linux/kernel.h> +#include <linux/user_namespace.h> +#define RPCDBG_FACILITY RPCDBG_AUTH + + +#include "netns.h" + +/* + * AUTHUNIX and AUTHNULL credentials are both handled here. + * AUTHNULL is treated just like AUTHUNIX except that the uid/gid + * are always nobody (-2). i.e. we do the same IP address checks for + * AUTHNULL as for AUTHUNIX, and that is done here. + */ + + +struct unix_domain { + struct auth_domain h; + /* other stuff later */ +}; + +extern struct auth_ops svcauth_null; +extern struct auth_ops svcauth_unix; + +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + struct unix_domain *ud = container_of(dom, struct unix_domain, h); + + kfree(dom->name); + kfree(ud); +} + +struct auth_domain *unix_domain_find(char *name) +{ + struct auth_domain *rv; + struct unix_domain *new = NULL; + + rv = auth_domain_lookup(name, NULL); + while(1) { + if (rv) { + if (new && rv != &new->h) + svcauth_unix_domain_release(&new->h); + + if (rv->flavour != &svcauth_unix) { + auth_domain_put(rv); + return NULL; + } + return rv; + } + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (new == NULL) + return NULL; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (new->h.name == NULL) { + kfree(new); + return NULL; + } + new->h.flavour = &svcauth_unix; + rv = auth_domain_lookup(name, &new->h); + } +} +EXPORT_SYMBOL_GPL(unix_domain_find); + + +/************************************************** + * cache for IP address to unix_domain + * as needed by AUTH_UNIX + */ +#define IP_HASHBITS 8 +#define IP_HASHMAX (1<<IP_HASHBITS) + +struct ip_map { + struct cache_head h; + char m_class[8]; /* e.g. "nfsd" */ + struct in6_addr m_addr; + struct unix_domain *m_client; +}; + +static void ip_map_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct ip_map *im = container_of(item, struct ip_map,h); + + if (test_bit(CACHE_VALID, &item->flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + auth_domain_put(&im->m_client->h); + kfree(im); +} + +static inline int hash_ip6(const struct in6_addr *ip) +{ + return hash_32(ipv6_addr_hash(ip), IP_HASHBITS); +} +static int ip_map_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct ip_map *orig = container_of(corig, struct ip_map, h); + struct ip_map *new = container_of(cnew, struct ip_map, h); + return strcmp(orig->m_class, new->m_class) == 0 && + ipv6_addr_equal(&orig->m_addr, &new->m_addr); +} +static void ip_map_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + strcpy(new->m_class, item->m_class); + new->m_addr = item->m_addr; +} +static void update(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + kref_get(&item->m_client->h.ref); + new->m_client = item->m_client; +} +static struct cache_head *ip_map_alloc(void) +{ + struct ip_map *i = kmalloc(sizeof(*i), GFP_KERNEL); + if (i) + return &i->h; + else + return NULL; +} + +static void ip_map_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char text_addr[40]; + struct ip_map *im = container_of(h, struct ip_map, h); + + if (ipv6_addr_v4mapped(&(im->m_addr))) { + snprintf(text_addr, 20, "%pI4", &im->m_addr.s6_addr32[3]); + } else { + snprintf(text_addr, 40, "%pI6", &im->m_addr); + } + qword_add(bpp, blen, im->m_class); + qword_add(bpp, blen, text_addr); + (*bpp)[-1] = '\n'; +} + +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr); +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time_t expiry); + +static int ip_map_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* class ipaddress [domainname] */ + /* should be safe just to use the start of the input buffer + * for scratch: */ + char *buf = mesg; + int len; + char class[8]; + union { + struct sockaddr sa; + struct sockaddr_in s4; + struct sockaddr_in6 s6; + } address; + struct sockaddr_in6 sin6; + int err; + + struct ip_map *ipmp; + struct auth_domain *dom; + time_t expiry; + + if (mesg[mlen-1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + /* class */ + len = qword_get(&mesg, class, sizeof(class)); + if (len <= 0) return -EINVAL; + + /* ip address */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) return -EINVAL; + + if (rpc_pton(cd->net, buf, len, &address.sa, sizeof(address)) == 0) + return -EINVAL; + switch (address.sa.sa_family) { + case AF_INET: + /* Form a mapped IPv4 address in sin6 */ + sin6.sin6_family = AF_INET6; + ipv6_addr_set_v4mapped(address.s4.sin_addr.s_addr, + &sin6.sin6_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + memcpy(&sin6, &address.s6, sizeof(sin6)); + break; +#endif + default: + return -EINVAL; + } + + expiry = get_expiry(&mesg); + if (expiry ==0) + return -EINVAL; + + /* domainname, or empty for NEGATIVE */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) return -EINVAL; + + if (len) { + dom = unix_domain_find(buf); + if (dom == NULL) + return -ENOENT; + } else + dom = NULL; + + /* IPv6 scope IDs are ignored for now */ + ipmp = __ip_map_lookup(cd, class, &sin6.sin6_addr); + if (ipmp) { + err = __ip_map_update(cd, ipmp, + container_of(dom, struct unix_domain, h), + expiry); + } else + err = -ENOMEM; + + if (dom) + auth_domain_put(dom); + + cache_flush(); + return err; +} + +static int ip_map_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct ip_map *im; + struct in6_addr addr; + char *dom = "-no-domain-"; + + if (h == NULL) { + seq_puts(m, "#class IP domain\n"); + return 0; + } + im = container_of(h, struct ip_map, h); + /* class addr domain */ + addr = im->m_addr; + + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + dom = im->m_client->h.name; + + if (ipv6_addr_v4mapped(&addr)) { + seq_printf(m, "%s %pI4 %s\n", + im->m_class, &addr.s6_addr32[3], dom); + } else { + seq_printf(m, "%s %pI6 %s\n", im->m_class, &addr, dom); + } + return 0; +} + + +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, + struct in6_addr *addr) +{ + struct ip_map ip; + struct cache_head *ch; + + strcpy(ip.m_class, class); + ip.m_addr = *addr; + ch = sunrpc_cache_lookup(cd, &ip.h, + hash_str(class, IP_HASHBITS) ^ + hash_ip6(addr)); + + if (ch) + return container_of(ch, struct ip_map, h); + else + return NULL; +} + +static inline struct ip_map *ip_map_lookup(struct net *net, char *class, + struct in6_addr *addr) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + return __ip_map_lookup(sn->ip_map_cache, class, addr); +} + +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, + struct unix_domain *udom, time_t expiry) +{ + struct ip_map ip; + struct cache_head *ch; + + ip.m_client = udom; + ip.h.flags = 0; + if (!udom) + set_bit(CACHE_NEGATIVE, &ip.h.flags); + ip.h.expiry_time = expiry; + ch = sunrpc_cache_update(cd, &ip.h, &ipm->h, + hash_str(ipm->m_class, IP_HASHBITS) ^ + hash_ip6(&ipm->m_addr)); + if (!ch) + return -ENOMEM; + cache_put(ch, cd); + return 0; +} + +static inline int ip_map_update(struct net *net, struct ip_map *ipm, + struct unix_domain *udom, time_t expiry) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + return __ip_map_update(sn->ip_map_cache, ipm, udom, expiry); +} + +void svcauth_unix_purge(struct net *net) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + cache_purge(sn->ip_map_cache); +} +EXPORT_SYMBOL_GPL(svcauth_unix_purge); + +static inline struct ip_map * +ip_map_cached_get(struct svc_xprt *xprt) +{ + struct ip_map *ipm = NULL; + struct sunrpc_net *sn; + + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + ipm = xprt->xpt_auth_cache; + if (ipm != NULL) { + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + if (cache_is_expired(sn->ip_map_cache, &ipm->h)) { + /* + * The entry has been invalidated since it was + * remembered, e.g. by a second mount from the + * same IP address. + */ + xprt->xpt_auth_cache = NULL; + spin_unlock(&xprt->xpt_lock); + cache_put(&ipm->h, sn->ip_map_cache); + return NULL; + } + cache_get(&ipm->h); + } + spin_unlock(&xprt->xpt_lock); + } + return ipm; +} + +static inline void +ip_map_cached_put(struct svc_xprt *xprt, struct ip_map *ipm) +{ + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + if (xprt->xpt_auth_cache == NULL) { + /* newly cached, keep the reference */ + xprt->xpt_auth_cache = ipm; + ipm = NULL; + } + spin_unlock(&xprt->xpt_lock); + } + if (ipm) { + struct sunrpc_net *sn; + + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } +} + +void +svcauth_unix_info_release(struct svc_xprt *xpt) +{ + struct ip_map *ipm; + + ipm = xpt->xpt_auth_cache; + if (ipm != NULL) { + struct sunrpc_net *sn; + + sn = net_generic(xpt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } +} + +/**************************************************************************** + * auth.unix.gid cache + * simple cache to map a UID to a list of GIDs + * because AUTH_UNIX aka AUTH_SYS has a max of UNX_NGROUPS + */ +#define GID_HASHBITS 8 +#define GID_HASHMAX (1<<GID_HASHBITS) + +struct unix_gid { + struct cache_head h; + kuid_t uid; + struct group_info *gi; +}; + +static int unix_gid_hash(kuid_t uid) +{ + return hash_long(from_kuid(&init_user_ns, uid), GID_HASHBITS); +} + +static void unix_gid_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct unix_gid *ug = container_of(item, struct unix_gid, h); + if (test_bit(CACHE_VALID, &item->flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + put_group_info(ug->gi); + kfree(ug); +} + +static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct unix_gid *orig = container_of(corig, struct unix_gid, h); + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + return uid_eq(orig->uid, new->uid); +} +static void unix_gid_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + new->uid = item->uid; +} +static void unix_gid_update(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + + get_group_info(item->gi); + new->gi = item->gi; +} +static struct cache_head *unix_gid_alloc(void) +{ + struct unix_gid *g = kmalloc(sizeof(*g), GFP_KERNEL); + if (g) + return &g->h; + else + return NULL; +} + +static void unix_gid_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char tuid[20]; + struct unix_gid *ug = container_of(h, struct unix_gid, h); + + snprintf(tuid, 20, "%u", from_kuid(&init_user_ns, ug->uid)); + qword_add(bpp, blen, tuid); + (*bpp)[-1] = '\n'; +} + +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid); + +static int unix_gid_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* uid expiry Ngid gid0 gid1 ... gidN-1 */ + int id; + kuid_t uid; + int gids; + int rv; + int i; + int err; + time_t expiry; + struct unix_gid ug, *ugp; + + if (mesg[mlen - 1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + rv = get_int(&mesg, &id); + if (rv) + return -EINVAL; + uid = make_kuid(&init_user_ns, id); + ug.uid = uid; + + expiry = get_expiry(&mesg); + if (expiry == 0) + return -EINVAL; + + rv = get_int(&mesg, &gids); + if (rv || gids < 0 || gids > 8192) + return -EINVAL; + + ug.gi = groups_alloc(gids); + if (!ug.gi) + return -ENOMEM; + + for (i = 0 ; i < gids ; i++) { + int gid; + kgid_t kgid; + rv = get_int(&mesg, &gid); + err = -EINVAL; + if (rv) + goto out; + kgid = make_kgid(&init_user_ns, gid); + if (!gid_valid(kgid)) + goto out; + ug.gi->gid[i] = kgid; + } + + groups_sort(ug.gi); + ugp = unix_gid_lookup(cd, uid); + if (ugp) { + struct cache_head *ch; + ug.h.flags = 0; + ug.h.expiry_time = expiry; + ch = sunrpc_cache_update(cd, + &ug.h, &ugp->h, + unix_gid_hash(uid)); + if (!ch) + err = -ENOMEM; + else { + err = 0; + cache_put(ch, cd); + } + } else + err = -ENOMEM; + out: + if (ug.gi) + put_group_info(ug.gi); + return err; +} + +static int unix_gid_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct user_namespace *user_ns = &init_user_ns; + struct unix_gid *ug; + int i; + int glen; + + if (h == NULL) { + seq_puts(m, "#uid cnt: gids...\n"); + return 0; + } + ug = container_of(h, struct unix_gid, h); + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + glen = ug->gi->ngroups; + else + glen = 0; + + seq_printf(m, "%u %d:", from_kuid_munged(user_ns, ug->uid), glen); + for (i = 0; i < glen; i++) + seq_printf(m, " %d", from_kgid_munged(user_ns, ug->gi->gid[i])); + seq_printf(m, "\n"); + return 0; +} + +static const struct cache_detail unix_gid_cache_template = { + .owner = THIS_MODULE, + .hash_size = GID_HASHMAX, + .name = "auth.unix.gid", + .cache_put = unix_gid_put, + .cache_request = unix_gid_request, + .cache_parse = unix_gid_parse, + .cache_show = unix_gid_show, + .match = unix_gid_match, + .init = unix_gid_init, + .update = unix_gid_update, + .alloc = unix_gid_alloc, +}; + +int unix_gid_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&unix_gid_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->unix_gid_cache = cd; + return 0; +} + +void unix_gid_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->unix_gid_cache; + + sn->unix_gid_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid) +{ + struct unix_gid ug; + struct cache_head *ch; + + ug.uid = uid; + ch = sunrpc_cache_lookup(cd, &ug.h, unix_gid_hash(uid)); + if (ch) + return container_of(ch, struct unix_gid, h); + else + return NULL; +} + +static struct group_info *unix_gid_find(kuid_t uid, struct svc_rqst *rqstp) +{ + struct unix_gid *ug; + struct group_info *gi; + int ret; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, + sunrpc_net_id); + + ug = unix_gid_lookup(sn->unix_gid_cache, uid); + if (!ug) + return ERR_PTR(-EAGAIN); + ret = cache_check(sn->unix_gid_cache, &ug->h, &rqstp->rq_chandle); + switch (ret) { + case -ENOENT: + return ERR_PTR(-ENOENT); + case -ETIMEDOUT: + return ERR_PTR(-ESHUTDOWN); + case 0: + gi = get_group_info(ug->gi); + cache_put(&ug->h, sn->unix_gid_cache); + return gi; + default: + return ERR_PTR(-EAGAIN); + } +} + +int +svcauth_unix_set_client(struct svc_rqst *rqstp) +{ + struct sockaddr_in *sin; + struct sockaddr_in6 *sin6, sin6_storage; + struct ip_map *ipm; + struct group_info *gi; + struct svc_cred *cred = &rqstp->rq_cred; + struct svc_xprt *xprt = rqstp->rq_xprt; + struct net *net = xprt->xpt_net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + switch (rqstp->rq_addr.ss_family) { + case AF_INET: + sin = svc_addr_in(rqstp); + sin6 = &sin6_storage; + ipv6_addr_set_v4mapped(sin->sin_addr.s_addr, &sin6->sin6_addr); + break; + case AF_INET6: + sin6 = svc_addr_in6(rqstp); + break; + default: + BUG(); + } + + rqstp->rq_client = NULL; + if (rqstp->rq_proc == 0) + return SVC_OK; + + ipm = ip_map_cached_get(xprt); + if (ipm == NULL) + ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, + &sin6->sin6_addr); + + if (ipm == NULL) + return SVC_DENIED; + + switch (cache_check(sn->ip_map_cache, &ipm->h, &rqstp->rq_chandle)) { + default: + BUG(); + case -ETIMEDOUT: + return SVC_CLOSE; + case -EAGAIN: + return SVC_DROP; + case -ENOENT: + return SVC_DENIED; + case 0: + rqstp->rq_client = &ipm->m_client->h; + kref_get(&rqstp->rq_client->ref); + ip_map_cached_put(xprt, ipm); + break; + } + + gi = unix_gid_find(cred->cr_uid, rqstp); + switch (PTR_ERR(gi)) { + case -EAGAIN: + return SVC_DROP; + case -ESHUTDOWN: + return SVC_CLOSE; + case -ENOENT: + break; + default: + put_group_info(cred->cr_group_info); + cred->cr_group_info = gi; + } + return SVC_OK; +} + +EXPORT_SYMBOL_GPL(svcauth_unix_set_client); + +static int +svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + + if (argv->iov_len < 3*4) + return SVC_GARBAGE; + + if (svc_getu32(argv) != 0) { + dprintk("svc: bad null cred\n"); + *authp = rpc_autherr_badcred; + return SVC_DENIED; + } + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + dprintk("svc: bad null verf\n"); + *authp = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Signal that mapping to nobody uid/gid is required */ + cred->cr_uid = INVALID_UID; + cred->cr_gid = INVALID_GID; + cred->cr_group_info = groups_alloc(0); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; /* kmalloc failure - client must retry */ + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_cred.cr_flavor = RPC_AUTH_NULL; + return SVC_OK; +} + +static int +svcauth_null_release(struct svc_rqst *rqstp) +{ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; /* don't drop */ +} + + +struct auth_ops svcauth_null = { + .name = "null", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_NULL, + .accept = svcauth_null_accept, + .release = svcauth_null_release, + .set_client = svcauth_unix_set_client, +}; + + +static int +svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + u32 slen, i; + int len = argv->iov_len; + + if ((len -= 3*4) < 0) + return SVC_GARBAGE; + + svc_getu32(argv); /* length */ + svc_getu32(argv); /* time stamp */ + slen = XDR_QUADLEN(svc_getnl(argv)); /* machname length */ + if (slen > 64 || (len -= (slen + 3)*4) < 0) + goto badcred; + argv->iov_base = (void*)((__be32*)argv->iov_base + slen); /* skip machname */ + argv->iov_len -= slen*4; + /* + * Note: we skip uid_valid()/gid_valid() checks here for + * backwards compatibility with clients that use -1 id's. + * Instead, -1 uid or gid is later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * Supplementary gid's will be left alone. + */ + cred->cr_uid = make_kuid(&init_user_ns, svc_getnl(argv)); /* uid */ + cred->cr_gid = make_kgid(&init_user_ns, svc_getnl(argv)); /* gid */ + slen = svc_getnl(argv); /* gids length */ + if (slen > UNX_NGROUPS || (len -= (slen + 2)*4) < 0) + goto badcred; + cred->cr_group_info = groups_alloc(slen); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; + for (i = 0; i < slen; i++) { + kgid_t kgid = make_kgid(&init_user_ns, svc_getnl(argv)); + cred->cr_group_info->gid[i] = kgid; + } + groups_sort(cred->cr_group_info); + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + *authp = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_cred.cr_flavor = RPC_AUTH_UNIX; + return SVC_OK; + +badcred: + *authp = rpc_autherr_badcred; + return SVC_DENIED; +} + +static int +svcauth_unix_release(struct svc_rqst *rqstp) +{ + /* Verifier (such as it is) is already in place. + */ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; +} + + +struct auth_ops svcauth_unix = { + .name = "unix", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_UNIX, + .accept = svcauth_unix_accept, + .release = svcauth_unix_release, + .domain_release = svcauth_unix_domain_release, + .set_client = svcauth_unix_set_client, +}; + +static const struct cache_detail ip_map_cache_template = { + .owner = THIS_MODULE, + .hash_size = IP_HASHMAX, + .name = "auth.unix.ip", + .cache_put = ip_map_put, + .cache_request = ip_map_request, + .cache_parse = ip_map_parse, + .cache_show = ip_map_show, + .match = ip_map_match, + .init = ip_map_init, + .update = update, + .alloc = ip_map_alloc, +}; + +int ip_map_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&ip_map_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->ip_map_cache = cd; + return 0; +} + +void ip_map_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->ip_map_cache; + + sn->ip_map_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c new file mode 100644 index 000000000..d0b5a1c47 --- /dev/null +++ b/net/sunrpc/svcsock.c @@ -0,0 +1,1668 @@ +/* + * linux/net/sunrpc/svcsock.c + * + * These are the RPC server socket internals. + * + * The server scheduling algorithm does not always distribute the load + * evenly when servicing a single client. May need to modify the + * svc_xprt_enqueue procedure... + * + * TCP support is largely untested and may be a little slow. The problem + * is that we currently do two separate recvfrom's, one for the 4-byte + * record length, and the second for the actual record. This could possibly + * be improved by always reading a minimum size of around 100 bytes and + * tucking any superfluous bytes away in a temporary store. Still, that + * leaves write requests out in the rain. An alternative may be to peek at + * the first skb in the queue, and if it matches the next TCP sequence + * number, to extract the record marker. Yuck. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/kernel.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/errno.h> +#include <linux/fcntl.h> +#include <linux/net.h> +#include <linux/in.h> +#include <linux/inet.h> +#include <linux/udp.h> +#include <linux/tcp.h> +#include <linux/unistd.h> +#include <linux/slab.h> +#include <linux/netdevice.h> +#include <linux/skbuff.h> +#include <linux/file.h> +#include <linux/freezer.h> +#include <net/sock.h> +#include <net/checksum.h> +#include <net/ip.h> +#include <net/ipv6.h> +#include <net/udp.h> +#include <net/tcp.h> +#include <net/tcp_states.h> +#include <linux/uaccess.h> +#include <asm/ioctls.h> +#include <trace/events/skb.h> + +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/xprt.h> + +#include "sunrpc.h" + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + + +static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, + int flags); +static int svc_udp_recvfrom(struct svc_rqst *); +static int svc_udp_sendto(struct svc_rqst *); +static void svc_sock_detach(struct svc_xprt *); +static void svc_tcp_sock_detach(struct svc_xprt *); +static void svc_sock_free(struct svc_xprt *); + +static struct svc_xprt *svc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +static void svc_bc_sock_free(struct svc_xprt *xprt); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key svc_key[2]; +static struct lock_class_key svc_slock_key[2]; + +static void svc_reclassify_socket(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (WARN_ON_ONCE(!sock_allow_reclassification(sk))) + return; + + switch (sk->sk_family) { + case AF_INET: + sock_lock_init_class_and_name(sk, "slock-AF_INET-NFSD", + &svc_slock_key[0], + "sk_xprt.xpt_lock-AF_INET-NFSD", + &svc_key[0]); + break; + + case AF_INET6: + sock_lock_init_class_and_name(sk, "slock-AF_INET6-NFSD", + &svc_slock_key[1], + "sk_xprt.xpt_lock-AF_INET6-NFSD", + &svc_key[1]); + break; + + default: + BUG(); + } +} +#else +static void svc_reclassify_socket(struct socket *sock) +{ +} +#endif + +/* + * Release an skbuff after use + */ +static void svc_release_skb(struct svc_rqst *rqstp) +{ + struct sk_buff *skb = rqstp->rq_xprt_ctxt; + + if (skb) { + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + rqstp->rq_xprt_ctxt = NULL; + + dprintk("svc: service %p, releasing skb %p\n", rqstp, skb); + skb_free_datagram_locked(svsk->sk_sk, skb); + } +} + +static void svc_release_udp_skb(struct svc_rqst *rqstp) +{ + struct sk_buff *skb = rqstp->rq_xprt_ctxt; + + if (skb) { + rqstp->rq_xprt_ctxt = NULL; + + dprintk("svc: service %p, releasing skb %p\n", rqstp, skb); + consume_skb(skb); + } +} + +union svc_pktinfo_u { + struct in_pktinfo pkti; + struct in6_pktinfo pkti6; +}; +#define SVC_PKTINFO_SPACE \ + CMSG_SPACE(sizeof(union svc_pktinfo_u)) + +static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + switch (svsk->sk_sk->sk_family) { + case AF_INET: { + struct in_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IP; + cmh->cmsg_type = IP_PKTINFO; + pki->ipi_ifindex = 0; + pki->ipi_spec_dst.s_addr = + svc_daddr_in(rqstp)->sin_addr.s_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + + case AF_INET6: { + struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); + + cmh->cmsg_level = SOL_IPV6; + cmh->cmsg_type = IPV6_PKTINFO; + pki->ipi6_ifindex = daddr->sin6_scope_id; + pki->ipi6_addr = daddr->sin6_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + } +} + +/* + * send routine intended to be shared by the fore- and back-channel + */ +int svc_send_common(struct socket *sock, struct xdr_buf *xdr, + struct page *headpage, unsigned long headoffset, + struct page *tailpage, unsigned long tailoffset) +{ + int result; + int size; + struct page **ppage = xdr->pages; + size_t base = xdr->page_base; + unsigned int pglen = xdr->page_len; + unsigned int flags = MSG_MORE | MSG_SENDPAGE_NOTLAST; + int slen; + int len = 0; + + slen = xdr->len; + + /* send head */ + if (slen == xdr->head[0].iov_len) + flags = 0; + len = kernel_sendpage(sock, headpage, headoffset, + xdr->head[0].iov_len, flags); + if (len != xdr->head[0].iov_len) + goto out; + slen -= xdr->head[0].iov_len; + if (slen == 0) + goto out; + + /* send page data */ + size = PAGE_SIZE - base < pglen ? PAGE_SIZE - base : pglen; + while (pglen > 0) { + if (slen == size) + flags = 0; + result = kernel_sendpage(sock, *ppage, base, size, flags); + if (result > 0) + len += result; + if (result != size) + goto out; + slen -= size; + pglen -= size; + size = PAGE_SIZE < pglen ? PAGE_SIZE : pglen; + base = 0; + ppage++; + } + + /* send tail */ + if (xdr->tail[0].iov_len) { + result = kernel_sendpage(sock, tailpage, tailoffset, + xdr->tail[0].iov_len, 0); + if (result > 0) + len += result; + } + +out: + return len; +} + + +/* + * Generic sendto routine + */ +static int svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct socket *sock = svsk->sk_sock; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + int len = 0; + unsigned long tailoff; + unsigned long headoff; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + if (rqstp->rq_prot == IPPROTO_UDP) { + struct msghdr msg = { + .msg_name = &rqstp->rq_addr, + .msg_namelen = rqstp->rq_addrlen, + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_MORE, + }; + + svc_set_cmsg_data(rqstp, cmh); + + if (sock_sendmsg(sock, &msg) < 0) + goto out; + } + + tailoff = ((unsigned long)xdr->tail[0].iov_base) & (PAGE_SIZE-1); + headoff = 0; + len = svc_send_common(sock, xdr, rqstp->rq_respages[0], headoff, + rqstp->rq_respages[0], tailoff); + +out: + dprintk("svc: socket %p sendto([%p %zu... ], %d) = %d (addr %s)\n", + svsk, xdr->head[0].iov_base, xdr->head[0].iov_len, + xdr->len, len, svc_print_addr(rqstp, buf, sizeof(buf))); + + return len; +} + +/* + * Report socket names for nfsdfs + */ +static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) +{ + const struct sock *sk = svsk->sk_sk; + const char *proto_name = sk->sk_protocol == IPPROTO_UDP ? + "udp" : "tcp"; + int len; + + switch (sk->sk_family) { + case PF_INET: + len = snprintf(buf, remaining, "ipv4 %s %pI4 %d\n", + proto_name, + &inet_sk(sk)->inet_rcv_saddr, + inet_sk(sk)->inet_num); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + len = snprintf(buf, remaining, "ipv6 %s %pI6 %d\n", + proto_name, + &sk->sk_v6_rcv_saddr, + inet_sk(sk)->inet_num); + break; +#endif + default: + len = snprintf(buf, remaining, "*unknown-%d*\n", + sk->sk_family); + } + + if (len >= remaining) { + *buf = '\0'; + return -ENAMETOOLONG; + } + return len; +} + +/* + * Generic recvfrom routine. + */ +static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, + int buflen) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct msghdr msg = { + .msg_flags = MSG_DONTWAIT, + }; + int len; + + rqstp->rq_xprt_hlen = 0; + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, iov, nr, buflen); + len = sock_recvmsg(svsk->sk_sock, &msg, msg.msg_flags); + /* If we read a full record, then assume there may be more + * data to read (stream based sockets only!) + */ + if (len == buflen) + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + dprintk("svc: socket %p recvfrom(%p, %zu) = %d\n", + svsk, iov[0].iov_base, iov[0].iov_len, len); + return len; +} + +static int svc_partial_recvfrom(struct svc_rqst *rqstp, + struct kvec *iov, int nr, + int buflen, unsigned int base) +{ + size_t save_iovlen; + void *save_iovbase; + unsigned int i; + int ret; + + if (base == 0) + return svc_recvfrom(rqstp, iov, nr, buflen); + + for (i = 0; i < nr; i++) { + if (iov[i].iov_len > base) + break; + base -= iov[i].iov_len; + } + save_iovlen = iov[i].iov_len; + save_iovbase = iov[i].iov_base; + iov[i].iov_len -= base; + iov[i].iov_base += base; + ret = svc_recvfrom(rqstp, &iov[i], nr - i, buflen); + iov[i].iov_len = save_iovlen; + iov[i].iov_base = save_iovbase; + return ret; +} + +/* + * Set socket snd and rcv buffer lengths + */ +static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs) +{ + unsigned int max_mesg = svsk->sk_xprt.xpt_server->sv_max_mesg; + struct socket *sock = svsk->sk_sock; + + nreqs = min(nreqs, INT_MAX / 2 / max_mesg); + + lock_sock(sock->sk); + sock->sk->sk_sndbuf = nreqs * max_mesg * 2; + sock->sk->sk_rcvbuf = nreqs * max_mesg * 2; + sock->sk->sk_write_space(sock->sk); + release_sock(sock->sk); +} + +static void svc_sock_secure_port(struct svc_rqst *rqstp) +{ + if (svc_port_is_privileged(svc_addr(rqstp))) + set_bit(RQ_SECURE, &rqstp->rq_flags); + else + clear_bit(RQ_SECURE, &rqstp->rq_flags); +} + +/* + * INET callback when data has been received on the socket. + */ +static void svc_data_ready(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + if (svsk) { + dprintk("svc: socket %p(inet %p), busy=%d\n", + svsk, sk, + test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); + + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + if (!test_and_set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags)) + svc_xprt_enqueue(&svsk->sk_xprt); + } +} + +/* + * INET callback when space is newly available on the socket. + */ +static void svc_write_space(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data); + + if (svsk) { + dprintk("svc: socket %p(inet %p), write_space busy=%d\n", + svsk, sk, test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); + + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_owspace(sk); + svc_xprt_enqueue(&svsk->sk_xprt); + } +} + +static int svc_tcp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) + return 1; + return !test_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); +} + +static void svc_tcp_kill_temp_xprt(struct svc_xprt *xprt) +{ + struct svc_sock *svsk; + struct socket *sock; + struct linger no_linger = { + .l_onoff = 1, + .l_linger = 0, + }; + + svsk = container_of(xprt, struct svc_sock, sk_xprt); + sock = svsk->sk_sock; + kernel_setsockopt(sock, SOL_SOCKET, SO_LINGER, + (char *)&no_linger, sizeof(no_linger)); +} + +/* + * See net/ipv6/ip_sockglue.c : ip_cmsg_recv_pktinfo + */ +static int svc_udp_get_dest_address4(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + struct in_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in *daddr = svc_daddr_in(rqstp); + + if (cmh->cmsg_type != IP_PKTINFO) + return 0; + + daddr->sin_family = AF_INET; + daddr->sin_addr.s_addr = pki->ipi_spec_dst.s_addr; + return 1; +} + +/* + * See net/ipv6/datagram.c : ip6_datagram_recv_ctl + */ +static int svc_udp_get_dest_address6(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); + + if (cmh->cmsg_type != IPV6_PKTINFO) + return 0; + + daddr->sin6_family = AF_INET6; + daddr->sin6_addr = pki->ipi6_addr; + daddr->sin6_scope_id = pki->ipi6_ifindex; + return 1; +} + +/* + * Copy the UDP datagram's destination address to the rqstp structure. + * The 'destination' address in this case is the address to which the + * peer sent the datagram, i.e. our local address. For multihomed + * hosts, this can change from msg to msg. Note that only the IP + * address changes, the port number should remain the same. + */ +static int svc_udp_get_dest_address(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + switch (cmh->cmsg_level) { + case SOL_IP: + return svc_udp_get_dest_address4(rqstp, cmh); + case SOL_IPV6: + return svc_udp_get_dest_address6(rqstp, cmh); + } + + return 0; +} + +/* + * Receive a datagram from a UDP socket. + */ +static int svc_udp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct sk_buff *skb; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + struct msghdr msg = { + .msg_name = svc_addr(rqstp), + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_DONTWAIT, + }; + size_t len; + int err; + + if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags)) + /* udp sockets need large rcvbuf as all pending + * requests are still in that buffer. sndbuf must + * also be large enough that there is enough space + * for one reply per thread. We count all threads + * rather than threads in a particular pool, which + * provides an upper bound on the number of threads + * which will access the socket. + */ + svc_sock_setbufsize(svsk, serv->sv_nrthreads + 3); + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + skb = NULL; + err = kernel_recvmsg(svsk->sk_sock, &msg, NULL, + 0, 0, MSG_PEEK | MSG_DONTWAIT); + if (err >= 0) + skb = skb_recv_udp(svsk->sk_sk, 0, 1, &err); + + if (skb == NULL) { + if (err != -EAGAIN) { + /* possibly an icmp error */ + dprintk("svc: recvfrom returned error %d\n", -err); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + } + return 0; + } + len = svc_addr_len(svc_addr(rqstp)); + rqstp->rq_addrlen = len; + if (skb->tstamp == 0) { + skb->tstamp = ktime_get_real(); + /* Don't enable netstamp, sunrpc doesn't + need that much accuracy */ + } + sock_write_timestamp(svsk->sk_sk, skb->tstamp); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */ + + len = skb->len; + rqstp->rq_arg.len = len; + + rqstp->rq_prot = IPPROTO_UDP; + + if (!svc_udp_get_dest_address(rqstp, cmh)) { + net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n", + cmh->cmsg_level, cmh->cmsg_type); + goto out_free; + } + rqstp->rq_daddrlen = svc_addr_len(svc_daddr(rqstp)); + + if (skb_is_nonlinear(skb)) { + /* we have to copy */ + local_bh_disable(); + if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) { + local_bh_enable(); + /* checksum error */ + goto out_free; + } + local_bh_enable(); + consume_skb(skb); + } else { + /* we can use it in-place */ + rqstp->rq_arg.head[0].iov_base = skb->data; + rqstp->rq_arg.head[0].iov_len = len; + if (skb_checksum_complete(skb)) + goto out_free; + rqstp->rq_xprt_ctxt = skb; + } + + rqstp->rq_arg.page_base = 0; + if (len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = len; + rqstp->rq_arg.page_len = 0; + rqstp->rq_respages = rqstp->rq_pages+1; + } else { + rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len; + rqstp->rq_respages = rqstp->rq_pages + 1 + + DIV_ROUND_UP(rqstp->rq_arg.page_len, PAGE_SIZE); + } + rqstp->rq_next_page = rqstp->rq_respages+1; + + if (serv->sv_stats) + serv->sv_stats->netudpcnt++; + + return len; +out_free: + kfree_skb(skb); + return 0; +} + +static int +svc_udp_sendto(struct svc_rqst *rqstp) +{ + int error; + + svc_release_udp_skb(rqstp); + + error = svc_sendto(rqstp, &rqstp->rq_res); + if (error == -ECONNREFUSED) + /* ICMP error on earlier request. */ + error = svc_sendto(rqstp, &rqstp->rq_res); + + return error; +} + +static void svc_udp_prep_reply_hdr(struct svc_rqst *rqstp) +{ +} + +static int svc_udp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = xprt->xpt_server; + unsigned long required; + + /* + * Set the SOCK_NOSPACE flag before checking the available + * sock space. + */ + set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + required = atomic_read(&svsk->sk_xprt.xpt_reserved) + serv->sv_max_mesg; + if (required*2 > sock_wspace(svsk->sk_sk)) + return 0; + clear_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + return 1; +} + +static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt) +{ + BUG(); + return NULL; +} + +static void svc_udp_kill_temp_xprt(struct svc_xprt *xprt) +{ +} + +static struct svc_xprt *svc_udp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_UDP, net, sa, salen, flags); +} + +static const struct svc_xprt_ops svc_udp_ops = { + .xpo_create = svc_udp_create, + .xpo_recvfrom = svc_udp_recvfrom, + .xpo_sendto = svc_udp_sendto, + .xpo_release_rqst = svc_release_udp_skb, + .xpo_detach = svc_sock_detach, + .xpo_free = svc_sock_free, + .xpo_prep_reply_hdr = svc_udp_prep_reply_hdr, + .xpo_has_wspace = svc_udp_has_wspace, + .xpo_accept = svc_udp_accept, + .xpo_secure_port = svc_sock_secure_port, + .xpo_kill_temp_xprt = svc_udp_kill_temp_xprt, +}; + +static struct svc_xprt_class svc_udp_class = { + .xcl_name = "udp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_udp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_UDP, + .xcl_ident = XPRT_TRANSPORT_UDP, +}; + +static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + int err, level, optname, one = 1; + + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_udp_class, + &svsk->sk_xprt, serv); + clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + svsk->sk_sk->sk_data_ready = svc_data_ready; + svsk->sk_sk->sk_write_space = svc_write_space; + + /* initialise setting must have enough space to + * receive and respond to one request. + * svc_udp_recvfrom will re-adjust if necessary + */ + svc_sock_setbufsize(svsk, 3); + + /* data might have come in before data_ready set up */ + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + + /* make sure we get destination address info */ + switch (svsk->sk_sk->sk_family) { + case AF_INET: + level = SOL_IP; + optname = IP_PKTINFO; + break; + case AF_INET6: + level = SOL_IPV6; + optname = IPV6_RECVPKTINFO; + break; + default: + BUG(); + } + err = kernel_setsockopt(svsk->sk_sock, level, optname, + (char *)&one, sizeof(one)); + dprintk("svc: kernel_setsockopt returned %d\n", err); +} + +/* + * A data_ready event on a listening socket means there's a connection + * pending. Do not use state_change as a substitute for it. + */ +static void svc_tcp_listen_data_ready(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + dprintk("svc: socket %p TCP (listen) state change %d\n", + sk, sk->sk_state); + + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + } + + /* + * This callback may called twice when a new connection + * is established as a child socket inherits everything + * from a parent LISTEN socket. + * 1) data_ready method of the parent socket will be called + * when one of child sockets become ESTABLISHED. + * 2) data_ready method of the child socket may be called + * when it receives data before the socket is accepted. + * In case of 2, we should ignore it silently. + */ + if (sk->sk_state == TCP_LISTEN) { + if (svsk) { + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } else + printk("svc: socket %p: no user data\n", sk); + } +} + +/* + * A state change on a connected socket means it's dying or dead. + */ +static void svc_tcp_state_change(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + dprintk("svc: socket %p TCP (connected) state change %d (svsk %p)\n", + sk, sk->sk_state, sk->sk_user_data); + + if (!svsk) + printk("svc: socket %p: no user data\n", sk); + else { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_ostate(sk); + if (sk->sk_state != TCP_ESTABLISHED) { + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } + } +} + +/* + * Accept a TCP connection + */ +static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *) &addr; + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct socket *sock = svsk->sk_sock; + struct socket *newsock; + struct svc_sock *newsvsk; + int err, slen; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + dprintk("svc: tcp_accept %p sock %p\n", svsk, sock); + if (!sock) + return NULL; + + clear_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + err = kernel_accept(sock, &newsock, O_NONBLOCK); + if (err < 0) { + if (err == -ENOMEM) + printk(KERN_WARNING "%s: no more sockets!\n", + serv->sv_name); + else if (err != -EAGAIN) + net_warn_ratelimited("%s: accept failed (err %d)!\n", + serv->sv_name, -err); + return NULL; + } + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + + err = kernel_getpeername(newsock, sin); + if (err < 0) { + net_warn_ratelimited("%s: peername failed (err %d)!\n", + serv->sv_name, -err); + goto failed; /* aborted connection or whatever */ + } + slen = err; + + /* Ideally, we would want to reject connections from unauthorized + * hosts here, but when we get encryption, the IP of the host won't + * tell us anything. For now just warn about unpriv connections. + */ + if (!svc_port_is_privileged(sin)) { + dprintk("%s: connect from unprivileged port: %s\n", + serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); + } + dprintk("%s: connect from %s\n", serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); + + /* Reset the inherited callbacks before calling svc_setup_socket */ + newsock->sk->sk_state_change = svsk->sk_ostate; + newsock->sk->sk_data_ready = svsk->sk_odata; + newsock->sk->sk_write_space = svsk->sk_owspace; + + /* make sure that a write doesn't block forever when + * low on memory + */ + newsock->sk->sk_sndtimeo = HZ*30; + + newsvsk = svc_setup_socket(serv, newsock, + (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)); + if (IS_ERR(newsvsk)) + goto failed; + svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen); + err = kernel_getsockname(newsock, sin); + slen = err; + if (unlikely(err < 0)) { + dprintk("svc_tcp_accept: kernel_getsockname error %d\n", -err); + slen = offsetof(struct sockaddr, sa_data); + } + svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen); + + if (sock_is_loopback(newsock->sk)) + set_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + else + clear_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + if (serv->sv_stats) + serv->sv_stats->nettcpconn++; + + return &newsvsk->sk_xprt; + +failed: + sock_release(newsock); + return NULL; +} + +static unsigned int svc_tcp_restore_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + return 0; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (rqstp->rq_pages[i] != NULL) + put_page(rqstp->rq_pages[i]); + BUG_ON(svsk->sk_pages[i] == NULL); + rqstp->rq_pages[i] = svsk->sk_pages[i]; + svsk->sk_pages[i] = NULL; + } + rqstp->rq_arg.head[0].iov_base = page_address(rqstp->rq_pages[0]); + return len; +} + +static void svc_tcp_save_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + return; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + svsk->sk_pages[i] = rqstp->rq_pages[i]; + rqstp->rq_pages[i] = NULL; + } +} + +static void svc_tcp_clear_pages(struct svc_sock *svsk) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + goto out; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (svsk->sk_pages[i] == NULL) { + WARN_ON_ONCE(1); + continue; + } + put_page(svsk->sk_pages[i]); + svsk->sk_pages[i] = NULL; + } +out: + svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; +} + +/* + * Receive fragment record header. + * If we haven't gotten the record length yet, get the next four bytes. + */ +static int svc_tcp_recv_record(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + unsigned int want; + int len; + + if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) { + struct kvec iov; + + want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; + iov.iov_base = ((char *) &svsk->sk_reclen) + svsk->sk_tcplen; + iov.iov_len = want; + if ((len = svc_recvfrom(rqstp, &iov, 1, want)) < 0) + goto error; + svsk->sk_tcplen += len; + + if (len < want) { + dprintk("svc: short recvfrom while reading record " + "length (%d of %d)\n", len, want); + return -EAGAIN; + } + + dprintk("svc: TCP record, %d bytes\n", svc_sock_reclen(svsk)); + if (svc_sock_reclen(svsk) + svsk->sk_datalen > + serv->sv_max_mesg) { + net_notice_ratelimited("RPC: fragment too large: %d\n", + svc_sock_reclen(svsk)); + goto err_delete; + } + } + + return svc_sock_reclen(svsk); +error: + dprintk("RPC: TCP recv_record got %d\n", len); + return len; +err_delete: + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + return -EAGAIN; +} + +static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + struct rpc_xprt *bc_xprt = svsk->sk_xprt.xpt_bc_xprt; + struct rpc_rqst *req = NULL; + struct kvec *src, *dst; + __be32 *p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + __be32 xid; + __be32 calldir; + + xid = *p++; + calldir = *p; + + if (!bc_xprt) + return -EAGAIN; + spin_lock(&bc_xprt->recv_lock); + req = xprt_lookup_rqst(bc_xprt, xid); + if (!req) + goto unlock_notfound; + + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + /* + * XXX!: cheating for now! Only copying HEAD. + * But we know this is good enough for now (in fact, for any + * callback reply in the forseeable future). + */ + dst = &req->rq_private_buf.head[0]; + src = &rqstp->rq_arg.head[0]; + if (dst->iov_len < src->iov_len) + goto unlock_eagain; /* whatever; just giving up. */ + memcpy(dst->iov_base, src->iov_base, src->iov_len); + xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len); + rqstp->rq_arg.len = 0; + spin_unlock(&bc_xprt->recv_lock); + return 0; +unlock_notfound: + printk(KERN_NOTICE + "%s: Got unrecognized reply: " + "calldir 0x%x xpt_bc_xprt %p xid %08x\n", + __func__, ntohl(calldir), + bc_xprt, ntohl(xid)); +unlock_eagain: + spin_unlock(&bc_xprt->recv_lock); + return -EAGAIN; +} + +static int copy_pages_to_kvecs(struct kvec *vec, struct page **pages, int len) +{ + int i = 0; + int t = 0; + + while (t < len) { + vec[i].iov_base = page_address(pages[i]); + vec[i].iov_len = PAGE_SIZE; + i++; + t += PAGE_SIZE; + } + return i; +} + +static void svc_tcp_fragment_received(struct svc_sock *svsk) +{ + /* If we have more data, signal svc_xprt_enqueue() to try again */ + dprintk("svc: TCP %s record (%d bytes)\n", + svc_sock_final_rec(svsk) ? "final" : "nonfinal", + svc_sock_reclen(svsk)); + svsk->sk_tcplen = 0; + svsk->sk_reclen = 0; +} + +/* + * Receive data from a TCP socket. + */ +static int svc_tcp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + int len; + struct kvec *vec; + unsigned int want, base; + __be32 *p; + __be32 calldir; + int pnum; + + dprintk("svc: tcp_recv %p data %d conn %d close %d\n", + svsk, test_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags), + test_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags), + test_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags)); + + len = svc_tcp_recv_record(svsk, rqstp); + if (len < 0) + goto error; + + base = svc_tcp_restore_pages(svsk, rqstp); + want = svc_sock_reclen(svsk) - (svsk->sk_tcplen - sizeof(rpc_fraghdr)); + + vec = rqstp->rq_vec; + + pnum = copy_pages_to_kvecs(&vec[0], &rqstp->rq_pages[0], + svsk->sk_datalen + want); + + rqstp->rq_respages = &rqstp->rq_pages[pnum]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + + /* Now receive data */ + len = svc_partial_recvfrom(rqstp, vec, pnum, want, base); + if (len >= 0) { + svsk->sk_tcplen += len; + svsk->sk_datalen += len; + } + if (len != want || !svc_sock_final_rec(svsk)) { + svc_tcp_save_pages(svsk, rqstp); + if (len < 0 && len != -EAGAIN) + goto err_delete; + if (len == want) + svc_tcp_fragment_received(svsk); + else + dprintk("svc: incomplete TCP record (%d of %d)\n", + (int)(svsk->sk_tcplen - sizeof(rpc_fraghdr)), + svc_sock_reclen(svsk)); + goto err_noclose; + } + + if (svsk->sk_datalen < 8) { + svsk->sk_datalen = 0; + goto err_delete; /* client is nuts. */ + } + + rqstp->rq_arg.len = svsk->sk_datalen; + rqstp->rq_arg.page_base = 0; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; + rqstp->rq_arg.page_len = 0; + } else + rqstp->rq_arg.page_len = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; + + rqstp->rq_xprt_ctxt = NULL; + rqstp->rq_prot = IPPROTO_TCP; + if (test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags)) + set_bit(RQ_LOCAL, &rqstp->rq_flags); + else + clear_bit(RQ_LOCAL, &rqstp->rq_flags); + + p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + calldir = p[1]; + if (calldir) + len = receive_cb_reply(svsk, rqstp); + + /* Reset TCP read info */ + svsk->sk_datalen = 0; + svc_tcp_fragment_received(svsk); + + if (len < 0) + goto error; + + svc_xprt_copy_addrs(rqstp, &svsk->sk_xprt); + if (serv->sv_stats) + serv->sv_stats->nettcpcnt++; + + return rqstp->rq_arg.len; + +error: + if (len != -EAGAIN) + goto err_delete; + dprintk("RPC: TCP recvfrom got EAGAIN\n"); + return 0; +err_delete: + printk(KERN_NOTICE "%s: recvfrom returned errno %d\n", + svsk->sk_xprt.xpt_server->sv_name, -len); + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); +err_noclose: + return 0; /* record not complete */ +} + +/* + * Send out data on TCP socket. + */ +static int svc_tcp_sendto(struct svc_rqst *rqstp) +{ + struct xdr_buf *xbufp = &rqstp->rq_res; + int sent; + __be32 reclen; + + svc_release_skb(rqstp); + + /* Set up the first element of the reply kvec. + * Any other kvecs that may be in use have been taken + * care of by the server implementation itself. + */ + reclen = htonl(0x80000000|((xbufp->len ) - 4)); + memcpy(xbufp->head[0].iov_base, &reclen, 4); + + sent = svc_sendto(rqstp, &rqstp->rq_res); + if (sent != xbufp->len) { + printk(KERN_NOTICE + "rpc-srv/tcp: %s: %s %d when sending %d bytes " + "- shutting down socket\n", + rqstp->rq_xprt->xpt_server->sv_name, + (sent<0)?"got error":"sent only", + sent, xbufp->len); + set_bit(XPT_CLOSE, &rqstp->rq_xprt->xpt_flags); + svc_xprt_enqueue(rqstp->rq_xprt); + sent = -EAGAIN; + } + return sent; +} + +/* + * Setup response header. TCP has a 4B record length field. + */ +void svc_tcp_prep_reply_hdr(struct svc_rqst *rqstp) +{ + struct kvec *resv = &rqstp->rq_res.head[0]; + + /* tcp needs a space for the record length... */ + svc_putnl(resv, 0); +} + +static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +static void svc_bc_sock_free(struct svc_xprt *xprt); + +static struct svc_xprt *svc_bc_tcp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_bc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); +} + +static void svc_bc_tcp_sock_detach(struct svc_xprt *xprt) +{ +} + +static const struct svc_xprt_ops svc_tcp_bc_ops = { + .xpo_create = svc_bc_tcp_create, + .xpo_detach = svc_bc_tcp_sock_detach, + .xpo_free = svc_bc_sock_free, + .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, + .xpo_secure_port = svc_sock_secure_port, +}; + +static struct svc_xprt_class svc_tcp_bc_class = { + .xcl_name = "tcp-bc", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_tcp_bc_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, +}; + +static void svc_init_bc_xprt_sock(void) +{ + svc_reg_xprt_class(&svc_tcp_bc_class); +} + +static void svc_cleanup_bc_xprt_sock(void) +{ + svc_unreg_xprt_class(&svc_tcp_bc_class); +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +static void svc_init_bc_xprt_sock(void) +{ +} + +static void svc_cleanup_bc_xprt_sock(void) +{ +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static const struct svc_xprt_ops svc_tcp_ops = { + .xpo_create = svc_tcp_create, + .xpo_recvfrom = svc_tcp_recvfrom, + .xpo_sendto = svc_tcp_sendto, + .xpo_release_rqst = svc_release_skb, + .xpo_detach = svc_tcp_sock_detach, + .xpo_free = svc_sock_free, + .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, + .xpo_has_wspace = svc_tcp_has_wspace, + .xpo_accept = svc_tcp_accept, + .xpo_secure_port = svc_sock_secure_port, + .xpo_kill_temp_xprt = svc_tcp_kill_temp_xprt, +}; + +static struct svc_xprt_class svc_tcp_class = { + .xcl_name = "tcp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_tcp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, + .xcl_ident = XPRT_TRANSPORT_TCP, +}; + +void svc_init_xprt_sock(void) +{ + svc_reg_xprt_class(&svc_tcp_class); + svc_reg_xprt_class(&svc_udp_class); + svc_init_bc_xprt_sock(); +} + +void svc_cleanup_xprt_sock(void) +{ + svc_unreg_xprt_class(&svc_tcp_class); + svc_unreg_xprt_class(&svc_udp_class); + svc_cleanup_bc_xprt_sock(); +} + +static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + struct sock *sk = svsk->sk_sk; + + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_tcp_class, + &svsk->sk_xprt, serv); + set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CONG_CTRL, &svsk->sk_xprt.xpt_flags); + if (sk->sk_state == TCP_LISTEN) { + dprintk("setting up TCP socket for listening\n"); + strcpy(svsk->sk_xprt.xpt_remotebuf, "listener"); + set_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags); + sk->sk_data_ready = svc_tcp_listen_data_ready; + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + } else { + dprintk("setting up TCP socket for reading\n"); + sk->sk_state_change = svc_tcp_state_change; + sk->sk_data_ready = svc_data_ready; + sk->sk_write_space = svc_write_space; + + svsk->sk_reclen = 0; + svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; + memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages)); + + tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; + + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + switch (sk->sk_state) { + case TCP_SYN_RECV: + case TCP_ESTABLISHED: + break; + default: + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + } + } +} + +void svc_sock_update_bufs(struct svc_serv *serv) +{ + /* + * The number of server threads has changed. Update + * rcvbuf and sndbuf accordingly on all sockets + */ + struct svc_sock *svsk; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list) + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + spin_unlock_bh(&serv->sv_lock); +} +EXPORT_SYMBOL_GPL(svc_sock_update_bufs); + +/* + * Initialize socket for RPC use and create svc_sock struct + */ +static struct svc_sock *svc_setup_socket(struct svc_serv *serv, + struct socket *sock, + int flags) +{ + struct svc_sock *svsk; + struct sock *inet; + int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + int err = 0; + + dprintk("svc: svc_setup_socket %p\n", sock); + svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + if (!svsk) + return ERR_PTR(-ENOMEM); + + inet = sock->sk; + + /* Register socket with portmapper */ + if (pmap_register) + err = svc_register(serv, sock_net(sock->sk), inet->sk_family, + inet->sk_protocol, + ntohs(inet_sk(inet)->inet_sport)); + + if (err < 0) { + kfree(svsk); + return ERR_PTR(err); + } + + svsk->sk_sock = sock; + svsk->sk_sk = inet; + svsk->sk_ostate = inet->sk_state_change; + svsk->sk_odata = inet->sk_data_ready; + svsk->sk_owspace = inet->sk_write_space; + /* + * This barrier is necessary in order to prevent race condition + * with svc_data_ready(), svc_listen_data_ready() and others + * when calling callbacks above. + */ + wmb(); + inet->sk_user_data = svsk; + + /* Initialize the socket */ + if (sock->type == SOCK_DGRAM) + svc_udp_init(svsk, serv); + else + svc_tcp_init(svsk, serv); + + dprintk("svc: svc_setup_socket created %p (inet %p), " + "listen %d close %d\n", + svsk, svsk->sk_sk, + test_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags), + test_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags)); + + return svsk; +} + +bool svc_alien_sock(struct net *net, int fd) +{ + int err; + struct socket *sock = sockfd_lookup(fd, &err); + bool ret = false; + + if (!sock) + goto out; + if (sock_net(sock->sk) != net) + ret = true; + sockfd_put(sock); +out: + return ret; +} +EXPORT_SYMBOL_GPL(svc_alien_sock); + +/** + * svc_addsock - add a listener socket to an RPC service + * @serv: pointer to RPC service to which to add a new listener + * @fd: file descriptor of the new listener + * @name_return: pointer to buffer to fill in with name of listener + * @len: size of the buffer + * + * Fills in socket name and returns positive length of name if successful. + * Name is terminated with '\n'. On error, returns a negative errno + * value. + */ +int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, + const size_t len) +{ + int err = 0; + struct socket *so = sockfd_lookup(fd, &err); + struct svc_sock *svsk = NULL; + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *)&addr; + int salen; + + if (!so) + return err; + err = -EAFNOSUPPORT; + if ((so->sk->sk_family != PF_INET) && (so->sk->sk_family != PF_INET6)) + goto out; + err = -EPROTONOSUPPORT; + if (so->sk->sk_protocol != IPPROTO_TCP && + so->sk->sk_protocol != IPPROTO_UDP) + goto out; + err = -EISCONN; + if (so->state > SS_UNCONNECTED) + goto out; + err = -ENOENT; + if (!try_module_get(THIS_MODULE)) + goto out; + svsk = svc_setup_socket(serv, so, SVC_SOCK_DEFAULTS); + if (IS_ERR(svsk)) { + module_put(THIS_MODULE); + err = PTR_ERR(svsk); + goto out; + } + salen = kernel_getsockname(svsk->sk_sock, sin); + if (salen >= 0) + svc_xprt_set_local(&svsk->sk_xprt, sin, salen); + svc_add_new_perm_xprt(serv, &svsk->sk_xprt); + return svc_one_sock_name(svsk, name_return, len); +out: + sockfd_put(so); + return err; +} +EXPORT_SYMBOL_GPL(svc_addsock); + +/* + * Create socket for RPC service. + */ +static struct svc_xprt *svc_create_socket(struct svc_serv *serv, + int protocol, + struct net *net, + struct sockaddr *sin, int len, + int flags) +{ + struct svc_sock *svsk; + struct socket *sock; + int error; + int type; + struct sockaddr_storage addr; + struct sockaddr *newsin = (struct sockaddr *)&addr; + int newlen; + int family; + int val; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + dprintk("svc: svc_create_socket(%s, %d, %s)\n", + serv->sv_program->pg_name, protocol, + __svc_print_addr(sin, buf, sizeof(buf))); + + if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) { + printk(KERN_WARNING "svc: only UDP and TCP " + "sockets supported\n"); + return ERR_PTR(-EINVAL); + } + + type = (protocol == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM; + switch (sin->sa_family) { + case AF_INET6: + family = PF_INET6; + break; + case AF_INET: + family = PF_INET; + break; + default: + return ERR_PTR(-EINVAL); + } + + error = __sock_create(net, family, type, protocol, &sock, 1); + if (error < 0) + return ERR_PTR(error); + + svc_reclassify_socket(sock); + + /* + * If this is an PF_INET6 listener, we want to avoid + * getting requests from IPv4 remotes. Those should + * be shunted to a PF_INET listener via rpcbind. + */ + val = 1; + if (family == PF_INET6) + kernel_setsockopt(sock, SOL_IPV6, IPV6_V6ONLY, + (char *)&val, sizeof(val)); + + if (type == SOCK_STREAM) + sock->sk->sk_reuse = SK_CAN_REUSE; /* allow address reuse */ + error = kernel_bind(sock, sin, len); + if (error < 0) + goto bummer; + + error = kernel_getsockname(sock, newsin); + if (error < 0) + goto bummer; + newlen = error; + + if (protocol == IPPROTO_TCP) { + if ((error = kernel_listen(sock, 64)) < 0) + goto bummer; + } + + svsk = svc_setup_socket(serv, sock, flags); + if (IS_ERR(svsk)) { + error = PTR_ERR(svsk); + goto bummer; + } + svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); + return (struct svc_xprt *)svsk; +bummer: + dprintk("svc: svc_create_socket error = %d\n", -error); + sock_release(sock); + return ERR_PTR(error); +} + +/* + * Detach the svc_sock from the socket so that no + * more callbacks occur. + */ +static void svc_sock_detach(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sock *sk = svsk->sk_sk; + + dprintk("svc: svc_sock_detach(%p)\n", svsk); + + /* put back the old socket callbacks */ + lock_sock(sk); + sk->sk_state_change = svsk->sk_ostate; + sk->sk_data_ready = svsk->sk_odata; + sk->sk_write_space = svsk->sk_owspace; + sk->sk_user_data = NULL; + release_sock(sk); +} + +/* + * Disconnect the socket, and reset the callbacks + */ +static void svc_tcp_sock_detach(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + dprintk("svc: svc_tcp_sock_detach(%p)\n", svsk); + + svc_sock_detach(xprt); + + if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + svc_tcp_clear_pages(svsk); + kernel_sock_shutdown(svsk->sk_sock, SHUT_RDWR); + } +} + +/* + * Free the svc_sock's socket resources and the svc_sock itself. + */ +static void svc_sock_free(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + dprintk("svc: svc_sock_free(%p)\n", svsk); + + if (svsk->sk_sock->file) + sockfd_put(svsk->sk_sock); + else + sock_release(svsk->sk_sock); + kfree(svsk); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/* + * Create a back channel svc_xprt which shares the fore channel socket. + */ +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *serv, + int protocol, + struct net *net, + struct sockaddr *sin, int len, + int flags) +{ + struct svc_sock *svsk; + struct svc_xprt *xprt; + + if (protocol != IPPROTO_TCP) { + printk(KERN_WARNING "svc: only TCP sockets" + " supported on shared back channel\n"); + return ERR_PTR(-EINVAL); + } + + svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + if (!svsk) + return ERR_PTR(-ENOMEM); + + xprt = &svsk->sk_xprt; + svc_xprt_init(net, &svc_tcp_bc_class, xprt, serv); + set_bit(XPT_CONG_CTRL, &svsk->sk_xprt.xpt_flags); + + serv->sv_bc_xprt = xprt; + + return xprt; +} + +/* + * Free a back channel svc_sock. + */ +static void svc_bc_sock_free(struct svc_xprt *xprt) +{ + if (xprt) + kfree(container_of(xprt, struct svc_sock, sk_xprt)); +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c new file mode 100644 index 000000000..8c3936403 --- /dev/null +++ b/net/sunrpc/sysctl.c @@ -0,0 +1,186 @@ +/* + * linux/net/sunrpc/sysctl.c + * + * Sysctl interface to sunrpc module. + * + * I would prefer to register the sunrpc table below sys/net, but that's + * impossible at the moment. + */ + +#include <linux/types.h> +#include <linux/linkage.h> +#include <linux/ctype.h> +#include <linux/fs.h> +#include <linux/sysctl.h> +#include <linux/module.h> + +#include <linux/uaccess.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svc_xprt.h> + +#include "netns.h" + +/* + * Declare the debug flags here + */ +unsigned int rpc_debug; +EXPORT_SYMBOL_GPL(rpc_debug); + +unsigned int nfs_debug; +EXPORT_SYMBOL_GPL(nfs_debug); + +unsigned int nfsd_debug; +EXPORT_SYMBOL_GPL(nfsd_debug); + +unsigned int nlm_debug; +EXPORT_SYMBOL_GPL(nlm_debug); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + +static struct ctl_table_header *sunrpc_table_header; +static struct ctl_table sunrpc_table[]; + +void +rpc_register_sysctl(void) +{ + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +} + +void +rpc_unregister_sysctl(void) +{ + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +} + +static int proc_do_xprt(struct ctl_table *table, int write, + void __user *buffer, size_t *lenp, loff_t *ppos) +{ + char tmpbuf[256]; + size_t len; + + if ((*ppos && !write) || !*lenp) { + *lenp = 0; + return 0; + } + len = svc_print_xprts(tmpbuf, sizeof(tmpbuf)); + return simple_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len); +} + +static int +proc_dodebug(struct ctl_table *table, int write, + void __user *buffer, size_t *lenp, loff_t *ppos) +{ + char tmpbuf[20], c, *s = NULL; + char __user *p; + unsigned int value; + size_t left, len; + + if ((*ppos && !write) || !*lenp) { + *lenp = 0; + return 0; + } + + left = *lenp; + + if (write) { + if (!access_ok(VERIFY_READ, buffer, left)) + return -EFAULT; + p = buffer; + while (left && __get_user(c, p) >= 0 && isspace(c)) + left--, p++; + if (!left) + goto done; + + if (left > sizeof(tmpbuf) - 1) + return -EINVAL; + if (copy_from_user(tmpbuf, p, left)) + return -EFAULT; + tmpbuf[left] = '\0'; + + value = simple_strtol(tmpbuf, &s, 0); + if (s) { + left -= (s - tmpbuf); + if (left && !isspace(*s)) + return -EINVAL; + while (left && isspace(*s)) + left--, s++; + } else + left = 0; + *(unsigned int *) table->data = value; + /* Display the RPC tasks on writing to rpc_debug */ + if (strcmp(table->procname, "rpc_debug") == 0) + rpc_show_tasks(&init_net); + } else { + len = sprintf(tmpbuf, "0x%04x", *(unsigned int *) table->data); + if (len > left) + len = left; + if (copy_to_user(buffer, tmpbuf, len)) + return -EFAULT; + if ((left -= len) > 0) { + if (put_user('\n', (char __user *)buffer + len)) + return -EFAULT; + left--; + } + } + +done: + *lenp -= left; + *ppos += *lenp; + return 0; +} + + +static struct ctl_table debug_table[] = { + { + .procname = "rpc_debug", + .data = &rpc_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nfs_debug", + .data = &nfs_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nfsd_debug", + .data = &nfsd_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nlm_debug", + .data = &nlm_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "transports", + .maxlen = 256, + .mode = 0444, + .proc_handler = proc_do_xprt, + }, + { } +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = debug_table + }, + { } +}; + +#endif diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c new file mode 100644 index 000000000..08881d0c9 --- /dev/null +++ b/net/sunrpc/timer.c @@ -0,0 +1,122 @@ +/* + * linux/net/sunrpc/timer.c + * + * Estimate RPC request round trip time. + * + * Based on packet round-trip and variance estimator algorithms described + * in appendix A of "Congestion Avoidance and Control" by Van Jacobson + * and Michael J. Karels (ACM Computer Communication Review; Proceedings + * of the Sigcomm '88 Symposium in Stanford, CA, August, 1988). + * + * This RTT estimator is used only for RPC over datagram protocols. + * + * Copyright (C) 2002 Trond Myklebust <trond.myklebust@fys.uio.no> + */ + +#include <asm/param.h> + +#include <linux/types.h> +#include <linux/unistd.h> +#include <linux/module.h> + +#include <linux/sunrpc/clnt.h> + +#define RPC_RTO_MAX (60*HZ) +#define RPC_RTO_INIT (HZ/5) +#define RPC_RTO_MIN (HZ/10) + +/** + * rpc_init_rtt - Initialize an RPC RTT estimator context + * @rt: context to initialize + * @timeo: initial timeout value, in jiffies + * + */ +void rpc_init_rtt(struct rpc_rtt *rt, unsigned long timeo) +{ + unsigned long init = 0; + unsigned int i; + + rt->timeo = timeo; + + if (timeo > RPC_RTO_INIT) + init = (timeo - RPC_RTO_INIT) << 3; + for (i = 0; i < 5; i++) { + rt->srtt[i] = init; + rt->sdrtt[i] = RPC_RTO_INIT; + rt->ntimeouts[i] = 0; + } +} +EXPORT_SYMBOL_GPL(rpc_init_rtt); + +/** + * rpc_update_rtt - Update an RPC RTT estimator context + * @rt: context to update + * @timer: timer array index (request type) + * @m: recent actual RTT, in jiffies + * + * NB: When computing the smoothed RTT and standard deviation, + * be careful not to produce negative intermediate results. + */ +void rpc_update_rtt(struct rpc_rtt *rt, unsigned int timer, long m) +{ + long *srtt, *sdrtt; + + if (timer-- == 0) + return; + + /* jiffies wrapped; ignore this one */ + if (m < 0) + return; + + if (m == 0) + m = 1L; + + srtt = (long *)&rt->srtt[timer]; + m -= *srtt >> 3; + *srtt += m; + + if (m < 0) + m = -m; + + sdrtt = (long *)&rt->sdrtt[timer]; + m -= *sdrtt >> 2; + *sdrtt += m; + + /* Set lower bound on the variance */ + if (*sdrtt < RPC_RTO_MIN) + *sdrtt = RPC_RTO_MIN; +} +EXPORT_SYMBOL_GPL(rpc_update_rtt); + +/** + * rpc_calc_rto - Provide an estimated timeout value + * @rt: context to use for calculation + * @timer: timer array index (request type) + * + * Estimate RTO for an NFS RPC sent via an unreliable datagram. Use + * the mean and mean deviation of RTT for the appropriate type of RPC + * for frequently issued RPCs, and a fixed default for the others. + * + * The justification for doing "other" this way is that these RPCs + * happen so infrequently that timer estimation would probably be + * stale. Also, since many of these RPCs are non-idempotent, a + * conservative timeout is desired. + * + * getattr, lookup, + * read, write, commit - A+4D + * other - timeo + */ +unsigned long rpc_calc_rto(struct rpc_rtt *rt, unsigned int timer) +{ + unsigned long res; + + if (timer-- == 0) + return rt->timeo; + + res = ((rt->srtt[timer] + 7) >> 3) + rt->sdrtt[timer]; + if (res > RPC_RTO_MAX) + res = RPC_RTO_MAX; + + return res; +} +EXPORT_SYMBOL_GPL(rpc_calc_rto); diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c new file mode 100644 index 000000000..34596d0e4 --- /dev/null +++ b/net/sunrpc/xdr.c @@ -0,0 +1,1643 @@ +/* + * linux/net/sunrpc/xdr.c + * + * Generic XDR support. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/types.h> +#include <linux/string.h> +#include <linux/kernel.h> +#include <linux/pagemap.h> +#include <linux/errno.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/msg_prot.h> + +/* + * XDR functions for basic NFS types + */ +__be32 * +xdr_encode_netobj(__be32 *p, const struct xdr_netobj *obj) +{ + unsigned int quadlen = XDR_QUADLEN(obj->len); + + p[quadlen] = 0; /* zero trailing bytes */ + *p++ = cpu_to_be32(obj->len); + memcpy(p, obj->data, obj->len); + return p + XDR_QUADLEN(obj->len); +} +EXPORT_SYMBOL_GPL(xdr_encode_netobj); + +__be32 * +xdr_decode_netobj(__be32 *p, struct xdr_netobj *obj) +{ + unsigned int len; + + if ((len = be32_to_cpu(*p++)) > XDR_MAX_NETOBJ) + return NULL; + obj->len = len; + obj->data = (u8 *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL_GPL(xdr_decode_netobj); + +/** + * xdr_encode_opaque_fixed - Encode fixed length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Copy the array of data of length nbytes at ptr to the XDR buffer + * at position p, then align to the next 32-bit boundary by padding + * with zero bytes (see RFC1832). + * Note: if ptr is NULL, only the padding is performed. + * + * Returns the updated current XDR buffer position + * + */ +__be32 *xdr_encode_opaque_fixed(__be32 *p, const void *ptr, unsigned int nbytes) +{ + if (likely(nbytes != 0)) { + unsigned int quadlen = XDR_QUADLEN(nbytes); + unsigned int padding = (quadlen << 2) - nbytes; + + if (ptr != NULL) + memcpy(p, ptr, nbytes); + if (padding != 0) + memset((char *)p + nbytes, 0, padding); + p += quadlen; + } + return p; +} +EXPORT_SYMBOL_GPL(xdr_encode_opaque_fixed); + +/** + * xdr_encode_opaque - Encode variable length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Returns the updated current XDR buffer position + */ +__be32 *xdr_encode_opaque(__be32 *p, const void *ptr, unsigned int nbytes) +{ + *p++ = cpu_to_be32(nbytes); + return xdr_encode_opaque_fixed(p, ptr, nbytes); +} +EXPORT_SYMBOL_GPL(xdr_encode_opaque); + +__be32 * +xdr_encode_string(__be32 *p, const char *string) +{ + return xdr_encode_array(p, string, strlen(string)); +} +EXPORT_SYMBOL_GPL(xdr_encode_string); + +__be32 * +xdr_decode_string_inplace(__be32 *p, char **sp, + unsigned int *lenp, unsigned int maxlen) +{ + u32 len; + + len = be32_to_cpu(*p++); + if (len > maxlen) + return NULL; + *lenp = len; + *sp = (char *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL_GPL(xdr_decode_string_inplace); + +/** + * xdr_terminate_string - '\0'-terminate a string residing in an xdr_buf + * @buf: XDR buffer where string resides + * @len: length of string, in bytes + * + */ +void +xdr_terminate_string(struct xdr_buf *buf, const u32 len) +{ + char *kaddr; + + kaddr = kmap_atomic(buf->pages[0]); + kaddr[buf->page_base + len] = '\0'; + kunmap_atomic(kaddr); +} +EXPORT_SYMBOL_GPL(xdr_terminate_string); + +void +xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, + struct page **pages, unsigned int base, unsigned int len) +{ + struct kvec *head = xdr->head; + struct kvec *tail = xdr->tail; + char *buf = (char *)head->iov_base; + unsigned int buflen = head->iov_len; + + head->iov_len = offset; + + xdr->pages = pages; + xdr->page_base = base; + xdr->page_len = len; + + tail->iov_base = buf + offset; + tail->iov_len = buflen - offset; + + xdr->buflen += len; +} +EXPORT_SYMBOL_GPL(xdr_inline_pages); + +/* + * Helper routines for doing 'memmove' like operations on a struct xdr_buf + */ + +/** + * _shift_data_right_pages + * @pages: vector of pages containing both the source and dest memory area. + * @pgto_base: page vector address of destination + * @pgfrom_base: page vector address of source + * @len: number of bytes to copy + * + * Note: the addresses pgto_base and pgfrom_base are both calculated in + * the same way: + * if a memory area starts at byte 'base' in page 'pages[i]', + * then its address is given as (i << PAGE_SHIFT) + base + * Also note: pgfrom_base must be < pgto_base, but the memory areas + * they point to may overlap. + */ +static void +_shift_data_right_pages(struct page **pages, size_t pgto_base, + size_t pgfrom_base, size_t len) +{ + struct page **pgfrom, **pgto; + char *vfrom, *vto; + size_t copy; + + BUG_ON(pgto_base <= pgfrom_base); + + pgto_base += len; + pgfrom_base += len; + + pgto = pages + (pgto_base >> PAGE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_SHIFT); + + pgto_base &= ~PAGE_MASK; + pgfrom_base &= ~PAGE_MASK; + + do { + /* Are any pointers crossing a page boundary? */ + if (pgto_base == 0) { + pgto_base = PAGE_SIZE; + pgto--; + } + if (pgfrom_base == 0) { + pgfrom_base = PAGE_SIZE; + pgfrom--; + } + + copy = len; + if (copy > pgto_base) + copy = pgto_base; + if (copy > pgfrom_base) + copy = pgfrom_base; + pgto_base -= copy; + pgfrom_base -= copy; + + vto = kmap_atomic(*pgto); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); + flush_dcache_page(*pgto); + kunmap_atomic(vto); + + } while ((len -= copy) != 0); +} + +/** + * _copy_to_pages + * @pages: array of pages + * @pgbase: page vector address of destination + * @p: pointer to source data + * @len: length + * + * Copies data from an arbitrary memory location into an array of pages + * The copy is assumed to be non-overlapping. + */ +static void +_copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) +{ + struct page **pgto; + char *vto; + size_t copy; + + pgto = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + for (;;) { + copy = PAGE_SIZE - pgbase; + if (copy > len) + copy = len; + + vto = kmap_atomic(*pgto); + memcpy(vto + pgbase, p, copy); + kunmap_atomic(vto); + + len -= copy; + if (len == 0) + break; + + pgbase += copy; + if (pgbase == PAGE_SIZE) { + flush_dcache_page(*pgto); + pgbase = 0; + pgto++; + } + p += copy; + } + flush_dcache_page(*pgto); +} + +/** + * _copy_from_pages + * @p: pointer to destination + * @pages: array of pages + * @pgbase: offset of source data + * @len: length + * + * Copies data into an arbitrary memory location from an array of pages + * The copy is assumed to be non-overlapping. + */ +void +_copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) +{ + struct page **pgfrom; + char *vfrom; + size_t copy; + + pgfrom = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + do { + copy = PAGE_SIZE - pgbase; + if (copy > len) + copy = len; + + vfrom = kmap_atomic(*pgfrom); + memcpy(p, vfrom + pgbase, copy); + kunmap_atomic(vfrom); + + pgbase += copy; + if (pgbase == PAGE_SIZE) { + pgbase = 0; + pgfrom++; + } + p += copy; + + } while ((len -= copy) != 0); +} +EXPORT_SYMBOL_GPL(_copy_from_pages); + +/** + * xdr_shrink_bufhead + * @buf: xdr_buf + * @len: bytes to remove from buf->head[0] + * + * Shrinks XDR buffer's header kvec buf->head[0] by + * 'len' bytes. The extra data is not lost, but is instead + * moved into the inlined pages and/or the tail. + */ +static void +xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) +{ + struct kvec *head, *tail; + size_t copy, offs; + unsigned int pglen = buf->page_len; + + tail = buf->tail; + head = buf->head; + + WARN_ON_ONCE(len > head->iov_len); + if (len > head->iov_len) + len = head->iov_len; + + /* Shift the tail first */ + if (tail->iov_len != 0) { + if (tail->iov_len > len) { + copy = tail->iov_len - len; + memmove((char *)tail->iov_base + len, + tail->iov_base, copy); + } + /* Copy from the inlined pages into the tail */ + copy = len; + if (copy > pglen) + copy = pglen; + offs = len - copy; + if (offs >= tail->iov_len) + copy = 0; + else if (copy > tail->iov_len - offs) + copy = tail->iov_len - offs; + if (copy != 0) + _copy_from_pages((char *)tail->iov_base + offs, + buf->pages, + buf->page_base + pglen + offs - len, + copy); + /* Do we also need to copy data from the head into the tail ? */ + if (len > pglen) { + offs = copy = len - pglen; + if (copy > tail->iov_len) + copy = tail->iov_len; + memcpy(tail->iov_base, + (char *)head->iov_base + + head->iov_len - offs, + copy); + } + } + /* Now handle pages */ + if (pglen != 0) { + if (pglen > len) + _shift_data_right_pages(buf->pages, + buf->page_base + len, + buf->page_base, + pglen - len); + copy = len; + if (len > pglen) + copy = pglen; + _copy_to_pages(buf->pages, buf->page_base, + (char *)head->iov_base + head->iov_len - len, + copy); + } + head->iov_len -= len; + buf->buflen -= len; + /* Have we truncated the message? */ + if (buf->len > buf->buflen) + buf->len = buf->buflen; +} + +/** + * xdr_shrink_pagelen + * @buf: xdr_buf + * @len: bytes to remove from buf->pages + * + * Shrinks XDR buffer's page array buf->pages by + * 'len' bytes. The extra data is not lost, but is instead + * moved into the tail. + */ +static void +xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) +{ + struct kvec *tail; + size_t copy; + unsigned int pglen = buf->page_len; + unsigned int tailbuf_len; + + tail = buf->tail; + BUG_ON (len > pglen); + + tailbuf_len = buf->buflen - buf->head->iov_len - buf->page_len; + + /* Shift the tail first */ + if (tailbuf_len != 0) { + unsigned int free_space = tailbuf_len - tail->iov_len; + + if (len < free_space) + free_space = len; + tail->iov_len += free_space; + + copy = len; + if (tail->iov_len > len) { + char *p = (char *)tail->iov_base + len; + memmove(p, tail->iov_base, tail->iov_len - len); + } else + copy = tail->iov_len; + /* Copy from the inlined pages into the tail */ + _copy_from_pages((char *)tail->iov_base, + buf->pages, buf->page_base + pglen - len, + copy); + } + buf->page_len -= len; + buf->buflen -= len; + /* Have we truncated the message? */ + if (buf->len > buf->buflen) + buf->len = buf->buflen; +} + +void +xdr_shift_buf(struct xdr_buf *buf, size_t len) +{ + xdr_shrink_bufhead(buf, len); +} +EXPORT_SYMBOL_GPL(xdr_shift_buf); + +/** + * xdr_stream_pos - Return the current offset from the start of the xdr_stream + * @xdr: pointer to struct xdr_stream + */ +unsigned int xdr_stream_pos(const struct xdr_stream *xdr) +{ + return (unsigned int)(XDR_QUADLEN(xdr->buf->len) - xdr->nwords) << 2; +} +EXPORT_SYMBOL_GPL(xdr_stream_pos); + +/** + * xdr_init_encode - Initialize a struct xdr_stream for sending data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer in which to encode data + * @p: current pointer inside XDR buffer + * + * Note: at the moment the RPC client only passes the length of our + * scratch buffer in the xdr_buf's header kvec. Previously this + * meant we needed to call xdr_adjust_iovec() after encoding the + * data. With the new scheme, the xdr_stream manages the details + * of the buffer length, and takes care of adjusting the kvec + * length for us. + */ +void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +{ + struct kvec *iov = buf->head; + int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; + + xdr_set_scratch_buffer(xdr, NULL, 0); + BUG_ON(scratch_len < 0); + xdr->buf = buf; + xdr->iov = iov; + xdr->p = (__be32 *)((char *)iov->iov_base + iov->iov_len); + xdr->end = (__be32 *)((char *)iov->iov_base + scratch_len); + BUG_ON(iov->iov_len > scratch_len); + + if (p != xdr->p && p != NULL) { + size_t len; + + BUG_ON(p < xdr->p || p > xdr->end); + len = (char *)p - (char *)xdr->p; + xdr->p = p; + buf->len += len; + iov->iov_len += len; + } +} +EXPORT_SYMBOL_GPL(xdr_init_encode); + +/** + * xdr_commit_encode - Ensure all data is written to buffer + * @xdr: pointer to xdr_stream + * + * We handle encoding across page boundaries by giving the caller a + * temporary location to write to, then later copying the data into + * place; xdr_commit_encode does that copying. + * + * Normally the caller doesn't need to call this directly, as the + * following xdr_reserve_space will do it. But an explicit call may be + * required at the end of encoding, or any other time when the xdr_buf + * data might be read. + */ +void xdr_commit_encode(struct xdr_stream *xdr) +{ + int shift = xdr->scratch.iov_len; + void *page; + + if (shift == 0) + return; + page = page_address(*xdr->page_ptr); + memcpy(xdr->scratch.iov_base, page, shift); + memmove(page, page + shift, (void *)xdr->p - page); + xdr->scratch.iov_len = 0; +} +EXPORT_SYMBOL_GPL(xdr_commit_encode); + +static __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, + size_t nbytes) +{ + __be32 *p; + int space_left; + int frag1bytes, frag2bytes; + + if (nbytes > PAGE_SIZE) + return NULL; /* Bigger buffers require special handling */ + if (xdr->buf->len + nbytes > xdr->buf->buflen) + return NULL; /* Sorry, we're totally out of space */ + frag1bytes = (xdr->end - xdr->p) << 2; + frag2bytes = nbytes - frag1bytes; + if (xdr->iov) + xdr->iov->iov_len += frag1bytes; + else + xdr->buf->page_len += frag1bytes; + xdr->page_ptr++; + xdr->iov = NULL; + /* + * If the last encode didn't end exactly on a page boundary, the + * next one will straddle boundaries. Encode into the next + * page, then copy it back later in xdr_commit_encode. We use + * the "scratch" iov to track any temporarily unused fragment of + * space at the end of the previous buffer: + */ + xdr->scratch.iov_base = xdr->p; + xdr->scratch.iov_len = frag1bytes; + p = page_address(*xdr->page_ptr); + /* + * Note this is where the next encode will start after we've + * shifted this one back: + */ + xdr->p = (void *)p + frag2bytes; + space_left = xdr->buf->buflen - xdr->buf->len; + if (space_left - nbytes >= PAGE_SIZE) + xdr->end = (void *)p + PAGE_SIZE; + else + xdr->end = (void *)p + space_left - frag1bytes; + + xdr->buf->page_len += frag2bytes; + xdr->buf->len += nbytes; + return p; +} + +/** + * xdr_reserve_space - Reserve buffer space for sending + * @xdr: pointer to xdr_stream + * @nbytes: number of bytes to reserve + * + * Checks that we have enough buffer space to encode 'nbytes' more + * bytes of data. If so, update the total xdr_buf length, and + * adjust the length of the current kvec. + */ +__be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p = xdr->p; + __be32 *q; + + xdr_commit_encode(xdr); + /* align nbytes on the next 32-bit boundary */ + nbytes += 3; + nbytes &= ~3; + q = p + (nbytes >> 2); + if (unlikely(q > xdr->end || q < p)) + return xdr_get_next_encode_buffer(xdr, nbytes); + xdr->p = q; + if (xdr->iov) + xdr->iov->iov_len += nbytes; + else + xdr->buf->page_len += nbytes; + xdr->buf->len += nbytes; + return p; +} +EXPORT_SYMBOL_GPL(xdr_reserve_space); + +/** + * xdr_truncate_encode - truncate an encode buffer + * @xdr: pointer to xdr_stream + * @len: new length of buffer + * + * Truncates the xdr stream, so that xdr->buf->len == len, + * and xdr->p points at offset len from the start of the buffer, and + * head, tail, and page lengths are adjusted to correspond. + * + * If this means moving xdr->p to a different buffer, we assume that + * that the end pointer should be set to the end of the current page, + * except in the case of the head buffer when we assume the head + * buffer's current length represents the end of the available buffer. + * + * This is *not* safe to use on a buffer that already has inlined page + * cache pages (as in a zero-copy server read reply), except for the + * simple case of truncating from one position in the tail to another. + * + */ +void xdr_truncate_encode(struct xdr_stream *xdr, size_t len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + int fraglen; + int new; + + if (len > buf->len) { + WARN_ON_ONCE(1); + return; + } + xdr_commit_encode(xdr); + + fraglen = min_t(int, buf->len - len, tail->iov_len); + tail->iov_len -= fraglen; + buf->len -= fraglen; + if (tail->iov_len) { + xdr->p = tail->iov_base + tail->iov_len; + WARN_ON_ONCE(!xdr->end); + WARN_ON_ONCE(!xdr->iov); + return; + } + WARN_ON_ONCE(fraglen); + fraglen = min_t(int, buf->len - len, buf->page_len); + buf->page_len -= fraglen; + buf->len -= fraglen; + + new = buf->page_base + buf->page_len; + + xdr->page_ptr = buf->pages + (new >> PAGE_SHIFT); + + if (buf->page_len) { + xdr->p = page_address(*xdr->page_ptr); + xdr->end = (void *)xdr->p + PAGE_SIZE; + xdr->p = (void *)xdr->p + (new % PAGE_SIZE); + WARN_ON_ONCE(xdr->iov); + return; + } + if (fraglen) + xdr->end = head->iov_base + head->iov_len; + /* (otherwise assume xdr->end is already set) */ + xdr->page_ptr--; + head->iov_len = len; + buf->len = len; + xdr->p = head->iov_base + head->iov_len; + xdr->iov = buf->head; +} +EXPORT_SYMBOL(xdr_truncate_encode); + +/** + * xdr_restrict_buflen - decrease available buffer space + * @xdr: pointer to xdr_stream + * @newbuflen: new maximum number of bytes available + * + * Adjust our idea of how much space is available in the buffer. + * If we've already used too much space in the buffer, returns -1. + * If the available space is already smaller than newbuflen, returns 0 + * and does nothing. Otherwise, adjusts xdr->buf->buflen to newbuflen + * and ensures xdr->end is set at most offset newbuflen from the start + * of the buffer. + */ +int xdr_restrict_buflen(struct xdr_stream *xdr, int newbuflen) +{ + struct xdr_buf *buf = xdr->buf; + int left_in_this_buf = (void *)xdr->end - (void *)xdr->p; + int end_offset = buf->len + left_in_this_buf; + + if (newbuflen < 0 || newbuflen < buf->len) + return -1; + if (newbuflen > buf->buflen) + return 0; + if (newbuflen < end_offset) + xdr->end = (void *)xdr->end + newbuflen - end_offset; + buf->buflen = newbuflen; + return 0; +} +EXPORT_SYMBOL(xdr_restrict_buflen); + +/** + * xdr_write_pages - Insert a list of pages into an XDR buffer for sending + * @xdr: pointer to xdr_stream + * @pages: list of pages + * @base: offset of first byte + * @len: length of data in bytes + * + */ +void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base, + unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov = buf->tail; + buf->pages = pages; + buf->page_base = base; + buf->page_len = len; + + iov->iov_base = (char *)xdr->p; + iov->iov_len = 0; + xdr->iov = iov; + + if (len & 3) { + unsigned int pad = 4 - (len & 3); + + BUG_ON(xdr->p >= xdr->end); + iov->iov_base = (char *)xdr->p + (len & 3); + iov->iov_len += pad; + len += pad; + *xdr->p++ = 0; + } + buf->buflen += len; + buf->len += len; +} +EXPORT_SYMBOL_GPL(xdr_write_pages); + +static void xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov, + unsigned int len) +{ + if (len > iov->iov_len) + len = iov->iov_len; + xdr->p = (__be32*)iov->iov_base; + xdr->end = (__be32*)(iov->iov_base + len); + xdr->iov = iov; + xdr->page_ptr = NULL; +} + +static int xdr_set_page_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) +{ + unsigned int pgnr; + unsigned int maxlen; + unsigned int pgoff; + unsigned int pgend; + void *kaddr; + + maxlen = xdr->buf->page_len; + if (base >= maxlen) + return -EINVAL; + maxlen -= base; + if (len > maxlen) + len = maxlen; + + base += xdr->buf->page_base; + + pgnr = base >> PAGE_SHIFT; + xdr->page_ptr = &xdr->buf->pages[pgnr]; + kaddr = page_address(*xdr->page_ptr); + + pgoff = base & ~PAGE_MASK; + xdr->p = (__be32*)(kaddr + pgoff); + + pgend = pgoff + len; + if (pgend > PAGE_SIZE) + pgend = PAGE_SIZE; + xdr->end = (__be32*)(kaddr + pgend); + xdr->iov = NULL; + return 0; +} + +static void xdr_set_next_page(struct xdr_stream *xdr) +{ + unsigned int newbase; + + newbase = (1 + xdr->page_ptr - xdr->buf->pages) << PAGE_SHIFT; + newbase -= xdr->buf->page_base; + + if (xdr_set_page_base(xdr, newbase, PAGE_SIZE) < 0) + xdr_set_iov(xdr, xdr->buf->tail, xdr->nwords << 2); +} + +static bool xdr_set_next_buffer(struct xdr_stream *xdr) +{ + if (xdr->page_ptr != NULL) + xdr_set_next_page(xdr); + else if (xdr->iov == xdr->buf->head) { + if (xdr_set_page_base(xdr, 0, PAGE_SIZE) < 0) + xdr_set_iov(xdr, xdr->buf->tail, xdr->nwords << 2); + } + return xdr->p != xdr->end; +} + +/** + * xdr_init_decode - Initialize an xdr_stream for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @p: current pointer inside XDR buffer + */ +void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +{ + xdr->buf = buf; + xdr->scratch.iov_base = NULL; + xdr->scratch.iov_len = 0; + xdr->nwords = XDR_QUADLEN(buf->len); + if (buf->head[0].iov_len != 0) + xdr_set_iov(xdr, buf->head, buf->len); + else if (buf->page_len != 0) + xdr_set_page_base(xdr, 0, buf->len); + else + xdr_set_iov(xdr, buf->head, buf->len); + if (p != NULL && p > xdr->p && xdr->end >= p) { + xdr->nwords -= p - xdr->p; + xdr->p = p; + } +} +EXPORT_SYMBOL_GPL(xdr_init_decode); + +/** + * xdr_init_decode_pages - Initialize an xdr_stream for decoding into pages + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @pages: list of pages to decode into + * @len: length in bytes of buffer in pages + */ +void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, + struct page **pages, unsigned int len) +{ + memset(buf, 0, sizeof(*buf)); + buf->pages = pages; + buf->page_len = len; + buf->buflen = len; + buf->len = len; + xdr_init_decode(xdr, buf, NULL); +} +EXPORT_SYMBOL_GPL(xdr_init_decode_pages); + +static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + unsigned int nwords = XDR_QUADLEN(nbytes); + __be32 *p = xdr->p; + __be32 *q = p + nwords; + + if (unlikely(nwords > xdr->nwords || q > xdr->end || q < p)) + return NULL; + xdr->p = q; + xdr->nwords -= nwords; + return p; +} + +/** + * xdr_set_scratch_buffer - Attach a scratch buffer for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to an empty buffer + * @buflen: size of 'buf' + * + * The scratch buffer is used when decoding from an array of pages. + * If an xdr_inline_decode() call spans across page boundaries, then + * we copy the data into the scratch buffer in order to allow linear + * access. + */ +void xdr_set_scratch_buffer(struct xdr_stream *xdr, void *buf, size_t buflen) +{ + xdr->scratch.iov_base = buf; + xdr->scratch.iov_len = buflen; +} +EXPORT_SYMBOL_GPL(xdr_set_scratch_buffer); + +static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p; + char *cpdest = xdr->scratch.iov_base; + size_t cplen = (char *)xdr->end - (char *)xdr->p; + + if (nbytes > xdr->scratch.iov_len) + return NULL; + p = __xdr_inline_decode(xdr, cplen); + if (p == NULL) + return NULL; + memcpy(cpdest, p, cplen); + cpdest += cplen; + nbytes -= cplen; + if (!xdr_set_next_buffer(xdr)) + return NULL; + p = __xdr_inline_decode(xdr, nbytes); + if (p == NULL) + return NULL; + memcpy(cpdest, p, nbytes); + return xdr->scratch.iov_base; +} + +/** + * xdr_inline_decode - Retrieve XDR data to decode + * @xdr: pointer to xdr_stream struct + * @nbytes: number of bytes of data to decode + * + * Check if the input buffer is long enough to enable us to decode + * 'nbytes' more bytes of data starting at the current position. + * If so return the current pointer, then update the current + * pointer position. + */ +__be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p; + + if (nbytes == 0) + return xdr->p; + if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) + return NULL; + p = __xdr_inline_decode(xdr, nbytes); + if (p != NULL) + return p; + return xdr_copy_to_scratch(xdr, nbytes); +} +EXPORT_SYMBOL_GPL(xdr_inline_decode); + +static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov; + unsigned int nwords = XDR_QUADLEN(len); + unsigned int cur = xdr_stream_pos(xdr); + + if (xdr->nwords == 0) + return 0; + /* Realign pages to current pointer position */ + iov = buf->head; + if (iov->iov_len > cur) { + xdr_shrink_bufhead(buf, iov->iov_len - cur); + xdr->nwords = XDR_QUADLEN(buf->len - cur); + } + + if (nwords > xdr->nwords) { + nwords = xdr->nwords; + len = nwords << 2; + } + if (buf->page_len <= len) + len = buf->page_len; + else if (nwords < xdr->nwords) { + /* Truncate page data and move it into the tail */ + xdr_shrink_pagelen(buf, buf->page_len - len); + xdr->nwords = XDR_QUADLEN(buf->len - cur); + } + return len; +} + +/** + * xdr_read_pages - Ensure page-based XDR data to decode is aligned at current pointer position + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + "len" + * bytes is moved into the XDR tail[]. + * + * Returns the number of XDR encoded bytes now contained in the pages + */ +unsigned int xdr_read_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov; + unsigned int nwords; + unsigned int end; + unsigned int padding; + + len = xdr_align_pages(xdr, len); + if (len == 0) + return 0; + nwords = XDR_QUADLEN(len); + padding = (nwords << 2) - len; + xdr->iov = iov = buf->tail; + /* Compute remaining message length. */ + end = ((xdr->nwords - nwords) << 2) + padding; + if (end > iov->iov_len) + end = iov->iov_len; + + /* + * Position current pointer at beginning of tail, and + * set remaining message length. + */ + xdr->p = (__be32 *)((char *)iov->iov_base + padding); + xdr->end = (__be32 *)((char *)iov->iov_base + end); + xdr->page_ptr = NULL; + xdr->nwords = XDR_QUADLEN(end - padding); + return len; +} +EXPORT_SYMBOL_GPL(xdr_read_pages); + +/** + * xdr_enter_page - decode data from the XDR page + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + "len" + * bytes is moved into the XDR tail[]. The current pointer is then + * repositioned at the beginning of the first XDR page. + */ +void xdr_enter_page(struct xdr_stream *xdr, unsigned int len) +{ + len = xdr_align_pages(xdr, len); + /* + * Position current pointer at beginning of tail, and + * set remaining message length. + */ + if (len != 0) + xdr_set_page_base(xdr, 0, len); +} +EXPORT_SYMBOL_GPL(xdr_enter_page); + +static struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0}; + +void +xdr_buf_from_iov(struct kvec *iov, struct xdr_buf *buf) +{ + buf->head[0] = *iov; + buf->tail[0] = empty_iov; + buf->page_len = 0; + buf->buflen = buf->len = iov->iov_len; +} +EXPORT_SYMBOL_GPL(xdr_buf_from_iov); + +/** + * xdr_buf_subsegment - set subbuf to a portion of buf + * @buf: an xdr buffer + * @subbuf: the result buffer + * @base: beginning of range in bytes + * @len: length of range in bytes + * + * sets @subbuf to an xdr buffer representing the portion of @buf of + * length @len starting at offset @base. + * + * @buf and @subbuf may be pointers to the same struct xdr_buf. + * + * Returns -1 if base of length are out of bounds. + */ +int +xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, + unsigned int base, unsigned int len) +{ + subbuf->buflen = subbuf->len = len; + if (base < buf->head[0].iov_len) { + subbuf->head[0].iov_base = buf->head[0].iov_base + base; + subbuf->head[0].iov_len = min_t(unsigned int, len, + buf->head[0].iov_len - base); + len -= subbuf->head[0].iov_len; + base = 0; + } else { + base -= buf->head[0].iov_len; + subbuf->head[0].iov_base = buf->head[0].iov_base; + subbuf->head[0].iov_len = 0; + } + + if (base < buf->page_len) { + subbuf->page_len = min(buf->page_len - base, len); + base += buf->page_base; + subbuf->page_base = base & ~PAGE_MASK; + subbuf->pages = &buf->pages[base >> PAGE_SHIFT]; + len -= subbuf->page_len; + base = 0; + } else { + base -= buf->page_len; + subbuf->pages = buf->pages; + subbuf->page_base = 0; + subbuf->page_len = 0; + } + + if (base < buf->tail[0].iov_len) { + subbuf->tail[0].iov_base = buf->tail[0].iov_base + base; + subbuf->tail[0].iov_len = min_t(unsigned int, len, + buf->tail[0].iov_len - base); + len -= subbuf->tail[0].iov_len; + base = 0; + } else { + base -= buf->tail[0].iov_len; + subbuf->tail[0].iov_base = buf->tail[0].iov_base; + subbuf->tail[0].iov_len = 0; + } + + if (base || len) + return -1; + return 0; +} +EXPORT_SYMBOL_GPL(xdr_buf_subsegment); + +/** + * xdr_buf_trim - lop at most "len" bytes off the end of "buf" + * @buf: buf to be trimmed + * @len: number of bytes to reduce "buf" by + * + * Trim an xdr_buf by the given number of bytes by fixing up the lengths. Note + * that it's possible that we'll trim less than that amount if the xdr_buf is + * too small, or if (for instance) it's all in the head and the parser has + * already read too far into it. + */ +void xdr_buf_trim(struct xdr_buf *buf, unsigned int len) +{ + size_t cur; + unsigned int trim = len; + + if (buf->tail[0].iov_len) { + cur = min_t(size_t, buf->tail[0].iov_len, trim); + buf->tail[0].iov_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->page_len) { + cur = min_t(unsigned int, buf->page_len, trim); + buf->page_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->head[0].iov_len) { + cur = min_t(size_t, buf->head[0].iov_len, trim); + buf->head[0].iov_len -= cur; + trim -= cur; + } +fix_len: + buf->len -= (len - trim); +} +EXPORT_SYMBOL_GPL(xdr_buf_trim); + +static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(obj, subbuf->head[0].iov_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + if (this_len) + _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(obj, subbuf->tail[0].iov_base, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int read_bytes_from_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __read_bytes_from_xdr_buf(&subbuf, obj, len); + return 0; +} +EXPORT_SYMBOL_GPL(read_bytes_from_xdr_buf); + +static void __write_bytes_to_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(subbuf->head[0].iov_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + if (this_len) + _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(subbuf->tail[0].iov_base, obj, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int write_bytes_to_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __write_bytes_to_xdr_buf(&subbuf, obj, len); + return 0; +} +EXPORT_SYMBOL_GPL(write_bytes_to_xdr_buf); + +int +xdr_decode_word(struct xdr_buf *buf, unsigned int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = be32_to_cpu(raw); + return 0; +} +EXPORT_SYMBOL_GPL(xdr_decode_word); + +int +xdr_encode_word(struct xdr_buf *buf, unsigned int base, u32 obj) +{ + __be32 raw = cpu_to_be32(obj); + + return write_bytes_to_xdr_buf(buf, base, &raw, sizeof(obj)); +} +EXPORT_SYMBOL_GPL(xdr_encode_word); + +/* If the netobj starting offset bytes from the start of xdr_buf is contained + * entirely in the head or the tail, set object to point to it; otherwise + * try to find space for it at the end of the tail, copy it there, and + * set obj to point to it. */ +int xdr_buf_read_netobj(struct xdr_buf *buf, struct xdr_netobj *obj, unsigned int offset) +{ + struct xdr_buf subbuf; + + if (xdr_decode_word(buf, offset, &obj->len)) + return -EFAULT; + if (xdr_buf_subsegment(buf, &subbuf, offset + 4, obj->len)) + return -EFAULT; + + /* Is the obj contained entirely in the head? */ + obj->data = subbuf.head[0].iov_base; + if (subbuf.head[0].iov_len == obj->len) + return 0; + /* ..or is the obj contained entirely in the tail? */ + obj->data = subbuf.tail[0].iov_base; + if (subbuf.tail[0].iov_len == obj->len) + return 0; + + /* use end of tail as storage for obj: + * (We don't copy to the beginning because then we'd have + * to worry about doing a potentially overlapping copy. + * This assumes the object is at most half the length of the + * tail.) */ + if (obj->len > buf->buflen - buf->len) + return -ENOMEM; + if (buf->tail[0].iov_len != 0) + obj->data = buf->tail[0].iov_base + buf->tail[0].iov_len; + else + obj->data = buf->head[0].iov_base + buf->head[0].iov_len; + __read_bytes_from_xdr_buf(&subbuf, obj->data, obj->len); + return 0; +} +EXPORT_SYMBOL_GPL(xdr_buf_read_netobj); + +/* Returns 0 on success, or else a negative error code. */ +static int +xdr_xcode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc, int encode) +{ + char *elem = NULL, *c; + unsigned int copied = 0, todo, avail_here; + struct page **ppages = NULL; + int err; + + if (encode) { + if (xdr_encode_word(buf, base, desc->array_len) != 0) + return -EINVAL; + } else { + if (xdr_decode_word(buf, base, &desc->array_len) != 0 || + desc->array_len > desc->array_maxlen || + (unsigned long) base + 4 + desc->array_len * + desc->elem_size > buf->len) + return -EINVAL; + } + base += 4; + + if (!desc->xcode) + return 0; + + todo = desc->array_len * desc->elem_size; + + /* process head */ + if (todo && base < buf->head->iov_len) { + c = buf->head->iov_base + base; + avail_here = min_t(unsigned int, todo, + buf->head->iov_len - base); + todo -= avail_here; + + while (avail_here >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_here -= desc->elem_size; + } + if (avail_here) { + if (!elem) { + elem = kmalloc(desc->elem_size, GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + err = desc->xcode(desc, elem); + if (err) + goto out; + memcpy(c, elem, avail_here); + } else + memcpy(elem, c, avail_here); + copied = avail_here; + } + base = buf->head->iov_len; /* align to start of pages */ + } + + /* process pages array */ + base -= buf->head->iov_len; + if (todo && base < buf->page_len) { + unsigned int avail_page; + + avail_here = min(todo, buf->page_len - base); + todo -= avail_here; + + base += buf->page_base; + ppages = buf->pages + (base >> PAGE_SHIFT); + base &= ~PAGE_MASK; + avail_page = min_t(unsigned int, PAGE_SIZE - base, + avail_here); + c = kmap(*ppages) + base; + + while (avail_here) { + avail_here -= avail_page; + if (copied || avail_page < desc->elem_size) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + avail_page -= l; + c += l; + } + while (avail_page >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_page -= desc->elem_size; + } + if (avail_page) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + } + if (avail_here) { + kunmap(*ppages); + ppages++; + c = kmap(*ppages); + } + + avail_page = min(avail_here, + (unsigned int) PAGE_SIZE); + } + base = buf->page_len; /* align to start of tail */ + } + + /* process tail */ + base -= buf->page_len; + if (todo) { + c = buf->tail->iov_base + base; + if (copied) { + unsigned int l = desc->elem_size - copied; + + if (encode) + memcpy(c, elem + copied, l); + else { + memcpy(elem + copied, c, l); + err = desc->xcode(desc, elem); + if (err) + goto out; + } + todo -= l; + c += l; + } + while (todo) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + todo -= desc->elem_size; + } + } + err = 0; + +out: + kfree(elem); + if (ppages) + kunmap(*ppages); + return err; +} + +int +xdr_decode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if (base >= buf->len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 0); +} +EXPORT_SYMBOL_GPL(xdr_decode_array2); + +int +xdr_encode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if ((unsigned long) base + 4 + desc->array_len * desc->elem_size > + buf->head->iov_len + buf->page_len + buf->tail->iov_len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 1); +} +EXPORT_SYMBOL_GPL(xdr_encode_array2); + +int +xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, + int (*actor)(struct scatterlist *, void *), void *data) +{ + int i, ret = 0; + unsigned int page_len, thislen, page_offset; + struct scatterlist sg[1]; + + sg_init_table(sg, 1); + + if (offset >= buf->head[0].iov_len) { + offset -= buf->head[0].iov_len; + } else { + thislen = buf->head[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->head[0].iov_base + offset, thislen); + ret = actor(sg, data); + if (ret) + goto out; + offset = 0; + len -= thislen; + } + if (len == 0) + goto out; + + if (offset >= buf->page_len) { + offset -= buf->page_len; + } else { + page_len = buf->page_len - offset; + if (page_len > len) + page_len = len; + len -= page_len; + page_offset = (offset + buf->page_base) & (PAGE_SIZE - 1); + i = (offset + buf->page_base) >> PAGE_SHIFT; + thislen = PAGE_SIZE - page_offset; + do { + if (thislen > page_len) + thislen = page_len; + sg_set_page(sg, buf->pages[i], thislen, page_offset); + ret = actor(sg, data); + if (ret) + goto out; + page_len -= thislen; + i++; + page_offset = 0; + thislen = PAGE_SIZE; + } while (page_len != 0); + offset = 0; + } + if (len == 0) + goto out; + if (offset < buf->tail[0].iov_len) { + thislen = buf->tail[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->tail[0].iov_base + offset, thislen); + ret = actor(sg, data); + len -= thislen; + } + if (len != 0) + ret = -EINVAL; +out: + return ret; +} +EXPORT_SYMBOL_GPL(xdr_process_buf); + +/** + * xdr_stream_decode_opaque - Decode variable length opaque + * @xdr: pointer to xdr_stream + * @ptr: location to store opaque data + * @size: size of storage buffer @ptr + * + * Return values: + * On success, returns size of object stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE on overflow of storage buffer @ptr + */ +ssize_t xdr_stream_decode_opaque(struct xdr_stream *xdr, void *ptr, size_t size) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, size); + if (ret <= 0) + return ret; + memcpy(ptr, p, ret); + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque); + +/** + * xdr_stream_decode_opaque_dup - Decode and duplicate variable length opaque + * @xdr: pointer to xdr_stream + * @ptr: location to store pointer to opaque data + * @maxlen: maximum acceptable object size + * @gfp_flags: GFP mask to use + * + * Return values: + * On success, returns size of object stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of the object would exceed @maxlen + * %-ENOMEM on memory allocation failure + */ +ssize_t xdr_stream_decode_opaque_dup(struct xdr_stream *xdr, void **ptr, + size_t maxlen, gfp_t gfp_flags) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen); + if (ret > 0) { + *ptr = kmemdup(p, ret, gfp_flags); + if (*ptr != NULL) + return ret; + ret = -ENOMEM; + } + *ptr = NULL; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque_dup); + +/** + * xdr_stream_decode_string - Decode variable length string + * @xdr: pointer to xdr_stream + * @str: location to store string + * @size: size of storage buffer @str + * + * Return values: + * On success, returns length of NUL-terminated string stored in *@str + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE on overflow of storage buffer @str + */ +ssize_t xdr_stream_decode_string(struct xdr_stream *xdr, char *str, size_t size) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, size); + if (ret > 0) { + memcpy(str, p, ret); + str[ret] = '\0'; + return strlen(str); + } + *str = '\0'; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_string); + +/** + * xdr_stream_decode_string_dup - Decode and duplicate variable length string + * @xdr: pointer to xdr_stream + * @str: location to store pointer to string + * @maxlen: maximum acceptable string length + * @gfp_flags: GFP mask to use + * + * Return values: + * On success, returns length of NUL-terminated string stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of the string would exceed @maxlen + * %-ENOMEM on memory allocation failure + */ +ssize_t xdr_stream_decode_string_dup(struct xdr_stream *xdr, char **str, + size_t maxlen, gfp_t gfp_flags) +{ + void *p; + ssize_t ret; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen); + if (ret > 0) { + char *s = kmalloc(ret + 1, gfp_flags); + if (s != NULL) { + memcpy(s, p, ret); + s[ret] = '\0'; + *str = s; + return strlen(s); + } + ret = -ENOMEM; + } + *str = NULL; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_string_dup); diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c new file mode 100644 index 000000000..d05fa7c36 --- /dev/null +++ b/net/sunrpc/xprt.c @@ -0,0 +1,1593 @@ +/* + * linux/net/sunrpc/xprt.c + * + * This is a generic RPC call interface supporting congestion avoidance, + * and asynchronous calls. + * + * The interface works like this: + * + * - When a process places a call, it allocates a request slot if + * one is available. Otherwise, it sleeps on the backlog queue + * (xprt_reserve). + * - Next, the caller puts together the RPC message, stuffs it into + * the request struct, and calls xprt_transmit(). + * - xprt_transmit sends the message and installs the caller on the + * transport's wait list. At the same time, if a reply is expected, + * it installs a timer that is run after the packet's timeout has + * expired. + * - When a packet arrives, the data_ready handler walks the list of + * pending requests for that transport. If a matching XID is found, the + * caller is woken up, and the timer removed. + * - When no reply arrives within the timeout interval, the timer is + * fired by the kernel and runs xprt_timer(). It either adjusts the + * timeout values (minor timeout) or wakes up the caller with a status + * of -ETIMEDOUT. + * - When the caller receives a notification from RPC that a reply arrived, + * it should release the RPC slot, and process the reply. + * If the call timed out, it may choose to retry the operation by + * adjusting the initial timeout value, and simply calling rpc_call + * again. + * + * Support for async RPC is done through a set of RPC-specific scheduling + * primitives that `transparently' work for processes as well as async + * tasks that rely on callbacks. + * + * Copyright (C) 1995-1997, Olaf Kirch <okir@monad.swb.de> + * + * Transport switch API copyright (C) 2005, Chuck Lever <cel@netapp.com> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/interrupt.h> +#include <linux/workqueue.h> +#include <linux/net.h> +#include <linux/ktime.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/metrics.h> +#include <linux/sunrpc/bc_xprt.h> +#include <linux/rcupdate.h> + +#include <trace/events/sunrpc.h> + +#include "sunrpc.h" + +/* + * Local variables + */ + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_XPRT +#endif + +/* + * Local functions + */ +static void xprt_init(struct rpc_xprt *xprt, struct net *net); +static __be32 xprt_alloc_xid(struct rpc_xprt *xprt); +static void xprt_connect_status(struct rpc_task *task); +static int __xprt_get_cong(struct rpc_xprt *, struct rpc_task *); +static void __xprt_put_cong(struct rpc_xprt *, struct rpc_rqst *); +static void xprt_destroy(struct rpc_xprt *xprt); + +static DEFINE_SPINLOCK(xprt_list_lock); +static LIST_HEAD(xprt_list); + +/** + * xprt_register_transport - register a transport implementation + * @transport: transport to register + * + * If a transport implementation is loaded as a kernel module, it can + * call this interface to make itself known to the RPC client. + * + * Returns: + * 0: transport successfully registered + * -EEXIST: transport already registered + * -EINVAL: transport module being unloaded + */ +int xprt_register_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = -EEXIST; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + /* don't register the same transport class twice */ + if (t->ident == transport->ident) + goto out; + } + + list_add_tail(&transport->list, &xprt_list); + printk(KERN_INFO "RPC: Registered %s transport module.\n", + transport->name); + result = 0; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_register_transport); + +/** + * xprt_unregister_transport - unregister a transport implementation + * @transport: transport to unregister + * + * Returns: + * 0: transport successfully unregistered + * -ENOENT: transport never registered + */ +int xprt_unregister_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = 0; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + if (t == transport) { + printk(KERN_INFO + "RPC: Unregistered %s transport module.\n", + transport->name); + list_del_init(&transport->list); + goto out; + } + } + result = -ENOENT; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_unregister_transport); + +static void +xprt_class_release(const struct xprt_class *t) +{ + module_put(t->owner); +} + +static const struct xprt_class * +xprt_class_find_by_netid_locked(const char *netid) +{ + const struct xprt_class *t; + unsigned int i; + + list_for_each_entry(t, &xprt_list, list) { + for (i = 0; t->netid[i][0] != '\0'; i++) { + if (strcmp(t->netid[i], netid) != 0) + continue; + if (!try_module_get(t->owner)) + continue; + return t; + } + } + return NULL; +} + +static const struct xprt_class * +xprt_class_find_by_netid(const char *netid) +{ + const struct xprt_class *t; + + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + if (!t) { + spin_unlock(&xprt_list_lock); + request_module("rpc%s", netid); + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + } + spin_unlock(&xprt_list_lock); + return t; +} + +/** + * xprt_load_transport - load a transport implementation + * @netid: transport to load + * + * Returns: + * 0: transport successfully loaded + * -ENOENT: transport module not available + */ +int xprt_load_transport(const char *netid) +{ + const struct xprt_class *t; + + t = xprt_class_find_by_netid(netid); + if (!t) + return -ENOENT; + xprt_class_release(t); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_load_transport); + +/** + * xprt_reserve_xprt - serialize write access to transports + * @task: task that is requesting access to the transport + * @xprt: pointer to the target transport + * + * This prevents mixing the payload of separate requests, and prevents + * transport connects from colliding with writes. No congestion control + * is provided. + */ +int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + int priority; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + return 1; + goto out_sleep; + } + xprt->snd_task = task; + if (req != NULL) + req->rq_ntrans++; + + return 1; + +out_sleep: + dprintk("RPC: %5u failed to lock transport %p\n", + task->tk_pid, xprt); + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + if (req == NULL) + priority = RPC_PRIORITY_LOW; + else if (!req->rq_ntrans) + priority = RPC_PRIORITY_NORMAL; + else + priority = RPC_PRIORITY_HIGH; + rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt); + +static void xprt_clear_locked(struct rpc_xprt *xprt) +{ + xprt->snd_task = NULL; + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { + smp_mb__before_atomic(); + clear_bit(XPRT_LOCKED, &xprt->state); + smp_mb__after_atomic(); + } else + queue_work(xprtiod_workqueue, &xprt->task_cleanup); +} + +/* + * xprt_reserve_xprt_cong - serialize write access to transports + * @task: task that is requesting access to the transport + * + * Same as xprt_reserve_xprt, but Van Jacobson congestion control is + * integrated into the decision of whether a request is allowed to be + * woken up and given access to the transport. + */ +int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + int priority; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + return 1; + goto out_sleep; + } + if (req == NULL) { + xprt->snd_task = task; + return 1; + } + if (__xprt_get_cong(xprt, task)) { + xprt->snd_task = task; + req->rq_ntrans++; + return 1; + } + xprt_clear_locked(xprt); +out_sleep: + if (req) + __xprt_put_cong(xprt, req); + dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + if (req == NULL) + priority = RPC_PRIORITY_LOW; + else if (!req->rq_ntrans) + priority = RPC_PRIORITY_NORMAL; + else + priority = RPC_PRIORITY_HIGH; + rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); + +static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + int retval; + + spin_lock_bh(&xprt->transport_lock); + retval = xprt->ops->reserve_xprt(xprt, task); + spin_unlock_bh(&xprt->transport_lock); + return retval; +} + +static bool __xprt_lock_write_func(struct rpc_task *task, void *data) +{ + struct rpc_xprt *xprt = data; + struct rpc_rqst *req; + + req = task->tk_rqstp; + xprt->snd_task = task; + if (req) + req->rq_ntrans++; + return true; +} + +static void __xprt_lock_write_next(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_func, xprt)) + return; + xprt_clear_locked(xprt); +} + +static bool __xprt_lock_write_cong_func(struct rpc_task *task, void *data) +{ + struct rpc_xprt *xprt = data; + struct rpc_rqst *req; + + req = task->tk_rqstp; + if (req == NULL) { + xprt->snd_task = task; + return true; + } + if (__xprt_get_cong(xprt, task)) { + xprt->snd_task = task; + req->rq_ntrans++; + return true; + } + return false; +} + +static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + if (RPCXPRT_CONGESTED(xprt)) + goto out_unlock; + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_cong_func, xprt)) + return; +out_unlock: + xprt_clear_locked(xprt); +} + +static void xprt_task_clear_bytes_sent(struct rpc_task *task) +{ + if (task != NULL) { + struct rpc_rqst *req = task->tk_rqstp; + if (req != NULL) + req->rq_bytes_sent = 0; + } +} + +/** + * xprt_release_xprt - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. No congestion control is provided. + */ +void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_task_clear_bytes_sent(task); + xprt_clear_locked(xprt); + __xprt_lock_write_next(xprt); + } +} +EXPORT_SYMBOL_GPL(xprt_release_xprt); + +/** + * xprt_release_xprt_cong - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. Another task is awoken to use the + * transport if the transport's congestion window allows it. + */ +void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_task_clear_bytes_sent(task); + xprt_clear_locked(xprt); + __xprt_lock_write_next_cong(xprt); + } +} +EXPORT_SYMBOL_GPL(xprt_release_xprt_cong); + +static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + spin_lock_bh(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + spin_unlock_bh(&xprt->transport_lock); +} + +/* + * Van Jacobson congestion avoidance. Check if the congestion window + * overflowed. Put the task to sleep if this is the case. + */ +static int +__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (req->rq_cong) + return 1; + dprintk("RPC: %5u xprt_cwnd_limited cong = %lu cwnd = %lu\n", + task->tk_pid, xprt->cong, xprt->cwnd); + if (RPCXPRT_CONGESTED(xprt)) + return 0; + req->rq_cong = 1; + xprt->cong += RPC_CWNDSCALE; + return 1; +} + +/* + * Adjust the congestion window, and wake up the next task + * that has been sleeping due to congestion + */ +static void +__xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (!req->rq_cong) + return; + req->rq_cong = 0; + xprt->cong -= RPC_CWNDSCALE; + __xprt_lock_write_next_cong(xprt); +} + +/** + * xprt_release_rqst_cong - housekeeping when request is complete + * @task: RPC request that recently completed + * + * Useful for transports that require congestion control. + */ +void xprt_release_rqst_cong(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + __xprt_put_cong(req->rq_xprt, req); +} +EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); + +/** + * xprt_adjust_cwnd - adjust transport congestion window + * @xprt: pointer to xprt + * @task: recently completed RPC request used to adjust window + * @result: result code of completed RPC request + * + * The transport code maintains an estimate on the maximum number of out- + * standing RPC requests, using a smoothed version of the congestion + * avoidance implemented in 44BSD. This is basically the Van Jacobson + * congestion algorithm: If a retransmit occurs, the congestion window is + * halved; otherwise, it is incremented by 1/cwnd when + * + * - a reply is received and + * - a full number of requests are outstanding and + * - the congestion window hasn't been updated recently. + */ +void xprt_adjust_cwnd(struct rpc_xprt *xprt, struct rpc_task *task, int result) +{ + struct rpc_rqst *req = task->tk_rqstp; + unsigned long cwnd = xprt->cwnd; + + if (result >= 0 && cwnd <= xprt->cong) { + /* The (cwnd >> 1) term makes sure + * the result gets rounded properly. */ + cwnd += (RPC_CWNDSCALE * RPC_CWNDSCALE + (cwnd >> 1)) / cwnd; + if (cwnd > RPC_MAXCWND(xprt)) + cwnd = RPC_MAXCWND(xprt); + __xprt_lock_write_next_cong(xprt); + } else if (result == -ETIMEDOUT) { + cwnd >>= 1; + if (cwnd < RPC_CWNDSCALE) + cwnd = RPC_CWNDSCALE; + } + dprintk("RPC: cong %ld, cwnd was %ld, now %ld\n", + xprt->cong, xprt->cwnd, cwnd); + xprt->cwnd = cwnd; + __xprt_put_cong(xprt, req); +} +EXPORT_SYMBOL_GPL(xprt_adjust_cwnd); + +/** + * xprt_wake_pending_tasks - wake all tasks on a transport's pending queue + * @xprt: transport with waiting tasks + * @status: result code to plant in each task before waking it + * + */ +void xprt_wake_pending_tasks(struct rpc_xprt *xprt, int status) +{ + if (status < 0) + rpc_wake_up_status(&xprt->pending, status); + else + rpc_wake_up(&xprt->pending); +} +EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); + +/** + * xprt_wait_for_buffer_space - wait for transport output buffer to clear + * @task: task to be put to sleep + * @action: function pointer to be executed after wait + * + * Note that we only set the timer for the case of RPC_IS_SOFT(), since + * we don't in general want to force a socket disconnection due to + * an incomplete RPC call transmission. + */ +void xprt_wait_for_buffer_space(struct rpc_task *task, rpc_action action) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; + rpc_sleep_on(&xprt->pending, task, action); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); + +/** + * xprt_write_space - wake the task waiting for transport output buffer space + * @xprt: transport with waiting tasks + * + * Can be called in a soft IRQ context, so xprt_write_space never sleeps. + */ +void xprt_write_space(struct rpc_xprt *xprt) +{ + spin_lock_bh(&xprt->transport_lock); + if (xprt->snd_task) { + dprintk("RPC: write space: waking waiting task on " + "xprt %p\n", xprt); + rpc_wake_up_queued_task_on_wq(xprtiod_workqueue, + &xprt->pending, xprt->snd_task); + } + spin_unlock_bh(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_write_space); + +/** + * xprt_set_retrans_timeout_def - set a request's retransmit timeout + * @task: task whose timeout is to be set + * + * Set a request's retransmit timeout based on the transport's + * default timeout parameters. Used by transports that don't adjust + * the retransmit timeout based on round-trip time estimation. + */ +void xprt_set_retrans_timeout_def(struct rpc_task *task) +{ + task->tk_timeout = task->tk_rqstp->rq_timeout; +} +EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_def); + +/** + * xprt_set_retrans_timeout_rtt - set a request's retransmit timeout + * @task: task whose timeout is to be set + * + * Set a request's retransmit timeout using the RTT estimator. + */ +void xprt_set_retrans_timeout_rtt(struct rpc_task *task) +{ + int timer = task->tk_msg.rpc_proc->p_timer; + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rtt *rtt = clnt->cl_rtt; + struct rpc_rqst *req = task->tk_rqstp; + unsigned long max_timeout = clnt->cl_timeout->to_maxval; + + task->tk_timeout = rpc_calc_rto(rtt, timer); + task->tk_timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries; + if (task->tk_timeout > max_timeout || task->tk_timeout == 0) + task->tk_timeout = max_timeout; +} +EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_rtt); + +static void xprt_reset_majortimeo(struct rpc_rqst *req) +{ + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + + req->rq_majortimeo = req->rq_timeout; + if (to->to_exponential) + req->rq_majortimeo <<= to->to_retries; + else + req->rq_majortimeo += to->to_increment * to->to_retries; + if (req->rq_majortimeo > to->to_maxval || req->rq_majortimeo == 0) + req->rq_majortimeo = to->to_maxval; + req->rq_majortimeo += jiffies; +} + +/** + * xprt_adjust_timeout - adjust timeout values for next retransmit + * @req: RPC request containing parameters to use for the adjustment + * + */ +int xprt_adjust_timeout(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + int status = 0; + + if (time_before(jiffies, req->rq_majortimeo)) { + if (to->to_exponential) + req->rq_timeout <<= 1; + else + req->rq_timeout += to->to_increment; + if (to->to_maxval && req->rq_timeout >= to->to_maxval) + req->rq_timeout = to->to_maxval; + req->rq_retries++; + } else { + req->rq_timeout = to->to_initval; + req->rq_retries = 0; + xprt_reset_majortimeo(req); + /* Reset the RTT counters == "slow start" */ + spin_lock_bh(&xprt->transport_lock); + rpc_init_rtt(req->rq_task->tk_client->cl_rtt, to->to_initval); + spin_unlock_bh(&xprt->transport_lock); + status = -ETIMEDOUT; + } + + if (req->rq_timeout == 0) { + printk(KERN_WARNING "xprt_adjust_timeout: rq_timeout = 0!\n"); + req->rq_timeout = 5 * HZ; + } + return status; +} + +static void xprt_autoclose(struct work_struct *work) +{ + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + xprt->ops->close(xprt); + xprt_release_write(xprt, NULL); + wake_up_bit(&xprt->state, XPRT_LOCKED); +} + +/** + * xprt_disconnect_done - mark a transport as disconnected + * @xprt: transport to flag for disconnect + * + */ +void xprt_disconnect_done(struct rpc_xprt *xprt) +{ + dprintk("RPC: disconnected transport %p\n", xprt); + spin_lock_bh(&xprt->transport_lock); + xprt_clear_connected(xprt); + xprt_wake_pending_tasks(xprt, -EAGAIN); + spin_unlock_bh(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_disconnect_done); + +/** + * xprt_force_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * + */ +void xprt_force_disconnect(struct rpc_xprt *xprt) +{ + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock_bh(&xprt->transport_lock); + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + /* Try to schedule an autoclose RPC call */ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(xprtiod_workqueue, &xprt->task_cleanup); + xprt_wake_pending_tasks(xprt, -EAGAIN); + spin_unlock_bh(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_force_disconnect); + +/** + * xprt_conditional_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * @cookie: 'connection cookie' + * + * This attempts to break the connection if and only if 'cookie' matches + * the current transport 'connection cookie'. It ensures that we don't + * try to break the connection more than once when we need to retransmit + * a batch of RPC requests. + * + */ +void xprt_conditional_disconnect(struct rpc_xprt *xprt, unsigned int cookie) +{ + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock_bh(&xprt->transport_lock); + if (cookie != xprt->connect_cookie) + goto out; + if (test_bit(XPRT_CLOSING, &xprt->state)) + goto out; + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + /* Try to schedule an autoclose RPC call */ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(xprtiod_workqueue, &xprt->task_cleanup); + xprt_wake_pending_tasks(xprt, -EAGAIN); +out: + spin_unlock_bh(&xprt->transport_lock); +} + +static bool +xprt_has_timer(const struct rpc_xprt *xprt) +{ + return xprt->idle_timeout != 0; +} + +static void +xprt_schedule_autodisconnect(struct rpc_xprt *xprt) + __must_hold(&xprt->transport_lock) +{ + if (list_empty(&xprt->recv) && xprt_has_timer(xprt)) + mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout); +} + +static void +xprt_init_autodisconnect(struct timer_list *t) +{ + struct rpc_xprt *xprt = from_timer(xprt, t, timer); + + spin_lock(&xprt->transport_lock); + if (!list_empty(&xprt->recv)) + goto out_abort; + /* Reset xprt->last_used to avoid connect/autodisconnect cycling */ + xprt->last_used = jiffies; + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + goto out_abort; + spin_unlock(&xprt->transport_lock); + queue_work(xprtiod_workqueue, &xprt->task_cleanup); + return; +out_abort: + spin_unlock(&xprt->transport_lock); +} + +bool xprt_lock_connect(struct rpc_xprt *xprt, + struct rpc_task *task, + void *cookie) +{ + bool ret = false; + + spin_lock_bh(&xprt->transport_lock); + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + if (xprt->snd_task != task) + goto out; + xprt_task_clear_bytes_sent(task); + xprt->snd_task = cookie; + ret = true; +out: + spin_unlock_bh(&xprt->transport_lock); + return ret; +} + +void xprt_unlock_connect(struct rpc_xprt *xprt, void *cookie) +{ + spin_lock_bh(&xprt->transport_lock); + if (xprt->snd_task != cookie) + goto out; + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + xprt->snd_task =NULL; + xprt->ops->release_xprt(xprt, NULL); + xprt_schedule_autodisconnect(xprt); +out: + spin_unlock_bh(&xprt->transport_lock); + wake_up_bit(&xprt->state, XPRT_LOCKED); +} + +/** + * xprt_connect - schedule a transport connect operation + * @task: RPC task that is requesting the connect + * + */ +void xprt_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + dprintk("RPC: %5u xprt_connect xprt %p %s connected\n", task->tk_pid, + xprt, (xprt_connected(xprt) ? "is" : "is not")); + + if (!xprt_bound(xprt)) { + task->tk_status = -EAGAIN; + return; + } + if (!xprt_lock_write(xprt, task)) + return; + + if (test_and_clear_bit(XPRT_CLOSE_WAIT, &xprt->state)) + xprt->ops->close(xprt); + + if (!xprt_connected(xprt)) { + task->tk_rqstp->rq_bytes_sent = 0; + task->tk_timeout = task->tk_rqstp->rq_timeout; + task->tk_rqstp->rq_connect_cookie = xprt->connect_cookie; + rpc_sleep_on(&xprt->pending, task, xprt_connect_status); + + if (test_bit(XPRT_CLOSING, &xprt->state)) + return; + if (xprt_test_and_set_connecting(xprt)) + return; + /* Race breaker */ + if (!xprt_connected(xprt)) { + xprt->stat.connect_start = jiffies; + xprt->ops->connect(xprt, task); + } else { + xprt_clear_connecting(xprt); + task->tk_status = 0; + rpc_wake_up_queued_task(&xprt->pending, task); + } + } + xprt_release_write(xprt, task); +} + +static void xprt_connect_status(struct rpc_task *task) +{ + switch (task->tk_status) { + case 0: + dprintk("RPC: %5u xprt_connect_status: connection established\n", + task->tk_pid); + break; + case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + case -ENETUNREACH: + case -EHOSTUNREACH: + case -EPIPE: + case -EAGAIN: + dprintk("RPC: %5u xprt_connect_status: retrying\n", task->tk_pid); + break; + case -ETIMEDOUT: + dprintk("RPC: %5u xprt_connect_status: connect attempt timed " + "out\n", task->tk_pid); + break; + default: + dprintk("RPC: %5u xprt_connect_status: error %d connecting to " + "server %s\n", task->tk_pid, -task->tk_status, + task->tk_rqstp->rq_xprt->servername); + task->tk_status = -EIO; + } +} + +/** + * xprt_lookup_rqst - find an RPC request corresponding to an XID + * @xprt: transport on which the original request was transmitted + * @xid: RPC XID of incoming reply + * + * Caller holds xprt->recv_lock. + */ +struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *entry; + + list_for_each_entry(entry, &xprt->recv, rq_list) + if (entry->rq_xid == xid) { + trace_xprt_lookup_rqst(xprt, xid, 0); + entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime); + return entry; + } + + dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", + ntohl(xid)); + trace_xprt_lookup_rqst(xprt, xid, -ENOENT); + xprt->stat.bad_xids++; + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_lookup_rqst); + +/** + * xprt_pin_rqst - Pin a request on the transport receive list + * @req: Request to pin + * + * Caller must ensure this is atomic with the call to xprt_lookup_rqst() + * so should be holding the xprt transport lock. + */ +void xprt_pin_rqst(struct rpc_rqst *req) +{ + set_bit(RPC_TASK_MSG_RECV, &req->rq_task->tk_runstate); +} +EXPORT_SYMBOL_GPL(xprt_pin_rqst); + +/** + * xprt_unpin_rqst - Unpin a request on the transport receive list + * @req: Request to pin + * + * Caller should be holding the xprt transport lock. + */ +void xprt_unpin_rqst(struct rpc_rqst *req) +{ + struct rpc_task *task = req->rq_task; + + clear_bit(RPC_TASK_MSG_RECV, &task->tk_runstate); + if (test_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate)) + wake_up_bit(&task->tk_runstate, RPC_TASK_MSG_RECV); +} +EXPORT_SYMBOL_GPL(xprt_unpin_rqst); + +static void xprt_wait_on_pinned_rqst(struct rpc_rqst *req) +__must_hold(&req->rq_xprt->recv_lock) +{ + struct rpc_task *task = req->rq_task; + + if (task && test_bit(RPC_TASK_MSG_RECV, &task->tk_runstate)) { + spin_unlock(&req->rq_xprt->recv_lock); + set_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate); + wait_on_bit(&task->tk_runstate, RPC_TASK_MSG_RECV, + TASK_UNINTERRUPTIBLE); + clear_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate); + spin_lock(&req->rq_xprt->recv_lock); + } +} + +/** + * xprt_update_rtt - Update RPC RTT statistics + * @task: RPC request that recently completed + * + * Caller holds xprt->recv_lock. + */ +void xprt_update_rtt(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_rtt *rtt = task->tk_client->cl_rtt; + unsigned int timer = task->tk_msg.rpc_proc->p_timer; + long m = usecs_to_jiffies(ktime_to_us(req->rq_rtt)); + + if (timer) { + if (req->rq_ntrans == 1) + rpc_update_rtt(rtt, timer, m); + rpc_set_timeo(rtt, timer, req->rq_ntrans - 1); + } +} +EXPORT_SYMBOL_GPL(xprt_update_rtt); + +/** + * xprt_complete_rqst - called when reply processing is complete + * @task: RPC request that recently completed + * @copied: actual number of bytes received from the transport + * + * Caller holds xprt->recv_lock. + */ +void xprt_complete_rqst(struct rpc_task *task, int copied) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + dprintk("RPC: %5u xid %08x complete (%d bytes received)\n", + task->tk_pid, ntohl(req->rq_xid), copied); + trace_xprt_complete_rqst(xprt, req->rq_xid, copied); + + xprt->stat.recvs++; + + list_del_init(&req->rq_list); + req->rq_private_buf.len = copied; + /* Ensure all writes are done before we update */ + /* req->rq_reply_bytes_recvd */ + smp_wmb(); + req->rq_reply_bytes_recvd = copied; + rpc_wake_up_queued_task(&xprt->pending, task); +} +EXPORT_SYMBOL_GPL(xprt_complete_rqst); + +static void xprt_timer(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (task->tk_status != -ETIMEDOUT) + return; + + trace_xprt_timer(xprt, req->rq_xid, task->tk_status); + if (!req->rq_reply_bytes_recvd) { + if (xprt->ops->timer) + xprt->ops->timer(xprt, task); + } else + task->tk_status = 0; +} + +/** + * xprt_prepare_transmit - reserve the transport before sending a request + * @task: RPC task about to send a request + * + */ +bool xprt_prepare_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + bool ret = false; + + dprintk("RPC: %5u xprt_prepare_transmit\n", task->tk_pid); + + spin_lock_bh(&xprt->transport_lock); + if (!req->rq_bytes_sent) { + if (req->rq_reply_bytes_recvd) { + task->tk_status = req->rq_reply_bytes_recvd; + goto out_unlock; + } + if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) + && xprt_connected(xprt) + && req->rq_connect_cookie == xprt->connect_cookie) { + xprt->ops->set_retrans_timeout(task); + rpc_sleep_on(&xprt->pending, task, xprt_timer); + goto out_unlock; + } + } + if (!xprt->ops->reserve_xprt(xprt, task)) { + task->tk_status = -EAGAIN; + goto out_unlock; + } + ret = true; +out_unlock: + spin_unlock_bh(&xprt->transport_lock); + return ret; +} + +void xprt_end_transmit(struct rpc_task *task) +{ + xprt_release_write(task->tk_rqstp->rq_xprt, task); +} + +/** + * xprt_transmit - send an RPC request on a transport + * @task: controlling RPC task + * + * We have to copy the iovec because sendmsg fiddles with its contents. + */ +void xprt_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + unsigned int connect_cookie; + int status; + + dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); + + if (!req->rq_reply_bytes_recvd) { + if (list_empty(&req->rq_list) && rpc_reply_expected(task)) { + /* + * Add to the list only if we're expecting a reply + */ + /* Update the softirq receive buffer */ + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + /* Add request to the receive list */ + spin_lock(&xprt->recv_lock); + list_add_tail(&req->rq_list, &xprt->recv); + spin_unlock(&xprt->recv_lock); + xprt_reset_majortimeo(req); + /* Turn off autodisconnect */ + del_singleshot_timer_sync(&xprt->timer); + } + } else if (!req->rq_bytes_sent) + return; + + connect_cookie = xprt->connect_cookie; + status = xprt->ops->send_request(task); + trace_xprt_transmit(xprt, req->rq_xid, status); + if (status != 0) { + task->tk_status = status; + return; + } + xprt_inject_disconnect(xprt); + + dprintk("RPC: %5u xmit complete\n", task->tk_pid); + task->tk_flags |= RPC_TASK_SENT; + spin_lock_bh(&xprt->transport_lock); + + xprt->ops->set_retrans_timeout(task); + + xprt->stat.sends++; + xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; + xprt->stat.bklog_u += xprt->backlog.qlen; + xprt->stat.sending_u += xprt->sending.qlen; + xprt->stat.pending_u += xprt->pending.qlen; + spin_unlock_bh(&xprt->transport_lock); + + req->rq_connect_cookie = connect_cookie; + if (rpc_reply_expected(task) && !READ_ONCE(req->rq_reply_bytes_recvd)) { + /* + * Sleep on the pending queue if we're expecting a reply. + * The spinlock ensures atomicity between the test of + * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on(). + */ + spin_lock(&xprt->recv_lock); + if (!req->rq_reply_bytes_recvd) { + rpc_sleep_on(&xprt->pending, task, xprt_timer); + /* + * Send an extra queue wakeup call if the + * connection was dropped in case the call to + * rpc_sleep_on() raced. + */ + if (!xprt_connected(xprt)) + xprt_wake_pending_tasks(xprt, -ENOTCONN); + } + spin_unlock(&xprt->recv_lock); + } +} + +static void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) +{ + set_bit(XPRT_CONGESTED, &xprt->state); + rpc_sleep_on(&xprt->backlog, task, NULL); +} + +static void xprt_wake_up_backlog(struct rpc_xprt *xprt) +{ + if (rpc_wake_up_next(&xprt->backlog) == NULL) + clear_bit(XPRT_CONGESTED, &xprt->state); +} + +static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task) +{ + bool ret = false; + + if (!test_bit(XPRT_CONGESTED, &xprt->state)) + goto out; + spin_lock(&xprt->reserve_lock); + if (test_bit(XPRT_CONGESTED, &xprt->state)) { + rpc_sleep_on(&xprt->backlog, task, NULL); + ret = true; + } + spin_unlock(&xprt->reserve_lock); +out: + return ret; +} + +static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt) +{ + struct rpc_rqst *req = ERR_PTR(-EAGAIN); + + if (xprt->num_reqs >= xprt->max_reqs) + goto out; + ++xprt->num_reqs; + spin_unlock(&xprt->reserve_lock); + req = kzalloc(sizeof(struct rpc_rqst), GFP_NOFS); + spin_lock(&xprt->reserve_lock); + if (req != NULL) + goto out; + --xprt->num_reqs; + req = ERR_PTR(-ENOMEM); +out: + return req; +} + +static bool xprt_dynamic_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (xprt->num_reqs > xprt->min_reqs) { + --xprt->num_reqs; + kfree(req); + return true; + } + return false; +} + +void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req; + + spin_lock(&xprt->reserve_lock); + if (!list_empty(&xprt->free)) { + req = list_entry(xprt->free.next, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + goto out_init_req; + } + req = xprt_dynamic_alloc_slot(xprt); + if (!IS_ERR(req)) + goto out_init_req; + switch (PTR_ERR(req)) { + case -ENOMEM: + dprintk("RPC: dynamic allocation of request slot " + "failed! Retrying\n"); + task->tk_status = -ENOMEM; + break; + case -EAGAIN: + xprt_add_backlog(xprt, task); + dprintk("RPC: waiting for request slot\n"); + /* fall through */ + default: + task->tk_status = -EAGAIN; + } + spin_unlock(&xprt->reserve_lock); + return; +out_init_req: + xprt->stat.max_slots = max_t(unsigned int, xprt->stat.max_slots, + xprt->num_reqs); + spin_unlock(&xprt->reserve_lock); + + task->tk_status = 0; + task->tk_rqstp = req; +} +EXPORT_SYMBOL_GPL(xprt_alloc_slot); + +void xprt_lock_and_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + /* Note: grabbing the xprt_lock_write() ensures that we throttle + * new slot allocation if the transport is congested (i.e. when + * reconnecting a stream transport or when out of socket write + * buffer space). + */ + if (xprt_lock_write(xprt, task)) { + xprt_alloc_slot(xprt, task); + xprt_release_write(xprt, task); + } +} +EXPORT_SYMBOL_GPL(xprt_lock_and_alloc_slot); + +void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + spin_lock(&xprt->reserve_lock); + if (!xprt_dynamic_free_slot(xprt, req)) { + memset(req, 0, sizeof(*req)); /* mark unused */ + list_add(&req->rq_list, &xprt->free); + } + xprt_wake_up_backlog(xprt); + spin_unlock(&xprt->reserve_lock); +} +EXPORT_SYMBOL_GPL(xprt_free_slot); + +static void xprt_free_all_slots(struct rpc_xprt *xprt) +{ + struct rpc_rqst *req; + while (!list_empty(&xprt->free)) { + req = list_first_entry(&xprt->free, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + kfree(req); + } +} + +struct rpc_xprt *xprt_alloc(struct net *net, size_t size, + unsigned int num_prealloc, + unsigned int max_alloc) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req; + int i; + + xprt = kzalloc(size, GFP_KERNEL); + if (xprt == NULL) + goto out; + + xprt_init(xprt, net); + + for (i = 0; i < num_prealloc; i++) { + req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL); + if (!req) + goto out_free; + list_add(&req->rq_list, &xprt->free); + } + if (max_alloc > num_prealloc) + xprt->max_reqs = max_alloc; + else + xprt->max_reqs = num_prealloc; + xprt->min_reqs = num_prealloc; + xprt->num_reqs = num_prealloc; + + return xprt; + +out_free: + xprt_free(xprt); +out: + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_alloc); + +void xprt_free(struct rpc_xprt *xprt) +{ + put_net(xprt->xprt_net); + xprt_free_all_slots(xprt); + kfree_rcu(xprt, rcu); +} +EXPORT_SYMBOL_GPL(xprt_free); + +static __be32 +xprt_alloc_xid(struct rpc_xprt *xprt) +{ + __be32 xid; + + spin_lock(&xprt->reserve_lock); + xid = (__force __be32)xprt->xid++; + spin_unlock(&xprt->reserve_lock); + return xid; +} + +static void +xprt_init_xid(struct rpc_xprt *xprt) +{ + xprt->xid = prandom_u32(); +} + +static void +xprt_request_init(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + INIT_LIST_HEAD(&req->rq_list); + req->rq_timeout = task->tk_client->cl_timeout->to_initval; + req->rq_task = task; + req->rq_xprt = xprt; + req->rq_buffer = NULL; + req->rq_xid = xprt_alloc_xid(xprt); + req->rq_connect_cookie = xprt->connect_cookie - 1; + req->rq_bytes_sent = 0; + req->rq_snd_buf.len = 0; + req->rq_snd_buf.buflen = 0; + req->rq_rcv_buf.len = 0; + req->rq_rcv_buf.buflen = 0; + req->rq_release_snd_buf = NULL; + xprt_reset_majortimeo(req); + dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, + req, ntohl(req->rq_xid)); +} + +static void +xprt_do_reserve(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt->ops->alloc_slot(xprt, task); + if (task->tk_rqstp != NULL) + xprt_request_init(task); +} + +/** + * xprt_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If the transport is marked as being congested, or if no more + * slots are available, place the task on the transport's + * backlog queue. + */ +void xprt_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + if (!xprt_throttle_congested(xprt, task)) + xprt_do_reserve(xprt, task); +} + +/** + * xprt_retry_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If no more slots are available, place the task on the transport's + * backlog queue. + * Note that the only difference with xprt_reserve is that we now + * ignore the value of the XPRT_CONGESTED flag. + */ +void xprt_retry_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + xprt_do_reserve(xprt, task); +} + +/** + * xprt_release - release an RPC request slot + * @task: task which is finished with the slot + * + */ +void xprt_release(struct rpc_task *task) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req = task->tk_rqstp; + + if (req == NULL) { + if (task->tk_client) { + xprt = task->tk_xprt; + if (xprt->snd_task == task) + xprt_release_write(xprt, task); + } + return; + } + + xprt = req->rq_xprt; + if (task->tk_ops->rpc_count_stats != NULL) + task->tk_ops->rpc_count_stats(task, task->tk_calldata); + else if (task->tk_client) + rpc_count_iostats(task, task->tk_client->cl_metrics); + spin_lock(&xprt->recv_lock); + if (!list_empty(&req->rq_list)) { + list_del_init(&req->rq_list); + xprt_wait_on_pinned_rqst(req); + } + spin_unlock(&xprt->recv_lock); + spin_lock_bh(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + if (xprt->ops->release_request) + xprt->ops->release_request(task); + xprt->last_used = jiffies; + xprt_schedule_autodisconnect(xprt); + spin_unlock_bh(&xprt->transport_lock); + if (req->rq_buffer) + xprt->ops->buf_free(task); + xprt_inject_disconnect(xprt); + if (req->rq_cred != NULL) + put_rpccred(req->rq_cred); + task->tk_rqstp = NULL; + if (req->rq_release_snd_buf) + req->rq_release_snd_buf(req); + + dprintk("RPC: %5u release request %p\n", task->tk_pid, req); + if (likely(!bc_prealloc(req))) + xprt->ops->free_slot(xprt, req); + else + xprt_free_bc_request(req); +} + +static void xprt_init(struct rpc_xprt *xprt, struct net *net) +{ + kref_init(&xprt->kref); + + spin_lock_init(&xprt->transport_lock); + spin_lock_init(&xprt->reserve_lock); + spin_lock_init(&xprt->recv_lock); + + INIT_LIST_HEAD(&xprt->free); + INIT_LIST_HEAD(&xprt->recv); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + spin_lock_init(&xprt->bc_pa_lock); + INIT_LIST_HEAD(&xprt->bc_pa_list); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + INIT_LIST_HEAD(&xprt->xprt_switch); + + xprt->last_used = jiffies; + xprt->cwnd = RPC_INITCWND; + xprt->bind_index = 0; + + rpc_init_wait_queue(&xprt->binding, "xprt_binding"); + rpc_init_wait_queue(&xprt->pending, "xprt_pending"); + rpc_init_priority_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); + + xprt_init_xid(xprt); + + xprt->xprt_net = get_net(net); +} + +/** + * xprt_create_transport - create an RPC transport + * @args: rpc transport creation arguments + * + */ +struct rpc_xprt *xprt_create_transport(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct xprt_class *t; + + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + if (t->ident == args->ident) { + spin_unlock(&xprt_list_lock); + goto found; + } + } + spin_unlock(&xprt_list_lock); + dprintk("RPC: transport (%d) not supported\n", args->ident); + return ERR_PTR(-EIO); + +found: + xprt = t->setup(args); + if (IS_ERR(xprt)) { + dprintk("RPC: xprt_create_transport: failed, %ld\n", + -PTR_ERR(xprt)); + goto out; + } + if (args->flags & XPRT_CREATE_NO_IDLE_TIMEOUT) + xprt->idle_timeout = 0; + INIT_WORK(&xprt->task_cleanup, xprt_autoclose); + if (xprt_has_timer(xprt)) + timer_setup(&xprt->timer, xprt_init_autodisconnect, 0); + else + timer_setup(&xprt->timer, NULL, 0); + + if (strlen(args->servername) > RPC_MAXNETNAMELEN) { + xprt_destroy(xprt); + return ERR_PTR(-EINVAL); + } + xprt->servername = kstrdup(args->servername, GFP_KERNEL); + if (xprt->servername == NULL) { + xprt_destroy(xprt); + return ERR_PTR(-ENOMEM); + } + + rpc_xprt_debugfs_register(xprt); + + dprintk("RPC: created transport %p with %u slots\n", xprt, + xprt->max_reqs); +out: + return xprt; +} + +static void xprt_destroy_cb(struct work_struct *work) +{ + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + + rpc_xprt_debugfs_unregister(xprt); + rpc_destroy_wait_queue(&xprt->binding); + rpc_destroy_wait_queue(&xprt->pending); + rpc_destroy_wait_queue(&xprt->sending); + rpc_destroy_wait_queue(&xprt->backlog); + kfree(xprt->servername); + /* + * Tear down transport state and free the rpc_xprt + */ + xprt->ops->destroy(xprt); +} + +/** + * xprt_destroy - destroy an RPC transport, killing off all requests. + * @xprt: transport to destroy + * + */ +static void xprt_destroy(struct rpc_xprt *xprt) +{ + dprintk("RPC: destroying transport %p\n", xprt); + + /* + * Exclude transport connect/disconnect handlers and autoclose + */ + wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_UNINTERRUPTIBLE); + + /* + * xprt_schedule_autodisconnect() can run after XPRT_LOCKED + * is cleared. We use ->transport_lock to ensure the mod_timer() + * can only run *before* del_time_sync(), never after. + */ + spin_lock(&xprt->transport_lock); + del_timer_sync(&xprt->timer); + spin_unlock(&xprt->transport_lock); + + /* + * Destroy sockets etc from the system workqueue so they can + * safely flush receive work running on rpciod. + */ + INIT_WORK(&xprt->task_cleanup, xprt_destroy_cb); + schedule_work(&xprt->task_cleanup); +} + +static void xprt_destroy_kref(struct kref *kref) +{ + xprt_destroy(container_of(kref, struct rpc_xprt, kref)); +} + +/** + * xprt_get - return a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +struct rpc_xprt *xprt_get(struct rpc_xprt *xprt) +{ + if (xprt != NULL && kref_get_unless_zero(&xprt->kref)) + return xprt; + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_get); + +/** + * xprt_put - release a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +void xprt_put(struct rpc_xprt *xprt) +{ + if (xprt != NULL) + kref_put(&xprt->kref, xprt_destroy_kref); +} +EXPORT_SYMBOL_GPL(xprt_put); diff --git a/net/sunrpc/xprtmultipath.c b/net/sunrpc/xprtmultipath.c new file mode 100644 index 000000000..e2d64c713 --- /dev/null +++ b/net/sunrpc/xprtmultipath.c @@ -0,0 +1,496 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Multipath support for RPC + * + * Copyright (c) 2015, 2016, Primary Data, Inc. All rights reserved. + * + * Trond Myklebust <trond.myklebust@primarydata.com> + * + */ +#include <linux/types.h> +#include <linux/kref.h> +#include <linux/list.h> +#include <linux/rcupdate.h> +#include <linux/rculist.h> +#include <linux/slab.h> +#include <asm/cmpxchg.h> +#include <linux/spinlock.h> +#include <linux/sunrpc/xprt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/xprtmultipath.h> + +typedef struct rpc_xprt *(*xprt_switch_find_xprt_t)(struct list_head *head, + const struct rpc_xprt *cur); + +static const struct rpc_xprt_iter_ops rpc_xprt_iter_singular; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_listall; + +static void xprt_switch_add_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (unlikely(xprt_get(xprt) == NULL)) + return; + list_add_tail_rcu(&xprt->xprt_switch, &xps->xps_xprt_list); + smp_wmb(); + if (xps->xps_nxprts == 0) + xps->xps_net = xprt->xprt_net; + xps->xps_nxprts++; +} + +/** + * rpc_xprt_switch_add_xprt - Add a new rpc_xprt to an rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * + * Adds xprt to the end of the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_add_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (xprt == NULL) + return; + spin_lock(&xps->xps_lock); + if ((xps->xps_net == xprt->xprt_net || xps->xps_net == NULL) && + !rpc_xprt_switch_has_addr(xps, (struct sockaddr *)&xprt->addr)) + xprt_switch_add_xprt_locked(xps, xprt); + spin_unlock(&xps->xps_lock); +} + +static void xprt_switch_remove_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (unlikely(xprt == NULL)) + return; + xps->xps_nxprts--; + if (xps->xps_nxprts == 0) + xps->xps_net = NULL; + smp_wmb(); + list_del_rcu(&xprt->xprt_switch); +} + +/** + * rpc_xprt_switch_remove_xprt - Removes an rpc_xprt from a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * + * Removes xprt from the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_remove_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + spin_lock(&xps->xps_lock); + xprt_switch_remove_xprt_locked(xps, xprt); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); +} + +/** + * xprt_switch_alloc - Allocate a new struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * @gfp_flags: allocation flags + * + * On success, returns an initialised struct rpc_xprt_switch, containing + * the entry xprt. Returns NULL on failure. + */ +struct rpc_xprt_switch *xprt_switch_alloc(struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_xprt_switch *xps; + + xps = kmalloc(sizeof(*xps), gfp_flags); + if (xps != NULL) { + spin_lock_init(&xps->xps_lock); + kref_init(&xps->xps_kref); + xps->xps_nxprts = 0; + INIT_LIST_HEAD(&xps->xps_xprt_list); + xps->xps_iter_ops = &rpc_xprt_iter_singular; + xprt_switch_add_xprt_locked(xps, xprt); + } + + return xps; +} + +static void xprt_switch_free_entries(struct rpc_xprt_switch *xps) +{ + spin_lock(&xps->xps_lock); + while (!list_empty(&xps->xps_xprt_list)) { + struct rpc_xprt *xprt; + + xprt = list_first_entry(&xps->xps_xprt_list, + struct rpc_xprt, xprt_switch); + xprt_switch_remove_xprt_locked(xps, xprt); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); + spin_lock(&xps->xps_lock); + } + spin_unlock(&xps->xps_lock); +} + +static void xprt_switch_free(struct kref *kref) +{ + struct rpc_xprt_switch *xps = container_of(kref, + struct rpc_xprt_switch, xps_kref); + + xprt_switch_free_entries(xps); + kfree_rcu(xps, xps_rcu); +} + +/** + * xprt_switch_get - Return a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Returns a reference to xps unless the refcount is already zero. + */ +struct rpc_xprt_switch *xprt_switch_get(struct rpc_xprt_switch *xps) +{ + if (xps != NULL && kref_get_unless_zero(&xps->xps_kref)) + return xps; + return NULL; +} + +/** + * xprt_switch_put - Release a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Release the reference to xps, and free it once the refcount is zero. + */ +void xprt_switch_put(struct rpc_xprt_switch *xps) +{ + if (xps != NULL) + kref_put(&xps->xps_kref, xprt_switch_free); +} + +/** + * rpc_xprt_switch_set_roundrobin - Set a round-robin policy on rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Sets a round-robin default policy for iterators acting on xps. + */ +void rpc_xprt_switch_set_roundrobin(struct rpc_xprt_switch *xps) +{ + if (READ_ONCE(xps->xps_iter_ops) != &rpc_xprt_iter_roundrobin) + WRITE_ONCE(xps->xps_iter_ops, &rpc_xprt_iter_roundrobin); +} + +static +const struct rpc_xprt_iter_ops *xprt_iter_ops(const struct rpc_xprt_iter *xpi) +{ + if (xpi->xpi_ops != NULL) + return xpi->xpi_ops; + return rcu_dereference(xpi->xpi_xpswitch)->xps_iter_ops; +} + +static +void xprt_iter_no_rewind(struct rpc_xprt_iter *xpi) +{ +} + +static +void xprt_iter_default_rewind(struct rpc_xprt_iter *xpi) +{ + WRITE_ONCE(xpi->xpi_cursor, NULL); +} + +static +struct rpc_xprt *xprt_switch_find_first_entry(struct list_head *head) +{ + return list_first_or_null_rcu(head, struct rpc_xprt, xprt_switch); +} + +static +struct rpc_xprt *xprt_iter_first_entry(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_find_first_entry(&xps->xps_xprt_list); +} + +static +struct rpc_xprt *xprt_switch_find_current_entry(struct list_head *head, + const struct rpc_xprt *cur) +{ + struct rpc_xprt *pos; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == pos) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_iter_current_entry(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + struct list_head *head; + + if (xps == NULL) + return NULL; + head = &xps->xps_xprt_list; + if (xpi->xpi_cursor == NULL || xps->xps_nxprts < 2) + return xprt_switch_find_first_entry(head); + return xprt_switch_find_current_entry(head, xpi->xpi_cursor); +} + +bool rpc_xprt_switch_has_addr(struct rpc_xprt_switch *xps, + const struct sockaddr *sap) +{ + struct list_head *head; + struct rpc_xprt *pos; + + if (xps == NULL || sap == NULL) + return false; + + head = &xps->xps_xprt_list; + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (rpc_cmp_addr_port(sap, (struct sockaddr *)&pos->addr)) { + pr_info("RPC: addr %s already in xprt switch\n", + pos->address_strings[RPC_DISPLAY_ADDR]); + return true; + } + } + return false; +} + +static +struct rpc_xprt *xprt_switch_find_next_entry(struct list_head *head, + const struct rpc_xprt *cur) +{ + struct rpc_xprt *pos, *prev = NULL; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == prev) + return pos; + prev = pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_set_next_cursor(struct list_head *head, + struct rpc_xprt **cursor, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt *cur, *pos, *old; + + cur = READ_ONCE(*cursor); + for (;;) { + old = cur; + pos = find_next(head, old); + if (pos == NULL) + break; + cur = cmpxchg_relaxed(cursor, old, pos); + if (cur == old) + break; + } + return pos; +} + +static +struct rpc_xprt *xprt_iter_next_entry_multiple(struct rpc_xprt_iter *xpi, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_set_next_cursor(&xps->xps_xprt_list, + &xpi->xpi_cursor, + find_next); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_roundrobin(struct list_head *head, + const struct rpc_xprt *cur) +{ + struct rpc_xprt *ret; + + ret = xprt_switch_find_next_entry(head, cur); + if (ret != NULL) + return ret; + return xprt_switch_find_first_entry(head); +} + +static +struct rpc_xprt *xprt_iter_next_entry_roundrobin(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_roundrobin); +} + +static +struct rpc_xprt *xprt_iter_next_entry_all(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, xprt_switch_find_next_entry); +} + +/* + * xprt_iter_rewind - Resets the xprt iterator + * @xpi: pointer to rpc_xprt_iter + * + * Resets xpi to ensure that it points to the first entry in the list + * of transports. + */ +static +void xprt_iter_rewind(struct rpc_xprt_iter *xpi) +{ + rcu_read_lock(); + xprt_iter_ops(xpi)->xpi_rewind(xpi); + rcu_read_unlock(); +} + +static void __xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps, + const struct rpc_xprt_iter_ops *ops) +{ + rcu_assign_pointer(xpi->xpi_xpswitch, xprt_switch_get(xps)); + xpi->xpi_cursor = NULL; + xpi->xpi_ops = ops; +} + +/** + * xprt_iter_init - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to use the default iterator ops + * as set in xps. This function is mainly intended for internal + * use in the rpc_client. + */ +void xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, NULL); +} + +/** + * xprt_iter_init_listall - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to iterate once through the entire list + * of entries in xps. + */ +void xprt_iter_init_listall(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listall); +} + +/** + * xprt_iter_xchg_switch - Atomically swap out the rpc_xprt_switch + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to a new rpc_xprt_switch or NULL + * + * Swaps out the existing xpi->xpi_xpswitch with a new value. + */ +struct rpc_xprt_switch *xprt_iter_xchg_switch(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *newswitch) +{ + struct rpc_xprt_switch __rcu *oldswitch; + + /* Atomically swap out the old xpswitch */ + oldswitch = xchg(&xpi->xpi_xpswitch, RCU_INITIALIZER(newswitch)); + if (newswitch != NULL) + xprt_iter_rewind(xpi); + return rcu_dereference_protected(oldswitch, true); +} + +/** + * xprt_iter_destroy - Destroys the xprt iterator + * @xpi pointer to rpc_xprt_iter + */ +void xprt_iter_destroy(struct rpc_xprt_iter *xpi) +{ + xprt_switch_put(xprt_iter_xchg_switch(xpi, NULL)); +} + +/** + * xprt_iter_xprt - Returns the rpc_xprt pointed to by the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a pointer to the struct rpc_xprt that is currently + * pointed to by the cursor. + * Caller must be holding rcu_read_lock(). + */ +struct rpc_xprt *xprt_iter_xprt(struct rpc_xprt_iter *xpi) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + return xprt_iter_ops(xpi)->xpi_xprt(xpi); +} + +static +struct rpc_xprt *xprt_iter_get_helper(struct rpc_xprt_iter *xpi, + struct rpc_xprt *(*fn)(struct rpc_xprt_iter *)) +{ + struct rpc_xprt *ret; + + do { + ret = fn(xpi); + if (ret == NULL) + break; + ret = xprt_get(ret); + } while (ret == NULL); + return ret; +} + +/** + * xprt_iter_get_xprt - Returns the rpc_xprt pointed to by the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a reference to the struct rpc_xprt that is currently + * pointed to by the cursor. + */ +struct rpc_xprt *xprt_iter_get_xprt(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_xprt); + rcu_read_unlock(); + return xprt; +} + +/** + * xprt_iter_get_next - Returns the next rpc_xprt following the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a reference to the struct rpc_xprt that immediately follows the + * entry pointed to by the cursor. + */ +struct rpc_xprt *xprt_iter_get_next(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_next); + rcu_read_unlock(); + return xprt; +} + +/* Policy for always returning the first entry in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_singular = { + .xpi_rewind = xprt_iter_no_rewind, + .xpi_xprt = xprt_iter_first_entry, + .xpi_next = xprt_iter_first_entry, +}; + +/* Policy for round-robin iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_roundrobin, +}; + +/* Policy for once-through iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_listall = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_all, +}; diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile new file mode 100644 index 000000000..8bf19e142 --- /dev/null +++ b/net/sunrpc/xprtrdma/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += rpcrdma.o + +rpcrdma-y := transport.o rpc_rdma.o verbs.o \ + fmr_ops.o frwr_ops.o \ + svc_rdma.o svc_rdma_backchannel.o svc_rdma_transport.o \ + svc_rdma_sendto.o svc_rdma_recvfrom.o svc_rdma_rw.o \ + module.o +rpcrdma-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel.o diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c new file mode 100644 index 000000000..90adeff4c --- /dev/null +++ b/net/sunrpc/xprtrdma/backchannel.c @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015 Oracle. All rights reserved. + * + * Support for backward direction RPCs on RPC/RDMA. + */ + +#include <linux/module.h> +#include <linux/sunrpc/xprt.h> +#include <linux/sunrpc/svc.h> +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +#undef RPCRDMA_BACKCHANNEL_DEBUG + +static void rpcrdma_bc_free_rqst(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + + spin_lock(&buf->rb_reqslock); + list_del(&req->rl_all); + spin_unlock(&buf->rb_reqslock); + + rpcrdma_destroy_req(req); +} + +static int rpcrdma_bc_setup_reqs(struct rpcrdma_xprt *r_xprt, + unsigned int count) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpc_rqst *rqst; + unsigned int i; + + for (i = 0; i < (count << 1); i++) { + struct rpcrdma_regbuf *rb; + struct rpcrdma_req *req; + size_t size; + + req = rpcrdma_create_req(r_xprt); + if (IS_ERR(req)) + return PTR_ERR(req); + rqst = &req->rl_slot; + + rqst->rq_xprt = xprt; + INIT_LIST_HEAD(&rqst->rq_list); + INIT_LIST_HEAD(&rqst->rq_bc_list); + __set_bit(RPC_BC_PA_IN_USE, &rqst->rq_bc_pa_state); + spin_lock_bh(&xprt->bc_pa_lock); + list_add(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); + spin_unlock_bh(&xprt->bc_pa_lock); + + size = r_xprt->rx_data.inline_rsize; + rb = rpcrdma_alloc_regbuf(size, DMA_TO_DEVICE, GFP_KERNEL); + if (IS_ERR(rb)) + goto out_fail; + req->rl_sendbuf = rb; + xdr_buf_init(&rqst->rq_snd_buf, rb->rg_base, + min_t(size_t, size, PAGE_SIZE)); + } + return 0; + +out_fail: + rpcrdma_bc_free_rqst(r_xprt, rqst); + return -ENOMEM; +} + +/** + * xprt_rdma_bc_setup - Pre-allocate resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of concurrent incoming requests to expect + * + * Returns 0 on success; otherwise a negative errno + */ +int xprt_rdma_bc_setup(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc; + + /* The backchannel reply path returns each rpc_rqst to the + * bc_pa_list _after_ the reply is sent. If the server is + * faster than the client, it can send another backward + * direction request before the rpc_rqst is returned to the + * list. The client rejects the request in this case. + * + * Twice as many rpc_rqsts are prepared to ensure there is + * always an rpc_rqst available as soon as a reply is sent. + */ + if (reqs > RPCRDMA_BACKWARD_WRS >> 1) + goto out_err; + + rc = rpcrdma_bc_setup_reqs(r_xprt, reqs); + if (rc) + goto out_free; + + r_xprt->rx_buf.rb_bc_srv_max_requests = reqs; + request_module("svcrdma"); + trace_xprtrdma_cb_setup(r_xprt, reqs); + return 0; + +out_free: + xprt_rdma_bc_destroy(xprt, reqs); + +out_err: + pr_err("RPC: %s: setup backchannel transport failed\n", __func__); + return -ENOMEM; +} + +/** + * xprt_rdma_bc_up - Create transport endpoint for backchannel service + * @serv: server endpoint + * @net: network namespace + * + * The "xprt" is an implied argument: it supplies the name of the + * backchannel transport class. + * + * Returns zero on success, negative errno on failure + */ +int xprt_rdma_bc_up(struct svc_serv *serv, struct net *net) +{ + int ret; + + ret = svc_create_xprt(serv, "rdma-bc", net, PF_INET, 0, 0); + if (ret < 0) + return ret; + return 0; +} + +/** + * xprt_rdma_bc_maxpayload - Return maximum backchannel message size + * @xprt: transport + * + * Returns maximum size, in bytes, of a backchannel message + */ +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_create_data_internal *cdata = &r_xprt->rx_data; + size_t maxmsg; + + maxmsg = min_t(unsigned int, cdata->inline_rsize, cdata->inline_wsize); + maxmsg = min_t(unsigned int, maxmsg, PAGE_SIZE); + return maxmsg - RPCRDMA_HDRLEN_MIN; +} + +static int rpcrdma_bc_marshal_reply(struct rpc_rqst *rqst) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + __be32 *p; + + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(&req->rl_stream, &req->rl_hdrbuf, + req->rl_rdmabuf->rg_base); + + p = xdr_reserve_space(&req->rl_stream, 28); + if (unlikely(!p)) + return -EIO; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_srv_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + if (rpcrdma_prepare_send_sges(r_xprt, req, RPCRDMA_HDRLEN_MIN, + &rqst->rq_snd_buf, rpcrdma_noch)) + return -EIO; + + trace_xprtrdma_cb_reply(rqst); + return 0; +} + +/** + * xprt_rdma_bc_send_reply - marshal and send a backchannel reply + * @rqst: RPC rqst with a backchannel RPC reply in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EIO if a permanent error occurred and the request was not + * sent. Do not try to send this message again. + */ +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + int rc; + + if (!xprt_connected(rqst->rq_xprt)) + goto drop_connection; + + rc = rpcrdma_bc_marshal_reply(rqst); + if (rc < 0) + goto failed_marshal; + + rpcrdma_post_recvs(r_xprt, true); + if (rpcrdma_ep_post(&r_xprt->rx_ia, &r_xprt->rx_ep, req)) + goto drop_connection; + return 0; + +failed_marshal: + if (rc != -ENOTCONN) + return rc; +drop_connection: + xprt_disconnect_done(rqst->rq_xprt); + return -ENOTCONN; +} + +/** + * xprt_rdma_bc_destroy - Release resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of incoming requests to destroy; ignored + */ +void xprt_rdma_bc_destroy(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpc_rqst *rqst, *tmp; + + spin_lock_bh(&xprt->bc_pa_lock); + list_for_each_entry_safe(rqst, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { + list_del(&rqst->rq_bc_pa_list); + spin_unlock_bh(&xprt->bc_pa_lock); + + rpcrdma_bc_free_rqst(r_xprt, rqst); + + spin_lock_bh(&xprt->bc_pa_lock); + } + spin_unlock_bh(&xprt->bc_pa_lock); +} + +/** + * xprt_rdma_bc_free_rqst - Release a backchannel rqst + * @rqst: request to release + */ +void xprt_rdma_bc_free_rqst(struct rpc_rqst *rqst) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpc_xprt *xprt = rqst->rq_xprt; + + dprintk("RPC: %s: freeing rqst %p (req %p)\n", + __func__, rqst, req); + + rpcrdma_recv_buffer_put(req->rl_reply); + req->rl_reply = NULL; + + spin_lock_bh(&xprt->bc_pa_lock); + list_add_tail(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); + spin_unlock_bh(&xprt->bc_pa_lock); +} + +/** + * rpcrdma_bc_receive_call - Handle a backward direction call + * @r_xprt: transport receiving the call + * @rep: receive buffer containing the call + * + * Operational assumptions: + * o Backchannel credits are ignored, just as the NFS server + * forechannel currently does + * o The ULP manages a replay cache (eg, NFSv4.1 sessions). + * No replay detection is done at the transport level + */ +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_rep *rep) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct svc_serv *bc_serv; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + struct xdr_buf *buf; + size_t size; + __be32 *p; + + p = xdr_inline_decode(&rep->rr_stream, 0); + size = xdr_stream_remaining(&rep->rr_stream); + +#ifdef RPCRDMA_BACKCHANNEL_DEBUG + pr_info("RPC: %s: callback XID %08x, length=%u\n", + __func__, be32_to_cpup(p), size); + pr_info("RPC: %s: %*ph\n", __func__, size, p); +#endif + + /* Grab a free bc rqst */ + spin_lock(&xprt->bc_pa_lock); + if (list_empty(&xprt->bc_pa_list)) { + spin_unlock(&xprt->bc_pa_lock); + goto out_overflow; + } + rqst = list_first_entry(&xprt->bc_pa_list, + struct rpc_rqst, rq_bc_pa_list); + list_del(&rqst->rq_bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + + /* Prepare rqst */ + rqst->rq_reply_bytes_recvd = 0; + rqst->rq_bytes_sent = 0; + rqst->rq_xid = *p; + + rqst->rq_private_buf.len = size; + + buf = &rqst->rq_rcv_buf; + memset(buf, 0, sizeof(*buf)); + buf->head[0].iov_base = p; + buf->head[0].iov_len = size; + buf->len = size; + + /* The receive buffer has to be hooked to the rpcrdma_req + * so that it is not released while the req is pointing + * to its buffer, and so that it can be reposted after + * the Upper Layer is done decoding it. + */ + req = rpcr_to_rdmar(rqst); + req->rl_reply = rep; + trace_xprtrdma_cb_call(rqst); + + /* Queue rqst for ULP's callback service */ + bc_serv = xprt->bc_serv; + spin_lock(&bc_serv->sv_cb_lock); + list_add(&rqst->rq_bc_list, &bc_serv->sv_cb_list); + spin_unlock(&bc_serv->sv_cb_lock); + + wake_up(&bc_serv->sv_cb_waitq); + + r_xprt->rx_stats.bcall_count++; + return; + +out_overflow: + pr_warn("RPC/RDMA backchannel overflow\n"); + xprt_disconnect_done(xprt); + /* This receive buffer gets reposted automatically + * when the connection is re-established. + */ + return; +} diff --git a/net/sunrpc/xprtrdma/fmr_ops.c b/net/sunrpc/xprtrdma/fmr_ops.c new file mode 100644 index 000000000..0f7c465d9 --- /dev/null +++ b/net/sunrpc/xprtrdma/fmr_ops.c @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + */ + +/* Lightweight memory registration using Fast Memory Regions (FMR). + * Referred to sometimes as MTHCAFMR mode. + * + * FMR uses synchronous memory registration and deregistration. + * FMR registration is known to be fast, but FMR deregistration + * can take tens of usecs to complete. + */ + +/* Normal operation + * + * A Memory Region is prepared for RDMA READ or WRITE using the + * ib_map_phys_fmr verb (fmr_op_map). When the RDMA operation is + * finished, the Memory Region is unmapped using the ib_unmap_fmr + * verb (fmr_op_unmap). + */ + +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* Maximum scatter/gather per FMR */ +#define RPCRDMA_MAX_FMR_SGES (64) + +/* Access mode of externally registered pages */ +enum { + RPCRDMA_FMR_ACCESS_FLAGS = IB_ACCESS_REMOTE_WRITE | + IB_ACCESS_REMOTE_READ, +}; + +bool +fmr_is_supported(struct rpcrdma_ia *ia) +{ + if (!ia->ri_device->alloc_fmr) { + pr_info("rpcrdma: 'fmr' mode is not supported by device %s\n", + ia->ri_device->name); + return false; + } + return true; +} + +static int +fmr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) +{ + static struct ib_fmr_attr fmr_attr = { + .max_pages = RPCRDMA_MAX_FMR_SGES, + .max_maps = 1, + .page_shift = PAGE_SHIFT + }; + + mr->fmr.fm_physaddrs = kcalloc(RPCRDMA_MAX_FMR_SGES, + sizeof(u64), GFP_KERNEL); + if (!mr->fmr.fm_physaddrs) + goto out_free; + + mr->mr_sg = kcalloc(RPCRDMA_MAX_FMR_SGES, + sizeof(*mr->mr_sg), GFP_KERNEL); + if (!mr->mr_sg) + goto out_free; + + sg_init_table(mr->mr_sg, RPCRDMA_MAX_FMR_SGES); + + mr->fmr.fm_mr = ib_alloc_fmr(ia->ri_pd, RPCRDMA_FMR_ACCESS_FLAGS, + &fmr_attr); + if (IS_ERR(mr->fmr.fm_mr)) + goto out_fmr_err; + + INIT_LIST_HEAD(&mr->mr_list); + return 0; + +out_fmr_err: + dprintk("RPC: %s: ib_alloc_fmr returned %ld\n", __func__, + PTR_ERR(mr->fmr.fm_mr)); + +out_free: + kfree(mr->mr_sg); + kfree(mr->fmr.fm_physaddrs); + return -ENOMEM; +} + +static int +__fmr_unmap(struct rpcrdma_mr *mr) +{ + LIST_HEAD(l); + int rc; + + list_add(&mr->fmr.fm_mr->list, &l); + rc = ib_unmap_fmr(&l); + list_del(&mr->fmr.fm_mr->list); + return rc; +} + +static void +fmr_op_release_mr(struct rpcrdma_mr *mr) +{ + LIST_HEAD(unmap_list); + int rc; + + kfree(mr->fmr.fm_physaddrs); + kfree(mr->mr_sg); + + /* In case this one was left mapped, try to unmap it + * to prevent dealloc_fmr from failing with EBUSY + */ + rc = __fmr_unmap(mr); + if (rc) + pr_err("rpcrdma: final ib_unmap_fmr for %p failed %i\n", + mr, rc); + + rc = ib_dealloc_fmr(mr->fmr.fm_mr); + if (rc) + pr_err("rpcrdma: final ib_dealloc_fmr for %p returned %i\n", + mr, rc); + + kfree(mr); +} + +/* Reset of a single FMR. + */ +static void +fmr_op_recover_mr(struct rpcrdma_mr *mr) +{ + struct rpcrdma_xprt *r_xprt = mr->mr_xprt; + int rc; + + /* ORDER: invalidate first */ + rc = __fmr_unmap(mr); + if (rc) + goto out_release; + + /* ORDER: then DMA unmap */ + rpcrdma_mr_unmap_and_put(mr); + + r_xprt->rx_stats.mrs_recovered++; + return; + +out_release: + pr_err("rpcrdma: FMR reset failed (%d), %p released\n", rc, mr); + r_xprt->rx_stats.mrs_orphaned++; + + trace_xprtrdma_dma_unmap(mr); + ib_dma_unmap_sg(r_xprt->rx_ia.ri_device, + mr->mr_sg, mr->mr_nents, mr->mr_dir); + + spin_lock(&r_xprt->rx_buf.rb_mrlock); + list_del(&mr->mr_all); + spin_unlock(&r_xprt->rx_buf.rb_mrlock); + + fmr_op_release_mr(mr); +} + +/* On success, sets: + * ep->rep_attr.cap.max_send_wr + * ep->rep_attr.cap.max_recv_wr + * cdata->max_requests + * ia->ri_max_segs + */ +static int +fmr_op_open(struct rpcrdma_ia *ia, struct rpcrdma_ep *ep, + struct rpcrdma_create_data_internal *cdata) +{ + int max_qp_wr; + + max_qp_wr = ia->ri_device->attrs.max_qp_wr; + max_qp_wr -= RPCRDMA_BACKWARD_WRS; + max_qp_wr -= 1; + if (max_qp_wr < RPCRDMA_MIN_SLOT_TABLE) + return -ENOMEM; + if (cdata->max_requests > max_qp_wr) + cdata->max_requests = max_qp_wr; + ep->rep_attr.cap.max_send_wr = cdata->max_requests; + ep->rep_attr.cap.max_send_wr += RPCRDMA_BACKWARD_WRS; + ep->rep_attr.cap.max_send_wr += 1; /* for ib_drain_sq */ + ep->rep_attr.cap.max_recv_wr = cdata->max_requests; + ep->rep_attr.cap.max_recv_wr += RPCRDMA_BACKWARD_WRS; + ep->rep_attr.cap.max_recv_wr += 1; /* for ib_drain_rq */ + + ia->ri_max_segs = max_t(unsigned int, 1, RPCRDMA_MAX_DATA_SEGS / + RPCRDMA_MAX_FMR_SGES); + return 0; +} + +/* FMR mode conveys up to 64 pages of payload per chunk segment. + */ +static size_t +fmr_op_maxpages(struct rpcrdma_xprt *r_xprt) +{ + return min_t(unsigned int, RPCRDMA_MAX_DATA_SEGS, + RPCRDMA_MAX_HDR_SEGS * RPCRDMA_MAX_FMR_SGES); +} + +/* Use the ib_map_phys_fmr() verb to register a memory region + * for remote access via RDMA READ or RDMA WRITE. + */ +static struct rpcrdma_mr_seg * +fmr_op_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, struct rpcrdma_mr **out) +{ + struct rpcrdma_mr_seg *seg1 = seg; + int len, pageoff, i, rc; + struct rpcrdma_mr *mr; + u64 *dma_pages; + + mr = rpcrdma_mr_get(r_xprt); + if (!mr) + return ERR_PTR(-EAGAIN); + + pageoff = offset_in_page(seg1->mr_offset); + seg1->mr_offset -= pageoff; /* start of page */ + seg1->mr_len += pageoff; + len = -pageoff; + if (nsegs > RPCRDMA_MAX_FMR_SGES) + nsegs = RPCRDMA_MAX_FMR_SGES; + for (i = 0; i < nsegs;) { + if (seg->mr_page) + sg_set_page(&mr->mr_sg[i], + seg->mr_page, + seg->mr_len, + offset_in_page(seg->mr_offset)); + else + sg_set_buf(&mr->mr_sg[i], seg->mr_offset, + seg->mr_len); + len += seg->mr_len; + ++seg; + ++i; + /* Check for holes */ + if ((i < nsegs && offset_in_page(seg->mr_offset)) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + mr->mr_dir = rpcrdma_data_dir(writing); + + mr->mr_nents = ib_dma_map_sg(r_xprt->rx_ia.ri_device, + mr->mr_sg, i, mr->mr_dir); + if (!mr->mr_nents) + goto out_dmamap_err; + trace_xprtrdma_dma_map(mr); + + for (i = 0, dma_pages = mr->fmr.fm_physaddrs; i < mr->mr_nents; i++) + dma_pages[i] = sg_dma_address(&mr->mr_sg[i]); + rc = ib_map_phys_fmr(mr->fmr.fm_mr, dma_pages, mr->mr_nents, + dma_pages[0]); + if (rc) + goto out_maperr; + + mr->mr_handle = mr->fmr.fm_mr->rkey; + mr->mr_length = len; + mr->mr_offset = dma_pages[0] + pageoff; + + *out = mr; + return seg; + +out_dmamap_err: + pr_err("rpcrdma: failed to DMA map sg %p sg_nents %d\n", + mr->mr_sg, i); + rpcrdma_mr_put(mr); + return ERR_PTR(-EIO); + +out_maperr: + pr_err("rpcrdma: ib_map_phys_fmr %u@0x%llx+%i (%d) status %i\n", + len, (unsigned long long)dma_pages[0], + pageoff, mr->mr_nents, rc); + rpcrdma_mr_unmap_and_put(mr); + return ERR_PTR(-EIO); +} + +/* Post Send WR containing the RPC Call message. + */ +static int +fmr_op_send(struct rpcrdma_ia *ia, struct rpcrdma_req *req) +{ + return ib_post_send(ia->ri_id->qp, &req->rl_sendctx->sc_wr, NULL); +} + +/* Invalidate all memory regions that were registered for "req". + * + * Sleeps until it is safe for the host CPU to access the + * previously mapped memory regions. + * + * Caller ensures that @mrs is not empty before the call. This + * function empties the list. + */ +static void +fmr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) +{ + struct rpcrdma_mr *mr; + LIST_HEAD(unmap_list); + int rc; + + /* ORDER: Invalidate all of the req's MRs first + * + * ib_unmap_fmr() is slow, so use a single call instead + * of one call per mapped FMR. + */ + list_for_each_entry(mr, mrs, mr_list) { + dprintk("RPC: %s: unmapping fmr %p\n", + __func__, &mr->fmr); + trace_xprtrdma_localinv(mr); + list_add_tail(&mr->fmr.fm_mr->list, &unmap_list); + } + r_xprt->rx_stats.local_inv_needed++; + rc = ib_unmap_fmr(&unmap_list); + if (rc) + goto out_reset; + + /* ORDER: Now DMA unmap all of the req's MRs, and return + * them to the free MW list. + */ + while (!list_empty(mrs)) { + mr = rpcrdma_mr_pop(mrs); + list_del(&mr->fmr.fm_mr->list); + rpcrdma_mr_unmap_and_put(mr); + } + + return; + +out_reset: + pr_err("rpcrdma: ib_unmap_fmr failed (%i)\n", rc); + + while (!list_empty(mrs)) { + mr = rpcrdma_mr_pop(mrs); + list_del(&mr->fmr.fm_mr->list); + fmr_op_recover_mr(mr); + } +} + +const struct rpcrdma_memreg_ops rpcrdma_fmr_memreg_ops = { + .ro_map = fmr_op_map, + .ro_send = fmr_op_send, + .ro_unmap_sync = fmr_op_unmap_sync, + .ro_recover_mr = fmr_op_recover_mr, + .ro_open = fmr_op_open, + .ro_maxpages = fmr_op_maxpages, + .ro_init_mr = fmr_op_init_mr, + .ro_release_mr = fmr_op_release_mr, + .ro_displayname = "fmr", + .ro_send_w_inv_ok = 0, +}; diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c new file mode 100644 index 000000000..1bb00dd6c --- /dev/null +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -0,0 +1,615 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + */ + +/* Lightweight memory registration using Fast Registration Work + * Requests (FRWR). + * + * FRWR features ordered asynchronous registration and deregistration + * of arbitrarily sized memory regions. This is the fastest and safest + * but most complex memory registration mode. + */ + +/* Normal operation + * + * A Memory Region is prepared for RDMA READ or WRITE using a FAST_REG + * Work Request (frwr_op_map). When the RDMA operation is finished, this + * Memory Region is invalidated using a LOCAL_INV Work Request + * (frwr_op_unmap_sync). + * + * Typically these Work Requests are not signaled, and neither are RDMA + * SEND Work Requests (with the exception of signaling occasionally to + * prevent provider work queue overflows). This greatly reduces HCA + * interrupt workload. + * + * As an optimization, frwr_op_unmap marks MRs INVALID before the + * LOCAL_INV WR is posted. If posting succeeds, the MR is placed on + * rb_mrs immediately so that no work (like managing a linked list + * under a spinlock) is needed in the completion upcall. + * + * But this means that frwr_op_map() can occasionally encounter an MR + * that is INVALID but the LOCAL_INV WR has not completed. Work Queue + * ordering prevents a subsequent FAST_REG WR from executing against + * that MR while it is still being invalidated. + */ + +/* Transport recovery + * + * ->op_map and the transport connect worker cannot run at the same + * time, but ->op_unmap can fire while the transport connect worker + * is running. Thus MR recovery is handled in ->op_map, to guarantee + * that recovered MRs are owned by a sending RPC, and not one where + * ->op_unmap could fire at the same time transport reconnect is + * being done. + * + * When the underlying transport disconnects, MRs are left in one of + * four states: + * + * INVALID: The MR was not in use before the QP entered ERROR state. + * + * VALID: The MR was registered before the QP entered ERROR state. + * + * FLUSHED_FR: The MR was being registered when the QP entered ERROR + * state, and the pending WR was flushed. + * + * FLUSHED_LI: The MR was being invalidated when the QP entered ERROR + * state, and the pending WR was flushed. + * + * When frwr_op_map encounters FLUSHED and VALID MRs, they are recovered + * with ib_dereg_mr and then are re-initialized. Because MR recovery + * allocates fresh resources, it is deferred to a workqueue, and the + * recovered MRs are placed back on the rb_mrs list when recovery is + * complete. frwr_op_map allocates another MR for the current RPC while + * the broken MR is reset. + * + * To ensure that frwr_op_map doesn't encounter an MR that is marked + * INVALID but that is about to be flushed due to a previous transport + * disconnect, the transport connect worker attempts to drain all + * pending send queue WRs before the transport is reconnected. + */ + +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +bool +frwr_is_supported(struct rpcrdma_ia *ia) +{ + struct ib_device_attr *attrs = &ia->ri_device->attrs; + + if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) + goto out_not_supported; + if (attrs->max_fast_reg_page_list_len == 0) + goto out_not_supported; + return true; + +out_not_supported: + pr_info("rpcrdma: 'frwr' mode is not supported by device %s\n", + ia->ri_device->name); + return false; +} + +static int +frwr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) +{ + unsigned int depth = ia->ri_max_frwr_depth; + struct rpcrdma_frwr *frwr = &mr->frwr; + int rc; + + frwr->fr_mr = ib_alloc_mr(ia->ri_pd, ia->ri_mrtype, depth); + if (IS_ERR(frwr->fr_mr)) + goto out_mr_err; + + mr->mr_sg = kcalloc(depth, sizeof(*mr->mr_sg), GFP_KERNEL); + if (!mr->mr_sg) + goto out_list_err; + + INIT_LIST_HEAD(&mr->mr_list); + sg_init_table(mr->mr_sg, depth); + init_completion(&frwr->fr_linv_done); + return 0; + +out_mr_err: + rc = PTR_ERR(frwr->fr_mr); + dprintk("RPC: %s: ib_alloc_mr status %i\n", + __func__, rc); + return rc; + +out_list_err: + rc = -ENOMEM; + dprintk("RPC: %s: sg allocation failure\n", + __func__); + ib_dereg_mr(frwr->fr_mr); + return rc; +} + +static void +frwr_op_release_mr(struct rpcrdma_mr *mr) +{ + int rc; + + rc = ib_dereg_mr(mr->frwr.fr_mr); + if (rc) + pr_err("rpcrdma: final ib_dereg_mr for %p returned %i\n", + mr, rc); + kfree(mr->mr_sg); + kfree(mr); +} + +static int +__frwr_mr_reset(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) +{ + struct rpcrdma_frwr *frwr = &mr->frwr; + int rc; + + rc = ib_dereg_mr(frwr->fr_mr); + if (rc) { + pr_warn("rpcrdma: ib_dereg_mr status %d, frwr %p orphaned\n", + rc, mr); + return rc; + } + + frwr->fr_mr = ib_alloc_mr(ia->ri_pd, ia->ri_mrtype, + ia->ri_max_frwr_depth); + if (IS_ERR(frwr->fr_mr)) { + pr_warn("rpcrdma: ib_alloc_mr status %ld, frwr %p orphaned\n", + PTR_ERR(frwr->fr_mr), mr); + return PTR_ERR(frwr->fr_mr); + } + + dprintk("RPC: %s: recovered FRWR %p\n", __func__, frwr); + frwr->fr_state = FRWR_IS_INVALID; + return 0; +} + +/* Reset of a single FRWR. Generate a fresh rkey by replacing the MR. + */ +static void +frwr_op_recover_mr(struct rpcrdma_mr *mr) +{ + enum rpcrdma_frwr_state state = mr->frwr.fr_state; + struct rpcrdma_xprt *r_xprt = mr->mr_xprt; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + int rc; + + rc = __frwr_mr_reset(ia, mr); + if (state != FRWR_FLUSHED_LI) { + trace_xprtrdma_dma_unmap(mr); + ib_dma_unmap_sg(ia->ri_device, + mr->mr_sg, mr->mr_nents, mr->mr_dir); + } + if (rc) + goto out_release; + + rpcrdma_mr_put(mr); + r_xprt->rx_stats.mrs_recovered++; + return; + +out_release: + pr_err("rpcrdma: FRWR reset failed %d, %p released\n", rc, mr); + r_xprt->rx_stats.mrs_orphaned++; + + spin_lock(&r_xprt->rx_buf.rb_mrlock); + list_del(&mr->mr_all); + spin_unlock(&r_xprt->rx_buf.rb_mrlock); + + frwr_op_release_mr(mr); +} + +/* On success, sets: + * ep->rep_attr.cap.max_send_wr + * ep->rep_attr.cap.max_recv_wr + * cdata->max_requests + * ia->ri_max_segs + * + * And these FRWR-related fields: + * ia->ri_max_frwr_depth + * ia->ri_mrtype + */ +static int +frwr_op_open(struct rpcrdma_ia *ia, struct rpcrdma_ep *ep, + struct rpcrdma_create_data_internal *cdata) +{ + struct ib_device_attr *attrs = &ia->ri_device->attrs; + int max_qp_wr, depth, delta; + + ia->ri_mrtype = IB_MR_TYPE_MEM_REG; + if (attrs->device_cap_flags & IB_DEVICE_SG_GAPS_REG) + ia->ri_mrtype = IB_MR_TYPE_SG_GAPS; + + ia->ri_max_frwr_depth = + min_t(unsigned int, RPCRDMA_MAX_DATA_SEGS, + attrs->max_fast_reg_page_list_len); + dprintk("RPC: %s: device's max FR page list len = %u\n", + __func__, ia->ri_max_frwr_depth); + + /* Add room for frwr register and invalidate WRs. + * 1. FRWR reg WR for head + * 2. FRWR invalidate WR for head + * 3. N FRWR reg WRs for pagelist + * 4. N FRWR invalidate WRs for pagelist + * 5. FRWR reg WR for tail + * 6. FRWR invalidate WR for tail + * 7. The RDMA_SEND WR + */ + depth = 7; + + /* Calculate N if the device max FRWR depth is smaller than + * RPCRDMA_MAX_DATA_SEGS. + */ + if (ia->ri_max_frwr_depth < RPCRDMA_MAX_DATA_SEGS) { + delta = RPCRDMA_MAX_DATA_SEGS - ia->ri_max_frwr_depth; + do { + depth += 2; /* FRWR reg + invalidate */ + delta -= ia->ri_max_frwr_depth; + } while (delta > 0); + } + + max_qp_wr = ia->ri_device->attrs.max_qp_wr; + max_qp_wr -= RPCRDMA_BACKWARD_WRS; + max_qp_wr -= 1; + if (max_qp_wr < RPCRDMA_MIN_SLOT_TABLE) + return -ENOMEM; + if (cdata->max_requests > max_qp_wr) + cdata->max_requests = max_qp_wr; + ep->rep_attr.cap.max_send_wr = cdata->max_requests * depth; + if (ep->rep_attr.cap.max_send_wr > max_qp_wr) { + cdata->max_requests = max_qp_wr / depth; + if (!cdata->max_requests) + return -EINVAL; + ep->rep_attr.cap.max_send_wr = cdata->max_requests * + depth; + } + ep->rep_attr.cap.max_send_wr += RPCRDMA_BACKWARD_WRS; + ep->rep_attr.cap.max_send_wr += 1; /* for ib_drain_sq */ + ep->rep_attr.cap.max_recv_wr = cdata->max_requests; + ep->rep_attr.cap.max_recv_wr += RPCRDMA_BACKWARD_WRS; + ep->rep_attr.cap.max_recv_wr += 1; /* for ib_drain_rq */ + + ia->ri_max_segs = max_t(unsigned int, 1, RPCRDMA_MAX_DATA_SEGS / + ia->ri_max_frwr_depth); + return 0; +} + +/* FRWR mode conveys a list of pages per chunk segment. The + * maximum length of that list is the FRWR page list depth. + */ +static size_t +frwr_op_maxpages(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + + return min_t(unsigned int, RPCRDMA_MAX_DATA_SEGS, + RPCRDMA_MAX_HDR_SEGS * ia->ri_max_frwr_depth); +} + +static void +__frwr_sendcompletion_flush(struct ib_wc *wc, const char *wr) +{ + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("rpcrdma: %s: %s (%u/0x%x)\n", + wr, ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); +} + +/** + * frwr_wc_fastreg - Invoked by RDMA provider for a flushed FastReg WC + * @cq: completion queue (ignored) + * @wc: completed WR + * + */ +static void +frwr_wc_fastreg(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_frwr *frwr = + container_of(cqe, struct rpcrdma_frwr, fr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + if (wc->status != IB_WC_SUCCESS) { + frwr->fr_state = FRWR_FLUSHED_FR; + __frwr_sendcompletion_flush(wc, "fastreg"); + } + trace_xprtrdma_wc_fastreg(wc, frwr); +} + +/** + * frwr_wc_localinv - Invoked by RDMA provider for a flushed LocalInv WC + * @cq: completion queue (ignored) + * @wc: completed WR + * + */ +static void +frwr_wc_localinv(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_frwr *frwr = container_of(cqe, struct rpcrdma_frwr, + fr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + if (wc->status != IB_WC_SUCCESS) { + frwr->fr_state = FRWR_FLUSHED_LI; + __frwr_sendcompletion_flush(wc, "localinv"); + } + trace_xprtrdma_wc_li(wc, frwr); +} + +/** + * frwr_wc_localinv_wake - Invoked by RDMA provider for a signaled LocalInv WC + * @cq: completion queue (ignored) + * @wc: completed WR + * + * Awaken anyone waiting for an MR to finish being fenced. + */ +static void +frwr_wc_localinv_wake(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_frwr *frwr = container_of(cqe, struct rpcrdma_frwr, + fr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + if (wc->status != IB_WC_SUCCESS) { + frwr->fr_state = FRWR_FLUSHED_LI; + __frwr_sendcompletion_flush(wc, "localinv"); + } + complete(&frwr->fr_linv_done); + trace_xprtrdma_wc_li_wake(wc, frwr); +} + +/* Post a REG_MR Work Request to register a memory region + * for remote access via RDMA READ or RDMA WRITE. + */ +static struct rpcrdma_mr_seg * +frwr_op_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, struct rpcrdma_mr **out) +{ + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + bool holes_ok = ia->ri_mrtype == IB_MR_TYPE_SG_GAPS; + struct rpcrdma_frwr *frwr; + struct rpcrdma_mr *mr; + struct ib_mr *ibmr; + struct ib_reg_wr *reg_wr; + int i, n; + u8 key; + + mr = NULL; + do { + if (mr) + rpcrdma_mr_defer_recovery(mr); + mr = rpcrdma_mr_get(r_xprt); + if (!mr) + return ERR_PTR(-EAGAIN); + } while (mr->frwr.fr_state != FRWR_IS_INVALID); + frwr = &mr->frwr; + frwr->fr_state = FRWR_IS_VALID; + + if (nsegs > ia->ri_max_frwr_depth) + nsegs = ia->ri_max_frwr_depth; + for (i = 0; i < nsegs;) { + if (seg->mr_page) + sg_set_page(&mr->mr_sg[i], + seg->mr_page, + seg->mr_len, + offset_in_page(seg->mr_offset)); + else + sg_set_buf(&mr->mr_sg[i], seg->mr_offset, + seg->mr_len); + + ++seg; + ++i; + if (holes_ok) + continue; + if ((i < nsegs && offset_in_page(seg->mr_offset)) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + mr->mr_dir = rpcrdma_data_dir(writing); + + mr->mr_nents = ib_dma_map_sg(ia->ri_device, mr->mr_sg, i, mr->mr_dir); + if (!mr->mr_nents) + goto out_dmamap_err; + trace_xprtrdma_dma_map(mr); + + ibmr = frwr->fr_mr; + n = ib_map_mr_sg(ibmr, mr->mr_sg, mr->mr_nents, NULL, PAGE_SIZE); + if (unlikely(n != mr->mr_nents)) + goto out_mapmr_err; + + key = (u8)(ibmr->rkey & 0x000000FF); + ib_update_fast_reg_key(ibmr, ++key); + + reg_wr = &frwr->fr_regwr; + reg_wr->mr = ibmr; + reg_wr->key = ibmr->rkey; + reg_wr->access = writing ? + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : + IB_ACCESS_REMOTE_READ; + + mr->mr_handle = ibmr->rkey; + mr->mr_length = ibmr->length; + mr->mr_offset = ibmr->iova; + + *out = mr; + return seg; + +out_dmamap_err: + pr_err("rpcrdma: failed to DMA map sg %p sg_nents %d\n", + mr->mr_sg, i); + frwr->fr_state = FRWR_IS_INVALID; + rpcrdma_mr_put(mr); + return ERR_PTR(-EIO); + +out_mapmr_err: + pr_err("rpcrdma: failed to map mr %p (%d/%d)\n", + frwr->fr_mr, n, mr->mr_nents); + rpcrdma_mr_defer_recovery(mr); + return ERR_PTR(-EIO); +} + +/* Post Send WR containing the RPC Call message. + * + * For FRMR, chain any FastReg WRs to the Send WR. Only a + * single ib_post_send call is needed to register memory + * and then post the Send WR. + */ +static int +frwr_op_send(struct rpcrdma_ia *ia, struct rpcrdma_req *req) +{ + struct ib_send_wr *post_wr; + struct rpcrdma_mr *mr; + + post_wr = &req->rl_sendctx->sc_wr; + list_for_each_entry(mr, &req->rl_registered, mr_list) { + struct rpcrdma_frwr *frwr; + + frwr = &mr->frwr; + + frwr->fr_cqe.done = frwr_wc_fastreg; + frwr->fr_regwr.wr.next = post_wr; + frwr->fr_regwr.wr.wr_cqe = &frwr->fr_cqe; + frwr->fr_regwr.wr.num_sge = 0; + frwr->fr_regwr.wr.opcode = IB_WR_REG_MR; + frwr->fr_regwr.wr.send_flags = 0; + + post_wr = &frwr->fr_regwr.wr; + } + + /* If ib_post_send fails, the next ->send_request for + * @req will queue these MWs for recovery. + */ + return ib_post_send(ia->ri_id->qp, post_wr, NULL); +} + +/* Handle a remotely invalidated mr on the @mrs list + */ +static void +frwr_op_reminv(struct rpcrdma_rep *rep, struct list_head *mrs) +{ + struct rpcrdma_mr *mr; + + list_for_each_entry(mr, mrs, mr_list) + if (mr->mr_handle == rep->rr_inv_rkey) { + list_del_init(&mr->mr_list); + trace_xprtrdma_remoteinv(mr); + mr->frwr.fr_state = FRWR_IS_INVALID; + rpcrdma_mr_unmap_and_put(mr); + break; /* only one invalidated MR per RPC */ + } +} + +/* Invalidate all memory regions that were registered for "req". + * + * Sleeps until it is safe for the host CPU to access the + * previously mapped memory regions. + * + * Caller ensures that @mrs is not empty before the call. This + * function empties the list. + */ +static void +frwr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) +{ + struct ib_send_wr *first, **prev, *last; + const struct ib_send_wr *bad_wr; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + struct rpcrdma_frwr *frwr; + struct rpcrdma_mr *mr; + int count, rc; + + /* ORDER: Invalidate all of the MRs first + * + * Chain the LOCAL_INV Work Requests and post them with + * a single ib_post_send() call. + */ + frwr = NULL; + count = 0; + prev = &first; + list_for_each_entry(mr, mrs, mr_list) { + mr->frwr.fr_state = FRWR_IS_INVALID; + + frwr = &mr->frwr; + trace_xprtrdma_localinv(mr); + + frwr->fr_cqe.done = frwr_wc_localinv; + last = &frwr->fr_invwr; + memset(last, 0, sizeof(*last)); + last->wr_cqe = &frwr->fr_cqe; + last->opcode = IB_WR_LOCAL_INV; + last->ex.invalidate_rkey = mr->mr_handle; + count++; + + *prev = last; + prev = &last->next; + } + if (!frwr) + goto unmap; + + /* Strong send queue ordering guarantees that when the + * last WR in the chain completes, all WRs in the chain + * are complete. + */ + last->send_flags = IB_SEND_SIGNALED; + frwr->fr_cqe.done = frwr_wc_localinv_wake; + reinit_completion(&frwr->fr_linv_done); + + /* Transport disconnect drains the receive CQ before it + * replaces the QP. The RPC reply handler won't call us + * unless ri_id->qp is a valid pointer. + */ + r_xprt->rx_stats.local_inv_needed++; + bad_wr = NULL; + rc = ib_post_send(ia->ri_id->qp, first, &bad_wr); + if (bad_wr != first) + wait_for_completion(&frwr->fr_linv_done); + if (rc) + goto reset_mrs; + + /* ORDER: Now DMA unmap all of the MRs, and return + * them to the free MR list. + */ +unmap: + while (!list_empty(mrs)) { + mr = rpcrdma_mr_pop(mrs); + rpcrdma_mr_unmap_and_put(mr); + } + return; + +reset_mrs: + pr_err("rpcrdma: FRWR invalidate ib_post_send returned %i\n", rc); + + /* Find and reset the MRs in the LOCAL_INV WRs that did not + * get posted. + */ + while (bad_wr) { + frwr = container_of(bad_wr, struct rpcrdma_frwr, + fr_invwr); + mr = container_of(frwr, struct rpcrdma_mr, frwr); + + __frwr_mr_reset(ia, mr); + + bad_wr = bad_wr->next; + } + goto unmap; +} + +const struct rpcrdma_memreg_ops rpcrdma_frwr_memreg_ops = { + .ro_map = frwr_op_map, + .ro_send = frwr_op_send, + .ro_reminv = frwr_op_reminv, + .ro_unmap_sync = frwr_op_unmap_sync, + .ro_recover_mr = frwr_op_recover_mr, + .ro_open = frwr_op_open, + .ro_maxpages = frwr_op_maxpages, + .ro_init_mr = frwr_op_init_mr, + .ro_release_mr = frwr_op_release_mr, + .ro_displayname = "frwr", + .ro_send_w_inv_ok = RPCRDMA_CMP_F_SND_W_INV_OK, +}; diff --git a/net/sunrpc/xprtrdma/module.c b/net/sunrpc/xprtrdma/module.c new file mode 100644 index 000000000..45c5b41ac --- /dev/null +++ b/net/sunrpc/xprtrdma/module.c @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + */ + +/* rpcrdma.ko module initialization + */ + +#include <linux/types.h> +#include <linux/compiler.h> +#include <linux/module.h> +#include <linux/init.h> +#include <linux/sunrpc/svc_rdma.h> + +#include <asm/swab.h> + +#include "xprt_rdma.h" + +#define CREATE_TRACE_POINTS +#include <trace/events/rpcrdma.h> + +MODULE_AUTHOR("Open Grid Computing and Network Appliance, Inc."); +MODULE_DESCRIPTION("RPC/RDMA Transport"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS("svcrdma"); +MODULE_ALIAS("xprtrdma"); +MODULE_ALIAS("rpcrdma6"); + +static void __exit rpc_rdma_cleanup(void) +{ + xprt_rdma_cleanup(); + svc_rdma_cleanup(); +} + +static int __init rpc_rdma_init(void) +{ + int rc; + + rc = svc_rdma_init(); + if (rc) + goto out; + + rc = xprt_rdma_init(); + if (rc) + svc_rdma_cleanup(); + +out: + return rc; +} + +module_init(rpc_rdma_init); +module_exit(rpc_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c new file mode 100644 index 000000000..7f9d8365c --- /dev/null +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -0,0 +1,1404 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * rpc_rdma.c + * + * This file contains the guts of the RPC RDMA protocol, and + * does marshaling/unmarshaling, etc. It is also where interfacing + * to the Linux RPC framework lives. + */ + +#include <linux/highmem.h> + +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* Returns size of largest RPC-over-RDMA header in a Call message + * + * The largest Call header contains a full-size Read list and a + * minimal Reply chunk. + */ +static unsigned int rpcrdma_max_call_header_size(unsigned int maxsegs) +{ + unsigned int size; + + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; + + /* Maximum Read list size */ + maxsegs += 2; /* segment for head and tail buffers */ + size += maxsegs * rpcrdma_readchunk_maxsz * sizeof(__be32); + + /* Minimal Read chunk size */ + size += sizeof(__be32); /* segment count */ + size += rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + dprintk("RPC: %s: max call header size = %u\n", + __func__, size); + return size; +} + +/* Returns size of largest RPC-over-RDMA header in a Reply message + * + * There is only one Write list or one Reply chunk per Reply + * message. The larger list is the Write list. + */ +static unsigned int rpcrdma_max_reply_header_size(unsigned int maxsegs) +{ + unsigned int size; + + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; + + /* Maximum Write list size */ + maxsegs += 2; /* segment for head and tail buffers */ + size += sizeof(__be32); /* segment count */ + size += maxsegs * rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + dprintk("RPC: %s: max reply header size = %u\n", + __func__, size); + return size; +} + +void rpcrdma_set_max_header_sizes(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_create_data_internal *cdata = &r_xprt->rx_data; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + unsigned int maxsegs = ia->ri_max_segs; + + ia->ri_max_inline_write = cdata->inline_wsize - + rpcrdma_max_call_header_size(maxsegs); + ia->ri_max_inline_read = cdata->inline_rsize - + rpcrdma_max_reply_header_size(maxsegs); +} + +/* The client can send a request inline as long as the RPCRDMA header + * plus the RPC call fit under the transport's inline limit. If the + * combined call message size exceeds that limit, the client must use + * a Read chunk for this operation. + * + * A Read chunk is also required if sending the RPC call inline would + * exceed this device's max_sge limit. + */ +static bool rpcrdma_args_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + struct xdr_buf *xdr = &rqst->rq_snd_buf; + unsigned int count, remaining, offset; + + if (xdr->len > r_xprt->rx_ia.ri_max_inline_write) + return false; + + if (xdr->page_len) { + remaining = xdr->page_len; + offset = offset_in_page(xdr->page_base); + count = RPCRDMA_MIN_SEND_SGES; + while (remaining) { + remaining -= min_t(unsigned int, + PAGE_SIZE - offset, remaining); + offset = 0; + if (++count > r_xprt->rx_ia.ri_max_send_sges) + return false; + } + } + + return true; +} + +/* The client can't know how large the actual reply will be. Thus it + * plans for the largest possible reply for that particular ULP + * operation. If the maximum combined reply message size exceeds that + * limit, the client must provide a write list or a reply chunk for + * this request. + */ +static bool rpcrdma_results_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + + return rqst->rq_rcv_buf.buflen <= ia->ri_max_inline_read; +} + +/* Split @vec on page boundaries into SGEs. FMR registers pages, not + * a byte range. Other modes coalesce these SGEs into a single MR + * when they can. + * + * Returns pointer to next available SGE, and bumps the total number + * of SGEs consumed. + */ +static struct rpcrdma_mr_seg * +rpcrdma_convert_kvec(struct kvec *vec, struct rpcrdma_mr_seg *seg, + unsigned int *n) +{ + u32 remaining, page_offset; + char *base; + + base = vec->iov_base; + page_offset = offset_in_page(base); + remaining = vec->iov_len; + while (remaining) { + seg->mr_page = NULL; + seg->mr_offset = base; + seg->mr_len = min_t(u32, PAGE_SIZE - page_offset, remaining); + remaining -= seg->mr_len; + base += seg->mr_len; + ++seg; + ++(*n); + page_offset = 0; + } + return seg; +} + +/* Convert @xdrbuf into SGEs no larger than a page each. As they + * are registered, these SGEs are then coalesced into RDMA segments + * when the selected memreg mode supports it. + * + * Returns positive number of SGEs consumed, or a negative errno. + */ + +static int +rpcrdma_convert_iovs(struct rpcrdma_xprt *r_xprt, struct xdr_buf *xdrbuf, + unsigned int pos, enum rpcrdma_chunktype type, + struct rpcrdma_mr_seg *seg) +{ + unsigned long page_base; + unsigned int len, n; + struct page **ppages; + + n = 0; + if (pos == 0) + seg = rpcrdma_convert_kvec(&xdrbuf->head[0], seg, &n); + + len = xdrbuf->page_len; + ppages = xdrbuf->pages + (xdrbuf->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdrbuf->page_base); + while (len) { + if (unlikely(!*ppages)) { + /* XXX: Certain upper layer operations do + * not provide receive buffer pages. + */ + *ppages = alloc_page(GFP_ATOMIC); + if (!*ppages) + return -ENOBUFS; + } + seg->mr_page = *ppages; + seg->mr_offset = (char *)page_base; + seg->mr_len = min_t(u32, PAGE_SIZE - page_base, len); + len -= seg->mr_len; + ++ppages; + ++seg; + ++n; + page_base = 0; + } + + /* When encoding a Read chunk, the tail iovec contains an + * XDR pad and may be omitted. + */ + if (type == rpcrdma_readch && r_xprt->rx_ia.ri_implicit_roundup) + goto out; + + /* When encoding a Write chunk, some servers need to see an + * extra segment for non-XDR-aligned Write chunks. The upper + * layer provides space in the tail iovec that may be used + * for this purpose. + */ + if (type == rpcrdma_writech && r_xprt->rx_ia.ri_implicit_roundup) + goto out; + + if (xdrbuf->tail[0].iov_len) + seg = rpcrdma_convert_kvec(&xdrbuf->tail[0], seg, &n); + +out: + if (unlikely(n > RPCRDMA_MAX_SEGS)) + return -EIO; + return n; +} + +static inline int +encode_item_present(struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + *p = xdr_one; + return 0; +} + +static inline int +encode_item_not_present(struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + *p = xdr_zero; + return 0; +} + +static void +xdr_encode_rdma_segment(__be32 *iptr, struct rpcrdma_mr *mr) +{ + *iptr++ = cpu_to_be32(mr->mr_handle); + *iptr++ = cpu_to_be32(mr->mr_length); + xdr_encode_hyper(iptr, mr->mr_offset); +} + +static int +encode_rdma_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + xdr_encode_rdma_segment(p, mr); + return 0; +} + +static int +encode_read_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr, + u32 position) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 6 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + *p++ = xdr_one; /* Item present */ + *p++ = cpu_to_be32(position); + xdr_encode_rdma_segment(p, mr); + return 0; +} + +/* Register and XDR encode the Read list. Supports encoding a list of read + * segments that belong to a single read chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Read chunklist (a linked list): + * N elements, position P (same P for all chunks of same arg!): + * 1 - PHLOO - 1 - PHLOO - ... - 1 - PHLOO - 0 + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single @pos value is currently supported. + */ +static noinline int +rpcrdma_encode_read_list(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, + struct rpc_rqst *rqst, enum rpcrdma_chunktype rtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + unsigned int pos; + int nsegs; + + pos = rqst->rq_snd_buf.head[0].iov_len; + if (rtype == rpcrdma_areadch) + pos = 0; + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_snd_buf, pos, + rtype, seg); + if (nsegs < 0) + return nsegs; + + do { + seg = r_xprt->rx_ia.ri_ops->ro_map(r_xprt, seg, nsegs, + false, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + rpcrdma_mr_push(mr, &req->rl_registered); + + if (encode_read_segment(xdr, mr, pos) < 0) + return -EMSGSIZE; + + trace_xprtrdma_read_chunk(rqst->rq_task, pos, mr, nsegs); + r_xprt->rx_stats.read_chunk_count++; + nsegs -= mr->mr_nents; + } while (nsegs); + + return 0; +} + +/* Register and XDR encode the Write list. Supports encoding a list + * containing one array of plain segments that belong to a single + * write chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Write chunklist (a list of (one) counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO - 0 + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single Write chunk is currently supported. + */ +static noinline int +rpcrdma_encode_write_list(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, + struct rpc_rqst *rqst, enum rpcrdma_chunktype wtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, + rqst->rq_rcv_buf.head[0].iov_len, + wtype, seg); + if (nsegs < 0) + return nsegs; + + if (encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + + nchunks = 0; + do { + seg = r_xprt->rx_ia.ri_ops->ro_map(r_xprt, seg, nsegs, + true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + rpcrdma_mr_push(mr, &req->rl_registered); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_write_chunk(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } while (nsegs); + + /* Update count of segments in this Write chunk */ + *segcount = cpu_to_be32(nchunks); + + return 0; +} + +/* Register and XDR encode the Reply chunk. Supports encoding an array + * of plain segments that belong to a single write (reply) chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Reply chunk (a counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + */ +static noinline int +rpcrdma_encode_reply_chunk(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, + struct rpc_rqst *rqst, enum rpcrdma_chunktype wtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, 0, wtype, seg); + if (nsegs < 0) + return nsegs; + + if (encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + + nchunks = 0; + do { + seg = r_xprt->rx_ia.ri_ops->ro_map(r_xprt, seg, nsegs, + true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + rpcrdma_mr_push(mr, &req->rl_registered); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_reply_chunk(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.reply_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } while (nsegs); + + /* Update count of segments in the Reply chunk */ + *segcount = cpu_to_be32(nchunks); + + return 0; +} + +/** + * rpcrdma_unmap_sendctx - DMA-unmap Send buffers + * @sc: sendctx containing SGEs to unmap + * + */ +void +rpcrdma_unmap_sendctx(struct rpcrdma_sendctx *sc) +{ + struct rpcrdma_ia *ia = &sc->sc_xprt->rx_ia; + struct ib_sge *sge; + unsigned int count; + + /* The first two SGEs contain the transport header and + * the inline buffer. These are always left mapped so + * they can be cheaply re-used. + */ + sge = &sc->sc_sges[2]; + for (count = sc->sc_unmap_count; count; ++sge, --count) + ib_dma_unmap_page(ia->ri_device, + sge->addr, sge->length, DMA_TO_DEVICE); + + if (test_and_clear_bit(RPCRDMA_REQ_F_TX_RESOURCES, &sc->sc_req->rl_flags)) { + smp_mb__after_atomic(); + wake_up_bit(&sc->sc_req->rl_flags, RPCRDMA_REQ_F_TX_RESOURCES); + } +} + +/* Prepare an SGE for the RPC-over-RDMA transport header. + */ +static bool +rpcrdma_prepare_hdr_sge(struct rpcrdma_ia *ia, struct rpcrdma_req *req, + u32 len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct rpcrdma_regbuf *rb = req->rl_rdmabuf; + struct ib_sge *sge = sc->sc_sges; + + if (!rpcrdma_dma_map_regbuf(ia, rb)) + goto out_regbuf; + sge->addr = rdmab_addr(rb); + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, + sge->length, DMA_TO_DEVICE); + sc->sc_wr.num_sge++; + return true; + +out_regbuf: + pr_err("rpcrdma: failed to DMA map a Send buffer\n"); + return false; +} + +/* Prepare the Send SGEs. The head and tail iovec, and each entry + * in the page list, gets its own SGE. + */ +static bool +rpcrdma_prepare_msg_sges(struct rpcrdma_ia *ia, struct rpcrdma_req *req, + struct xdr_buf *xdr, enum rpcrdma_chunktype rtype) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + unsigned int sge_no, page_base, len, remaining; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + struct ib_device *device = ia->ri_device; + struct ib_sge *sge = sc->sc_sges; + u32 lkey = ia->ri_pd->local_dma_lkey; + struct page *page, **ppages; + + /* The head iovec is straightforward, as it is already + * DMA-mapped. Sync the content that has changed. + */ + if (!rpcrdma_dma_map_regbuf(ia, rb)) + goto out_regbuf; + sge_no = 1; + sge[sge_no].addr = rdmab_addr(rb); + sge[sge_no].length = xdr->head[0].iov_len; + sge[sge_no].lkey = rdmab_lkey(rb); + ib_dma_sync_single_for_device(rdmab_device(rb), sge[sge_no].addr, + sge[sge_no].length, DMA_TO_DEVICE); + + /* If there is a Read chunk, the page list is being handled + * via explicit RDMA, and thus is skipped here. However, the + * tail iovec may include an XDR pad for the page list, as + * well as additional content, and may not reside in the + * same page as the head iovec. + */ + if (rtype == rpcrdma_readch) { + len = xdr->tail[0].iov_len; + + /* Do not include the tail if it is only an XDR pad */ + if (len < 4) + goto out; + + page = virt_to_page(xdr->tail[0].iov_base); + page_base = offset_in_page(xdr->tail[0].iov_base); + + /* If the content in the page list is an odd length, + * xdr_write_pages() has added a pad at the beginning + * of the tail iovec. Force the tail's non-pad content + * to land at the next XDR position in the Send message. + */ + page_base += len & 3; + len -= len & 3; + goto map_tail; + } + + /* If there is a page list present, temporarily DMA map + * and prepare an SGE for each page to be sent. + */ + if (xdr->page_len) { + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + sge_no++; + if (sge_no > RPCRDMA_MAX_SEND_SGES - 2) + goto out_mapping_overflow; + + len = min_t(u32, PAGE_SIZE - page_base, remaining); + sge[sge_no].addr = ib_dma_map_page(device, *ppages, + page_base, len, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(device, sge[sge_no].addr)) + goto out_mapping_err; + sge[sge_no].length = len; + sge[sge_no].lkey = lkey; + + sc->sc_unmap_count++; + ppages++; + remaining -= len; + page_base = 0; + } + } + + /* The tail iovec is not always constructed in the same + * page where the head iovec resides (see, for example, + * gss_wrap_req_priv). To neatly accommodate that case, + * DMA map it separately. + */ + if (xdr->tail[0].iov_len) { + page = virt_to_page(xdr->tail[0].iov_base); + page_base = offset_in_page(xdr->tail[0].iov_base); + len = xdr->tail[0].iov_len; + +map_tail: + sge_no++; + sge[sge_no].addr = ib_dma_map_page(device, page, + page_base, len, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(device, sge[sge_no].addr)) + goto out_mapping_err; + sge[sge_no].length = len; + sge[sge_no].lkey = lkey; + sc->sc_unmap_count++; + } + +out: + sc->sc_wr.num_sge += sge_no; + if (sc->sc_unmap_count) + __set_bit(RPCRDMA_REQ_F_TX_RESOURCES, &req->rl_flags); + return true; + +out_regbuf: + pr_err("rpcrdma: failed to DMA map a Send buffer\n"); + return false; + +out_mapping_overflow: + rpcrdma_unmap_sendctx(sc); + pr_err("rpcrdma: too many Send SGEs (%u)\n", sge_no); + return false; + +out_mapping_err: + rpcrdma_unmap_sendctx(sc); + pr_err("rpcrdma: Send mapping error\n"); + return false; +} + +/** + * rpcrdma_prepare_send_sges - Construct SGEs for a Send WR + * @r_xprt: controlling transport + * @req: context of RPC Call being marshalled + * @hdrlen: size of transport header, in bytes + * @xdr: xdr_buf containing RPC Call + * @rtype: chunk type being encoded + * + * Returns 0 on success; otherwise a negative errno is returned. + */ +int +rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, enum rpcrdma_chunktype rtype) +{ + req->rl_sendctx = rpcrdma_sendctx_get_locked(&r_xprt->rx_buf); + if (!req->rl_sendctx) + return -EAGAIN; + req->rl_sendctx->sc_wr.num_sge = 0; + req->rl_sendctx->sc_unmap_count = 0; + req->rl_sendctx->sc_req = req; + __clear_bit(RPCRDMA_REQ_F_TX_RESOURCES, &req->rl_flags); + + if (!rpcrdma_prepare_hdr_sge(&r_xprt->rx_ia, req, hdrlen)) + return -EIO; + + if (rtype != rpcrdma_areadch) + if (!rpcrdma_prepare_msg_sges(&r_xprt->rx_ia, req, xdr, rtype)) + return -EIO; + + return 0; +} + +/** + * rpcrdma_marshal_req - Marshal and send one RPC request + * @r_xprt: controlling transport + * @rqst: RPC request to be marshaled + * + * For the RPC in "rqst", this function: + * - Chooses the transfer mode (eg., RDMA_MSG or RDMA_NOMSG) + * - Registers Read, Write, and Reply chunks + * - Constructs the transport header + * - Posts a Send WR to send the transport header and request + * + * Returns: + * %0 if the RPC was sent successfully, + * %-ENOTCONN if the connection was lost, + * %-EAGAIN if the caller should call again with the same arguments, + * %-ENOBUFS if the caller should call again after a delay, + * %-EMSGSIZE if the transport header is too small, + * %-EIO if a permanent problem occurred while marshaling. + */ +int +rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct xdr_stream *xdr = &req->rl_stream; + enum rpcrdma_chunktype rtype, wtype; + bool ddp_allowed; + __be32 *p; + int ret; + + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(xdr, &req->rl_hdrbuf, + req->rl_rdmabuf->rg_base); + + /* Fixed header fields */ + ret = -EMSGSIZE; + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + goto out_err; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_max_requests); + + /* When the ULP employs a GSS flavor that guarantees integrity + * or privacy, direct data placement of individual data items + * is not allowed. + */ + ddp_allowed = !(rqst->rq_cred->cr_auth->au_flags & + RPCAUTH_AUTH_DATATOUCH); + + /* + * Chunks needed for results? + * + * o If the expected result is under the inline threshold, all ops + * return as inline. + * o Large read ops return data as write chunk(s), header as + * inline. + * o Large non-read ops return as a single reply chunk. + */ + if (rpcrdma_results_inline(r_xprt, rqst)) + wtype = rpcrdma_noch; + else if (ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) + wtype = rpcrdma_writech; + else + wtype = rpcrdma_replych; + + /* + * Chunks needed for arguments? + * + * o If the total request is under the inline threshold, all ops + * are sent as inline. + * o Large write ops transmit data as read chunk(s), header as + * inline. + * o Large non-write ops are sent with the entire message as a + * single read chunk (protocol 0-position special case). + * + * This assumes that the upper layer does not present a request + * that both has a data payload, and whose non-data arguments + * by themselves are larger than the inline threshold. + */ + if (rpcrdma_args_inline(r_xprt, rqst)) { + *p++ = rdma_msg; + rtype = rpcrdma_noch; + } else if (ddp_allowed && rqst->rq_snd_buf.flags & XDRBUF_WRITE) { + *p++ = rdma_msg; + rtype = rpcrdma_readch; + } else { + r_xprt->rx_stats.nomsg_call_count++; + *p++ = rdma_nomsg; + rtype = rpcrdma_areadch; + } + + /* If this is a retransmit, discard previously registered + * chunks. Very likely the connection has been replaced, + * so these registrations are invalid and unusable. + */ + while (unlikely(!list_empty(&req->rl_registered))) { + struct rpcrdma_mr *mr; + + mr = rpcrdma_mr_pop(&req->rl_registered); + rpcrdma_mr_defer_recovery(mr); + } + + /* This implementation supports the following combinations + * of chunk lists in one RPC-over-RDMA Call message: + * + * - Read list + * - Write list + * - Reply chunk + * - Read list + Reply chunk + * + * It might not yet support the following combinations: + * + * - Read list + Write list + * + * It does not support the following combinations: + * + * - Write list + Reply chunk + * - Read list + Write list + Reply chunk + * + * This implementation supports only a single chunk in each + * Read or Write list. Thus for example the client cannot + * send a Call message with a Position Zero Read chunk and a + * regular Read chunk at the same time. + */ + if (rtype != rpcrdma_noch) { + ret = rpcrdma_encode_read_list(r_xprt, req, rqst, rtype); + if (ret) + goto out_err; + } + ret = encode_item_not_present(xdr); + if (ret) + goto out_err; + + if (wtype == rpcrdma_writech) { + ret = rpcrdma_encode_write_list(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + } + ret = encode_item_not_present(xdr); + if (ret) + goto out_err; + + if (wtype != rpcrdma_replych) + ret = encode_item_not_present(xdr); + else + ret = rpcrdma_encode_reply_chunk(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + + trace_xprtrdma_marshal(rqst, xdr_stream_pos(xdr), rtype, wtype); + + ret = rpcrdma_prepare_send_sges(r_xprt, req, xdr_stream_pos(xdr), + &rqst->rq_snd_buf, rtype); + if (ret) + goto out_err; + return 0; + +out_err: + switch (ret) { + case -EAGAIN: + xprt_wait_for_buffer_space(rqst->rq_task, NULL); + break; + case -ENOBUFS: + break; + default: + r_xprt->rx_stats.failed_marshal_count++; + } + return ret; +} + +/** + * rpcrdma_inline_fixup - Scatter inline received data into rqst's iovecs + * @rqst: controlling RPC request + * @srcp: points to RPC message payload in receive buffer + * @copy_len: remaining length of receive buffer content + * @pad: Write chunk pad bytes needed (zero for pure inline) + * + * The upper layer has set the maximum number of bytes it can + * receive in each component of rq_rcv_buf. These values are set in + * the head.iov_len, page_len, tail.iov_len, and buflen fields. + * + * Unlike the TCP equivalent (xdr_partial_copy_from_skb), in + * many cases this function simply updates iov_base pointers in + * rq_rcv_buf to point directly to the received reply data, to + * avoid copying reply data. + * + * Returns the count of bytes which had to be memcopied. + */ +static unsigned long +rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) +{ + unsigned long fixup_copy_count; + int i, npages, curlen; + char *destp; + struct page **ppages; + int page_base; + + /* The head iovec is redirected to the RPC reply message + * in the receive buffer, to avoid a memcopy. + */ + rqst->rq_rcv_buf.head[0].iov_base = srcp; + rqst->rq_private_buf.head[0].iov_base = srcp; + + /* The contents of the receive buffer that follow + * head.iov_len bytes are copied into the page list. + */ + curlen = rqst->rq_rcv_buf.head[0].iov_len; + if (curlen > copy_len) + curlen = copy_len; + trace_xprtrdma_fixup(rqst, copy_len, curlen); + srcp += curlen; + copy_len -= curlen; + + ppages = rqst->rq_rcv_buf.pages + + (rqst->rq_rcv_buf.page_base >> PAGE_SHIFT); + page_base = offset_in_page(rqst->rq_rcv_buf.page_base); + fixup_copy_count = 0; + if (copy_len && rqst->rq_rcv_buf.page_len) { + int pagelist_len; + + pagelist_len = rqst->rq_rcv_buf.page_len; + if (pagelist_len > copy_len) + pagelist_len = copy_len; + npages = PAGE_ALIGN(page_base + pagelist_len) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + curlen = PAGE_SIZE - page_base; + if (curlen > pagelist_len) + curlen = pagelist_len; + + trace_xprtrdma_fixup_pg(rqst, i, srcp, + copy_len, curlen); + destp = kmap_atomic(ppages[i]); + memcpy(destp + page_base, srcp, curlen); + flush_dcache_page(ppages[i]); + kunmap_atomic(destp); + srcp += curlen; + copy_len -= curlen; + fixup_copy_count += curlen; + pagelist_len -= curlen; + if (!pagelist_len) + break; + page_base = 0; + } + + /* Implicit padding for the last segment in a Write + * chunk is inserted inline at the front of the tail + * iovec. The upper layer ignores the content of + * the pad. Simply ensure inline content in the tail + * that follows the Write chunk is properly aligned. + */ + if (pad) + srcp -= pad; + } + + /* The tail iovec is redirected to the remaining data + * in the receive buffer, to avoid a memcopy. + */ + if (copy_len || pad) { + rqst->rq_rcv_buf.tail[0].iov_base = srcp; + rqst->rq_private_buf.tail[0].iov_base = srcp; + } + + return fixup_copy_count; +} + +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool +rpcrdma_is_bcall(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + if (rep->rr_proc != rdma_msg) + return false; + + /* Peek at stream contents without advancing. */ + p = xdr_inline_decode(xdr, 0); + + /* Chunk lists */ + if (*p++ != xdr_zero) + return false; + if (*p++ != xdr_zero) + return false; + if (*p++ != xdr_zero) + return false; + + /* RPC header */ + if (*p++ != rep->rr_xid) + return false; + if (*p != cpu_to_be32(RPC_CALL)) + return false; + + /* No bc service. */ + if (xprt->bc_serv == NULL) + return false; + + /* Now that we are sure this is a backchannel call, + * advance to the RPC header. + */ + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (unlikely(!p)) + goto out_short; + + rpcrdma_bc_receive_call(r_xprt, rep); + return true; + +out_short: + pr_warn("RPC/RDMA short backward direction call\n"); + return true; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +{ + return false; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static int decode_rdma_segment(struct xdr_stream *xdr, u32 *length) +{ + u32 handle; + u64 offset; + __be32 *p; + + p = xdr_inline_decode(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + handle = be32_to_cpup(p++); + *length = be32_to_cpup(p++); + xdr_decode_hyper(p, &offset); + + trace_xprtrdma_decode_seg(handle, *length, offset); + return 0; +} + +static int decode_write_chunk(struct xdr_stream *xdr, u32 *length) +{ + u32 segcount, seglength; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + *length = 0; + segcount = be32_to_cpup(p); + while (segcount--) { + if (decode_rdma_segment(xdr, &seglength)) + return -EIO; + *length += seglength; + } + + return 0; +} + +/* In RPC-over-RDMA Version One replies, a Read list is never + * expected. This decoder is a stub that returns an error if + * a Read list is present. + */ +static int decode_read_list(struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (unlikely(*p != xdr_zero)) + return -EIO; + return 0; +} + +/* Supports only one Write chunk in the Write list + */ +static int decode_write_list(struct xdr_stream *xdr, u32 *length) +{ + u32 chunklen; + bool first; + __be32 *p; + + *length = 0; + first = true; + do { + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (*p == xdr_zero) + break; + if (!first) + return -EIO; + + if (decode_write_chunk(xdr, &chunklen)) + return -EIO; + *length += chunklen; + first = false; + } while (true); + return 0; +} + +static int decode_reply_chunk(struct xdr_stream *xdr, u32 *length) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + *length = 0; + if (*p != xdr_zero) + if (decode_write_chunk(xdr, length)) + return -EIO; + return 0; +} + +static int +rpcrdma_decode_msg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk, rpclen; + char *base; + + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_MSG sanity checks */ + if (unlikely(replychunk)) + return -EIO; + + /* Build the RPC reply's Payload stream in rqst->rq_rcv_buf */ + base = (char *)xdr_inline_decode(xdr, 0); + rpclen = xdr_stream_remaining(xdr); + r_xprt->rx_stats.fixup_copy_count += + rpcrdma_inline_fixup(rqst, base, rpclen, writelist & 3); + + r_xprt->rx_stats.total_rdma_reply += writelist; + return rpclen + xdr_align_size(writelist); +} + +static noinline int +rpcrdma_decode_nomsg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk; + + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_NOMSG sanity checks */ + if (unlikely(writelist)) + return -EIO; + if (unlikely(!replychunk)) + return -EIO; + + /* Reply chunk buffer already is the reply vector */ + r_xprt->rx_stats.total_rdma_reply += replychunk; + return replychunk; +} + +static noinline int +rpcrdma_decode_error(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + switch (*p) { + case err_vers: + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + break; + dprintk("RPC: %5u: %s: server reports version error (%u-%u)\n", + rqst->rq_task->tk_pid, __func__, + be32_to_cpup(p), be32_to_cpu(*(p + 1))); + break; + case err_chunk: + dprintk("RPC: %5u: %s: server reports header decoding error\n", + rqst->rq_task->tk_pid, __func__); + break; + default: + dprintk("RPC: %5u: %s: server reports unrecognized error %d\n", + rqst->rq_task->tk_pid, __func__, be32_to_cpup(p)); + } + + r_xprt->rx_stats.bad_reply_count++; + return -EREMOTEIO; +} + +/* Perform XID lookup, reconstruction of the RPC reply, and + * RPC completion while holding the transport lock to ensure + * the rep, rqst, and rq_task pointers remain stable. + */ +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpc_rqst *rqst = rep->rr_rqst; + unsigned long cwnd; + int status; + + xprt->reestablish_timeout = 0; + + switch (rep->rr_proc) { + case rdma_msg: + status = rpcrdma_decode_msg(r_xprt, rep, rqst); + break; + case rdma_nomsg: + status = rpcrdma_decode_nomsg(r_xprt, rep); + break; + case rdma_error: + status = rpcrdma_decode_error(r_xprt, rep, rqst); + break; + default: + status = -EIO; + } + if (status < 0) + goto out_badheader; + +out: + spin_lock(&xprt->recv_lock); + cwnd = xprt->cwnd; + xprt->cwnd = r_xprt->rx_buf.rb_credits << RPC_CWNDSHIFT; + if (xprt->cwnd > cwnd) + xprt_release_rqst_cong(rqst->rq_task); + + xprt_complete_rqst(rqst->rq_task, status); + xprt_unpin_rqst(rqst); + spin_unlock(&xprt->recv_lock); + return; + +/* If the incoming reply terminated a pending RPC, the next + * RPC call will post a replacement receive buffer as it is + * being marshaled. + */ +out_badheader: + trace_xprtrdma_reply_hdr(rep); + r_xprt->rx_stats.bad_reply_count++; + status = -EIO; + goto out; +} + +void rpcrdma_release_rqst(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + /* Invalidate and unmap the data payloads before waking + * the waiting application. This guarantees the memory + * regions are properly fenced from the server before the + * application accesses the data. It also ensures proper + * send flow control: waking the next RPC waits until this + * RPC has relinquished all its Send Queue entries. + */ + if (!list_empty(&req->rl_registered)) + r_xprt->rx_ia.ri_ops->ro_unmap_sync(r_xprt, + &req->rl_registered); + + /* Ensure that any DMA mapped pages associated with + * the Send of the RPC Call have been unmapped before + * allowing the RPC to complete. This protects argument + * memory not controlled by the RPC client from being + * re-used before we're done with it. + */ + if (test_bit(RPCRDMA_REQ_F_TX_RESOURCES, &req->rl_flags)) { + r_xprt->rx_stats.reply_waits_for_send++; + out_of_line_wait_on_bit(&req->rl_flags, + RPCRDMA_REQ_F_TX_RESOURCES, + bit_wait, + TASK_UNINTERRUPTIBLE); + } +} + +/* Reply handling runs in the poll worker thread. Anything that + * might wait is deferred to a separate workqueue. + */ +void rpcrdma_deferred_completion(struct work_struct *work) +{ + struct rpcrdma_rep *rep = + container_of(work, struct rpcrdma_rep, rr_work); + struct rpcrdma_req *req = rpcr_to_rdmar(rep->rr_rqst); + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + + trace_xprtrdma_defer_cmp(rep); + if (rep->rr_wc_flags & IB_WC_WITH_INVALIDATE) + r_xprt->rx_ia.ri_ops->ro_reminv(rep, &req->rl_registered); + rpcrdma_release_rqst(r_xprt, req); + rpcrdma_complete_rqst(rep); +} + +/* Process received RPC/RDMA messages. + * + * Errors must result in the RPC task either being awakened, or + * allowed to timeout, to discover the errors at that time. + */ +void rpcrdma_reply_handler(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + u32 credits; + __be32 *p; + + --buf->rb_posted_receives; + + if (rep->rr_hdrbuf.head[0].iov_len == 0) + goto out_badstatus; + + /* Fixed transport header fields */ + xdr_init_decode(&rep->rr_stream, &rep->rr_hdrbuf, + rep->rr_hdrbuf.head[0].iov_base); + p = xdr_inline_decode(&rep->rr_stream, 4 * sizeof(*p)); + if (unlikely(!p)) + goto out_shortreply; + rep->rr_xid = *p++; + rep->rr_vers = *p++; + credits = be32_to_cpu(*p++); + rep->rr_proc = *p++; + + if (rep->rr_vers != rpcrdma_version) + goto out_badversion; + + if (rpcrdma_is_bcall(r_xprt, rep)) + return; + + /* Match incoming rpcrdma_rep to an rpcrdma_req to + * get context for handling any incoming chunks. + */ + spin_lock(&xprt->recv_lock); + rqst = xprt_lookup_rqst(xprt, rep->rr_xid); + if (!rqst) + goto out_norqst; + xprt_pin_rqst(rqst); + + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > buf->rb_max_requests) + credits = buf->rb_max_requests; + buf->rb_credits = credits; + + spin_unlock(&xprt->recv_lock); + + req = rpcr_to_rdmar(rqst); + if (req->rl_reply) { + trace_xprtrdma_leaked_rep(rqst, req->rl_reply); + rpcrdma_recv_buffer_put(req->rl_reply); + } + req->rl_reply = rep; + rep->rr_rqst = rqst; + clear_bit(RPCRDMA_REQ_F_PENDING, &req->rl_flags); + + trace_xprtrdma_reply(rqst->rq_task, rep, req, credits); + + rpcrdma_post_recvs(r_xprt, false); + queue_work(rpcrdma_receive_wq, &rep->rr_work); + return; + +out_badversion: + trace_xprtrdma_reply_vers(rep); + goto repost; + +/* The RPC transaction has already been terminated, or the header + * is corrupt. + */ +out_norqst: + spin_unlock(&xprt->recv_lock); + trace_xprtrdma_reply_rqst(rep); + goto repost; + +out_shortreply: + trace_xprtrdma_reply_short(rep); + +/* If no pending RPC transaction was matched, post a replacement + * receive buffer before returning. + */ +repost: + rpcrdma_post_recvs(r_xprt, false); +out_badstatus: + rpcrdma_recv_buffer_put(rep); +} diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c new file mode 100644 index 000000000..134bef6a4 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/slab.h> +#include <linux/fs.h> +#include <linux/sysctl.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* RPC/RDMA parameters */ +unsigned int svcrdma_ord = 16; /* historical default */ +static unsigned int min_ord = 1; +static unsigned int max_ord = 255; +unsigned int svcrdma_max_requests = RPCRDMA_MAX_REQUESTS; +unsigned int svcrdma_max_bc_requests = RPCRDMA_MAX_BC_REQUESTS; +static unsigned int min_max_requests = 4; +static unsigned int max_max_requests = 16384; +unsigned int svcrdma_max_req_size = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int min_max_inline = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int max_max_inline = RPCRDMA_MAX_INLINE_THRESH; + +atomic_t rdma_stat_recv; +atomic_t rdma_stat_read; +atomic_t rdma_stat_write; +atomic_t rdma_stat_sq_starve; +atomic_t rdma_stat_rq_starve; +atomic_t rdma_stat_rq_poll; +atomic_t rdma_stat_rq_prod; +atomic_t rdma_stat_sq_poll; +atomic_t rdma_stat_sq_prod; + +struct workqueue_struct *svc_rdma_wq; + +/* + * This function implements reading and resetting an atomic_t stat + * variable through read/write to a proc file. Any write to the file + * resets the associated statistic to zero. Any read returns it's + * current value. + */ +static int read_reset_stat(struct ctl_table *table, int write, + void __user *buffer, size_t *lenp, + loff_t *ppos) +{ + atomic_t *stat = (atomic_t *)table->data; + + if (!stat) + return -EINVAL; + + if (write) + atomic_set(stat, 0); + else { + char str_buf[32]; + int len = snprintf(str_buf, 32, "%d\n", atomic_read(stat)); + if (len >= 32) + return -EFAULT; + len = strlen(str_buf); + if (*ppos > len) { + *lenp = 0; + return 0; + } + len -= *ppos; + if (len > *lenp) + len = *lenp; + if (len && copy_to_user(buffer, str_buf, len)) + return -EFAULT; + *lenp = len; + *ppos += len; + } + return 0; +} + +static struct ctl_table_header *svcrdma_table_header; +static struct ctl_table svcrdma_parm_table[] = { + { + .procname = "max_requests", + .data = &svcrdma_max_requests, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_max_requests, + .extra2 = &max_max_requests + }, + { + .procname = "max_req_size", + .data = &svcrdma_max_req_size, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_max_inline, + .extra2 = &max_max_inline + }, + { + .procname = "max_outbound_read_requests", + .data = &svcrdma_ord, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ord, + .extra2 = &max_ord, + }, + + { + .procname = "rdma_stat_read", + .data = &rdma_stat_read, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_recv", + .data = &rdma_stat_recv, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_write", + .data = &rdma_stat_write, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_sq_starve", + .data = &rdma_stat_sq_starve, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_rq_starve", + .data = &rdma_stat_rq_starve, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_rq_poll", + .data = &rdma_stat_rq_poll, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_rq_prod", + .data = &rdma_stat_rq_prod, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_sq_poll", + .data = &rdma_stat_sq_poll, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { + .procname = "rdma_stat_sq_prod", + .data = &rdma_stat_sq_prod, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = read_reset_stat, + }, + { }, +}; + +static struct ctl_table svcrdma_table[] = { + { + .procname = "svc_rdma", + .mode = 0555, + .child = svcrdma_parm_table + }, + { }, +}; + +static struct ctl_table svcrdma_root_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = svcrdma_table + }, + { }, +}; + +void svc_rdma_cleanup(void) +{ + dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); + destroy_workqueue(svc_rdma_wq); + if (svcrdma_table_header) { + unregister_sysctl_table(svcrdma_table_header); + svcrdma_table_header = NULL; + } +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + svc_unreg_xprt_class(&svc_rdma_bc_class); +#endif + svc_unreg_xprt_class(&svc_rdma_class); +} + +int svc_rdma_init(void) +{ + dprintk("SVCRDMA Module Init, register RPC RDMA transport\n"); + dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord); + dprintk("\tmax_requests : %u\n", svcrdma_max_requests); + dprintk("\tmax_bc_requests : %u\n", svcrdma_max_bc_requests); + dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); + + svc_rdma_wq = alloc_workqueue("svc_rdma", 0, 0); + if (!svc_rdma_wq) + return -ENOMEM; + + if (!svcrdma_table_header) + svcrdma_table_header = + register_sysctl_table(svcrdma_root_table); + + /* Register RDMA with the SVC transport switch */ + svc_reg_xprt_class(&svc_rdma_class); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + svc_reg_xprt_class(&svc_rdma_bc_class); +#endif + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c new file mode 100644 index 000000000..cf2272a90 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * + * Support for backward direction RPCs on RPC/RDMA (server-side). + */ + +#include <linux/module.h> + +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +#undef SVCRDMA_BACKCHANNEL_DEBUG + +/** + * svc_rdma_handle_bc_reply - Process incoming backchannel reply + * @xprt: controlling backchannel transport + * @rdma_resp: pointer to incoming transport header + * @rcvbuf: XDR buffer into which to decode the reply + * + * Returns: + * %0 if @rcvbuf is filled in, xprt_complete_rqst called, + * %-EAGAIN if server should call ->recvfrom again. + */ +int svc_rdma_handle_bc_reply(struct rpc_xprt *xprt, __be32 *rdma_resp, + struct xdr_buf *rcvbuf) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct kvec *dst, *src = &rcvbuf->head[0]; + struct rpc_rqst *req; + unsigned long cwnd; + u32 credits; + size_t len; + __be32 xid; + __be32 *p; + int ret; + + p = (__be32 *)src->iov_base; + len = src->iov_len; + xid = *rdma_resp; + +#ifdef SVCRDMA_BACKCHANNEL_DEBUG + pr_info("%s: xid=%08x, length=%zu\n", + __func__, be32_to_cpu(xid), len); + pr_info("%s: RPC/RDMA: %*ph\n", + __func__, (int)RPCRDMA_HDRLEN_MIN, rdma_resp); + pr_info("%s: RPC: %*ph\n", + __func__, (int)len, p); +#endif + + ret = -EAGAIN; + if (src->iov_len < 24) + goto out_shortreply; + + spin_lock(&xprt->recv_lock); + req = xprt_lookup_rqst(xprt, xid); + if (!req) + goto out_notfound; + + dst = &req->rq_private_buf.head[0]; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + if (dst->iov_len < len) + goto out_unlock; + memcpy(dst->iov_base, p, len); + + credits = be32_to_cpup(rdma_resp + 2); + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > r_xprt->rx_buf.rb_bc_max_requests) + credits = r_xprt->rx_buf.rb_bc_max_requests; + + spin_lock_bh(&xprt->transport_lock); + cwnd = xprt->cwnd; + xprt->cwnd = credits << RPC_CWNDSHIFT; + if (xprt->cwnd > cwnd) + xprt_release_rqst_cong(req->rq_task); + spin_unlock_bh(&xprt->transport_lock); + + + ret = 0; + xprt_complete_rqst(req->rq_task, rcvbuf->len); + rcvbuf->len = 0; + +out_unlock: + spin_unlock(&xprt->recv_lock); +out: + return ret; + +out_shortreply: + dprintk("svcrdma: short bc reply: xprt=%p, len=%zu\n", + xprt, src->iov_len); + goto out; + +out_notfound: + dprintk("svcrdma: unrecognized bc reply: xprt=%p, xid=%08x\n", + xprt, be32_to_cpu(xid)); + goto out_unlock; +} + +/* Send a backwards direction RPC call. + * + * Caller holds the connection's mutex and has already marshaled + * the RPC/RDMA request. + * + * This is similar to svc_rdma_send_reply_msg, but takes a struct + * rpc_rqst instead, does not support chunks, and avoids blocking + * memory allocation. + * + * XXX: There is still an opportunity to block in svc_rdma_send() + * if there are no SQ entries to post the Send. This may occur if + * the adapter has a small maximum SQ depth. + */ +static int svc_rdma_bc_sendto(struct svcxprt_rdma *rdma, + struct rpc_rqst *rqst, + struct svc_rdma_send_ctxt *ctxt) +{ + int ret; + + ret = svc_rdma_map_reply_msg(rdma, ctxt, &rqst->rq_snd_buf, NULL); + if (ret < 0) + return -EIO; + + /* Bump page refcnt so Send completion doesn't release + * the rq_buffer before all retransmits are complete. + */ + get_page(virt_to_page(rqst->rq_buffer)); + ctxt->sc_send_wr.opcode = IB_WR_SEND; + return svc_rdma_send(rdma, &ctxt->sc_send_wr); +} + +/* Server-side transport endpoint wants a whole page for its send + * buffer. The client RPC code constructs the RPC header in this + * buffer before it invokes ->send_request. + */ +static int +xprt_rdma_bc_allocate(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; + struct page *page; + + if (size > PAGE_SIZE) { + WARN_ONCE(1, "svcrdma: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } + + page = alloc_page(RPCRDMA_DEF_GFP); + if (!page) + return -ENOMEM; + rqst->rq_buffer = page_address(page); + + rqst->rq_rbuffer = kmalloc(rqst->rq_rcvsize, RPCRDMA_DEF_GFP); + if (!rqst->rq_rbuffer) { + put_page(page); + return -ENOMEM; + } + return 0; +} + +static void +xprt_rdma_bc_free(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + + put_page(virt_to_page(rqst->rq_buffer)); + kfree(rqst->rq_rbuffer); +} + +static int +rpcrdma_bc_send_request(struct svcxprt_rdma *rdma, struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct svc_rdma_send_ctxt *ctxt; + __be32 *p; + int rc; + + ctxt = svc_rdma_send_ctxt_get(rdma); + if (!ctxt) + goto drop_connection; + + p = ctxt->sc_xprt_buf; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + svc_rdma_sync_reply_hdr(rdma, ctxt, RPCRDMA_HDRLEN_MIN); + +#ifdef SVCRDMA_BACKCHANNEL_DEBUG + pr_info("%s: %*ph\n", __func__, 64, rqst->rq_buffer); +#endif + + rc = svc_rdma_bc_sendto(rdma, rqst, ctxt); + if (rc) { + svc_rdma_send_ctxt_put(rdma, ctxt); + goto drop_connection; + } + return rc; + +drop_connection: + dprintk("svcrdma: failed to send bc call\n"); + xprt_disconnect_done(xprt); + return -ENOTCONN; +} + +/* Send an RPC call on the passive end of a transport + * connection. + */ +static int +xprt_rdma_bc_send_request(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct svc_xprt *sxprt = rqst->rq_xprt->bc_xprt; + struct svcxprt_rdma *rdma; + int ret; + + dprintk("svcrdma: sending bc call with xid: %08x\n", + be32_to_cpu(rqst->rq_xid)); + + if (!mutex_trylock(&sxprt->xpt_mutex)) { + rpc_sleep_on(&sxprt->xpt_bc_pending, task, NULL); + if (!mutex_trylock(&sxprt->xpt_mutex)) + return -EAGAIN; + rpc_wake_up_queued_task(&sxprt->xpt_bc_pending, task); + } + + ret = -ENOTCONN; + rdma = container_of(sxprt, struct svcxprt_rdma, sc_xprt); + if (!test_bit(XPT_DEAD, &sxprt->xpt_flags)) + ret = rpcrdma_bc_send_request(rdma, rqst); + + mutex_unlock(&sxprt->xpt_mutex); + + if (ret < 0) + return ret; + return 0; +} + +static void +xprt_rdma_bc_close(struct rpc_xprt *xprt) +{ + dprintk("svcrdma: %s: xprt %p\n", __func__, xprt); + xprt->cwnd = RPC_CWNDSHIFT; +} + +static void +xprt_rdma_bc_put(struct rpc_xprt *xprt) +{ + dprintk("svcrdma: %s: xprt %p\n", __func__, xprt); + + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + module_put(THIS_MODULE); +} + +static const struct rpc_xprt_ops xprt_rdma_bc_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .release_request = xprt_release_rqst_cong, + .buf_alloc = xprt_rdma_bc_allocate, + .buf_free = xprt_rdma_bc_free, + .send_request = xprt_rdma_bc_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = xprt_rdma_bc_close, + .destroy = xprt_rdma_bc_put, + .print_stats = xprt_rdma_print_stats +}; + +static const struct rpc_timeout xprt_rdma_bc_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/* It shouldn't matter if the number of backchannel session slots + * doesn't match the number of RPC/RDMA credits. That just means + * one or the other will have extra slots that aren't used. + */ +static struct rpc_xprt * +xprt_setup_rdma_bc(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: %s: address too large\n", __func__); + return ERR_PTR(-EBADF); + } + + xprt = xprt_alloc(args->net, sizeof(*new_xprt), + RPCRDMA_MAX_BC_REQUESTS, + RPCRDMA_MAX_BC_REQUESTS); + if (!xprt) { + dprintk("RPC: %s: couldn't allocate rpc_xprt\n", + __func__); + return ERR_PTR(-ENOMEM); + } + + xprt->timeout = &xprt_rdma_bc_timeout; + xprt_set_bound(xprt); + xprt_set_connected(xprt); + xprt->bind_timeout = 0; + xprt->reestablish_timeout = 0; + xprt->idle_timeout = 0; + + xprt->prot = XPRT_TRANSPORT_BC_RDMA; + xprt->tsh_size = 0; + xprt->ops = &xprt_rdma_bc_procs; + + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + xprt_rdma_format_addresses(xprt, (struct sockaddr *)&xprt->addr); + xprt->resvport = 0; + + xprt->max_payload = xprt_rdma_max_inline_read; + + new_xprt = rpcx_to_rdmax(xprt); + new_xprt->rx_buf.rb_bc_max_requests = xprt->max_reqs; + + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + + if (!try_module_get(THIS_MODULE)) + goto out_fail; + + /* Final put for backchannel xprt is in __svc_rdma_free */ + xprt_get(xprt); + return xprt; + +out_fail: + xprt_rdma_free_addresses(xprt); + args->bc_xprt->xpt_bc_xprt = NULL; + args->bc_xprt->xpt_bc_xps = NULL; + xprt_put(xprt); + xprt_free(xprt); + return ERR_PTR(-EINVAL); +} + +struct xprt_class xprt_rdma_bc = { + .list = LIST_HEAD_INIT(xprt_rdma_bc.list), + .name = "rdma backchannel", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_BC_RDMA, + .setup = xprt_setup_rdma_bc, +}; diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c new file mode 100644 index 000000000..252495ff9 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +/* Operation + * + * The main entry point is svc_rdma_recvfrom. This is called from + * svc_recv when the transport indicates there is incoming data to + * be read. "Data Ready" is signaled when an RDMA Receive completes, + * or when a set of RDMA Reads complete. + * + * An svc_rqst is passed in. This structure contains an array of + * free pages (rq_pages) that will contain the incoming RPC message. + * + * Short messages are moved directly into svc_rqst::rq_arg, and + * the RPC Call is ready to be processed by the Upper Layer. + * svc_rdma_recvfrom returns the length of the RPC Call message, + * completing the reception of the RPC Call. + * + * However, when an incoming message has Read chunks, + * svc_rdma_recvfrom must post RDMA Reads to pull the RPC Call's + * data payload from the client. svc_rdma_recvfrom sets up the + * RDMA Reads using pages in svc_rqst::rq_pages, which are + * transferred to an svc_rdma_recv_ctxt for the duration of the + * I/O. svc_rdma_recvfrom then returns zero, since the RPC message + * is still not yet ready. + * + * When the Read chunk payloads have become available on the + * server, "Data Ready" is raised again, and svc_recv calls + * svc_rdma_recvfrom again. This second call may use a different + * svc_rqst than the first one, thus any information that needs + * to be preserved across these two calls is kept in an + * svc_rdma_recv_ctxt. + * + * The second call to svc_rdma_recvfrom performs final assembly + * of the RPC Call message, using the RDMA Read sink pages kept in + * the svc_rdma_recv_ctxt. The xdr_buf is copied from the + * svc_rdma_recv_ctxt to the second svc_rqst. The second call returns + * the length of the completed RPC Call message. + * + * Page Management + * + * Pages under I/O must be transferred from the first svc_rqst to an + * svc_rdma_recv_ctxt before the first svc_rdma_recvfrom call returns. + * + * The first svc_rqst supplies pages for RDMA Reads. These are moved + * from rqstp::rq_pages into ctxt::pages. The consumed elements of + * the rq_pages array are set to NULL and refilled with the first + * svc_rdma_recvfrom call returns. + * + * During the second svc_rdma_recvfrom call, RDMA Read sink pages + * are transferred from the svc_rdma_recv_ctxt to the second svc_rqst + * (see rdma_read_complete() below). + */ + +#include <linux/spinlock.h> +#include <asm/unaligned.h> +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc); + +static inline struct svc_rdma_recv_ctxt * +svc_rdma_next_recv_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_recv_ctxt, + rc_list); +} + +static struct svc_rdma_recv_ctxt * +svc_rdma_recv_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + dma_addr_t addr; + void *buffer; + + ctxt = kmalloc(sizeof(*ctxt), GFP_KERNEL); + if (!ctxt) + goto fail0; + buffer = kmalloc(rdma->sc_max_req_size, GFP_KERNEL); + if (!buffer) + goto fail1; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_FROM_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail2; + + ctxt->rc_recv_wr.next = NULL; + ctxt->rc_recv_wr.wr_cqe = &ctxt->rc_cqe; + ctxt->rc_recv_wr.sg_list = &ctxt->rc_recv_sge; + ctxt->rc_recv_wr.num_sge = 1; + ctxt->rc_cqe.done = svc_rdma_wc_receive; + ctxt->rc_recv_sge.addr = addr; + ctxt->rc_recv_sge.length = rdma->sc_max_req_size; + ctxt->rc_recv_sge.lkey = rdma->sc_pd->local_dma_lkey; + ctxt->rc_recv_buf = buffer; + ctxt->rc_temp = false; + return ctxt; + +fail2: + kfree(buffer); +fail1: + kfree(ctxt); +fail0: + return NULL; +} + +static void svc_rdma_recv_ctxt_destroy(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + ib_dma_unmap_single(rdma->sc_pd->device, ctxt->rc_recv_sge.addr, + ctxt->rc_recv_sge.length, DMA_FROM_DEVICE); + kfree(ctxt->rc_recv_buf); + kfree(ctxt); +} + +/** + * svc_rdma_recv_ctxts_destroy - Release all recv_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down + * + */ +void svc_rdma_recv_ctxts_destroy(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_recv_ctxts))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_destroy(rdma, ctxt); + } +} + +static struct svc_rdma_recv_ctxt * +svc_rdma_recv_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + + spin_lock(&rdma->sc_recv_lock); + ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_recv_ctxts); + if (!ctxt) + goto out_empty; + list_del(&ctxt->rc_list); + spin_unlock(&rdma->sc_recv_lock); + +out: + ctxt->rc_page_count = 0; + return ctxt; + +out_empty: + spin_unlock(&rdma->sc_recv_lock); + + ctxt = svc_rdma_recv_ctxt_alloc(rdma); + if (!ctxt) + return NULL; + goto out; +} + +/** + * svc_rdma_recv_ctxt_put - Return recv_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + * + */ +void svc_rdma_recv_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + unsigned int i; + + for (i = 0; i < ctxt->rc_page_count; i++) + put_page(ctxt->rc_pages[i]); + + if (!ctxt->rc_temp) { + spin_lock(&rdma->sc_recv_lock); + list_add(&ctxt->rc_list, &rdma->sc_recv_ctxts); + spin_unlock(&rdma->sc_recv_lock); + } else + svc_rdma_recv_ctxt_destroy(rdma, ctxt); +} + +/** + * svc_rdma_release_rqst - Release transport-specific per-rqst resources + * @rqstp: svc_rqst being released + * + * Ensure that the recv_ctxt is released whether or not a Reply + * was sent. For example, the client could close the connection, + * or svc_process could drop an RPC, before the Reply is sent. + */ +void svc_rdma_release_rqst(struct svc_rqst *rqstp) +{ + struct svc_rdma_recv_ctxt *ctxt = rqstp->rq_xprt_ctxt; + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + rqstp->rq_xprt_ctxt = NULL; + if (ctxt) + svc_rdma_recv_ctxt_put(rdma, ctxt); +} + +static int __svc_rdma_post_recv(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + int ret; + + svc_xprt_get(&rdma->sc_xprt); + ret = ib_post_recv(rdma->sc_qp, &ctxt->rc_recv_wr, NULL); + trace_svcrdma_post_recv(&ctxt->rc_recv_wr, ret); + if (ret) + goto err_post; + return 0; + +err_post: + svc_rdma_recv_ctxt_put(rdma, ctxt); + svc_xprt_put(&rdma->sc_xprt); + return ret; +} + +static int svc_rdma_post_recv(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + + if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) + return 0; + ctxt = svc_rdma_recv_ctxt_get(rdma); + if (!ctxt) + return -ENOMEM; + return __svc_rdma_post_recv(rdma, ctxt); +} + +/** + * svc_rdma_post_recvs - Post initial set of Recv WRs + * @rdma: fresh svcxprt_rdma + * + * Returns true if successful, otherwise false. + */ +bool svc_rdma_post_recvs(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + unsigned int i; + int ret; + + for (i = 0; i < rdma->sc_max_requests; i++) { + ctxt = svc_rdma_recv_ctxt_get(rdma); + if (!ctxt) + return false; + ctxt->rc_temp = true; + ret = __svc_rdma_post_recv(rdma, ctxt); + if (ret) { + pr_err("svcrdma: failure posting recv buffers: %d\n", + ret); + return false; + } + } + return true; +} + +/** + * svc_rdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: Completion Queue context + * @wc: Work Completion object + * + * NB: The svc_xprt/svcxprt_rdma is pinned whenever it's possible that + * the Receive completion handler could be running. + */ +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_recv_ctxt *ctxt; + + trace_svcrdma_wc_receive(wc); + + /* WARNING: Only wc->wr_cqe and wc->status are reliable */ + ctxt = container_of(cqe, struct svc_rdma_recv_ctxt, rc_cqe); + + if (wc->status != IB_WC_SUCCESS) + goto flushed; + + if (svc_rdma_post_recv(rdma)) + goto post_err; + + /* All wc fields are now known to be valid */ + ctxt->rc_byte_len = wc->byte_len; + ib_dma_sync_single_for_cpu(rdma->sc_pd->device, + ctxt->rc_recv_sge.addr, + wc->byte_len, DMA_FROM_DEVICE); + + spin_lock(&rdma->sc_rq_dto_lock); + list_add_tail(&ctxt->rc_list, &rdma->sc_rq_dto_q); + spin_unlock(&rdma->sc_rq_dto_lock); + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + if (!test_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags)) + svc_xprt_enqueue(&rdma->sc_xprt); + goto out; + +flushed: + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("svcrdma: Recv: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); +post_err: + svc_rdma_recv_ctxt_put(rdma, ctxt); + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + svc_xprt_enqueue(&rdma->sc_xprt); +out: + svc_xprt_put(&rdma->sc_xprt); +} + +/** + * svc_rdma_flush_recv_queues - Drain pending Receive work + * @rdma: svcxprt_rdma being shut down + * + */ +void svc_rdma_flush_recv_queues(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_read_complete_q))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_put(rdma, ctxt); + } + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_rq_dto_q))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_put(rdma, ctxt); + } +} + +static void svc_rdma_build_arg_xdr(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct xdr_buf *arg = &rqstp->rq_arg; + + arg->head[0].iov_base = ctxt->rc_recv_buf; + arg->head[0].iov_len = ctxt->rc_byte_len; + arg->tail[0].iov_base = NULL; + arg->tail[0].iov_len = 0; + arg->page_len = 0; + arg->page_base = 0; + arg->buflen = ctxt->rc_byte_len; + arg->len = ctxt->rc_byte_len; +} + +/* This accommodates the largest possible Write chunk, + * in one segment. + */ +#define MAX_BYTES_WRITE_SEG ((u32)(RPCSVC_MAXPAGES << PAGE_SHIFT)) + +/* This accommodates the largest possible Position-Zero + * Read chunk or Reply chunk, in one segment. + */ +#define MAX_BYTES_SPECIAL_SEG ((u32)((RPCSVC_MAXPAGES + 2) << PAGE_SHIFT)) + +/* Sanity check the Read list. + * + * Implementation limits: + * - This implementation supports only one Read chunk. + * + * Sanity checks: + * - Read list does not overflow buffer. + * - Segment size limited by largest NFS data payload. + * + * The segment count is limited to how many segments can + * fit in the transport header without overflowing the + * buffer. That's about 40 Read segments for a 1KB inline + * threshold. + * + * Returns pointer to the following Write list. + */ +static __be32 *xdr_check_read_list(__be32 *p, const __be32 *end) +{ + u32 position; + bool first; + + first = true; + while (*p++ != xdr_zero) { + if (first) { + position = be32_to_cpup(p++); + first = false; + } else if (be32_to_cpup(p++) != position) { + return NULL; + } + p++; /* handle */ + if (be32_to_cpup(p++) > MAX_BYTES_SPECIAL_SEG) + return NULL; + p += 2; /* offset */ + + if (p > end) + return NULL; + } + return p; +} + +/* The segment count is limited to how many segments can + * fit in the transport header without overflowing the + * buffer. That's about 60 Write segments for a 1KB inline + * threshold. + */ +static __be32 *xdr_check_write_chunk(__be32 *p, const __be32 *end, + u32 maxlen) +{ + u32 i, segcount; + + segcount = be32_to_cpup(p++); + for (i = 0; i < segcount; i++) { + p++; /* handle */ + if (be32_to_cpup(p++) > maxlen) + return NULL; + p += 2; /* offset */ + + if (p > end) + return NULL; + } + + return p; +} + +/* Sanity check the Write list. + * + * Implementation limits: + * - This implementation supports only one Write chunk. + * + * Sanity checks: + * - Write list does not overflow buffer. + * - Segment size limited by largest NFS data payload. + * + * Returns pointer to the following Reply chunk. + */ +static __be32 *xdr_check_write_list(__be32 *p, const __be32 *end) +{ + u32 chcount; + + chcount = 0; + while (*p++ != xdr_zero) { + p = xdr_check_write_chunk(p, end, MAX_BYTES_WRITE_SEG); + if (!p) + return NULL; + if (chcount++ > 1) + return NULL; + } + return p; +} + +/* Sanity check the Reply chunk. + * + * Sanity checks: + * - Reply chunk does not overflow buffer. + * - Segment size limited by largest NFS data payload. + * + * Returns pointer to the following RPC header. + */ +static __be32 *xdr_check_reply_chunk(__be32 *p, const __be32 *end) +{ + if (*p++ != xdr_zero) { + p = xdr_check_write_chunk(p, end, MAX_BYTES_SPECIAL_SEG); + if (!p) + return NULL; + } + return p; +} + +/* On entry, xdr->head[0].iov_base points to first byte in the + * RPC-over-RDMA header. + * + * On successful exit, head[0] points to first byte past the + * RPC-over-RDMA header. For RDMA_MSG, this is the RPC message. + * The length of the RPC-over-RDMA header is returned. + * + * Assumptions: + * - The transport header is entirely contained in the head iovec. + */ +static int svc_rdma_xdr_decode_req(struct xdr_buf *rq_arg) +{ + __be32 *p, *end, *rdma_argp; + unsigned int hdr_len; + + /* Verify that there's enough bytes for header + something */ + if (rq_arg->len <= RPCRDMA_HDRLEN_ERR) + goto out_short; + + rdma_argp = rq_arg->head[0].iov_base; + if (*(rdma_argp + 1) != rpcrdma_version) + goto out_version; + + switch (*(rdma_argp + 3)) { + case rdma_msg: + break; + case rdma_nomsg: + break; + + case rdma_done: + goto out_drop; + + case rdma_error: + goto out_drop; + + default: + goto out_proc; + } + + end = (__be32 *)((unsigned long)rdma_argp + rq_arg->len); + p = xdr_check_read_list(rdma_argp + 4, end); + if (!p) + goto out_inval; + p = xdr_check_write_list(p, end); + if (!p) + goto out_inval; + p = xdr_check_reply_chunk(p, end); + if (!p) + goto out_inval; + if (p > end) + goto out_inval; + + rq_arg->head[0].iov_base = p; + hdr_len = (unsigned long)p - (unsigned long)rdma_argp; + rq_arg->head[0].iov_len -= hdr_len; + rq_arg->len -= hdr_len; + trace_svcrdma_decode_rqst(rdma_argp, hdr_len); + return hdr_len; + +out_short: + trace_svcrdma_decode_short(rq_arg->len); + return -EINVAL; + +out_version: + trace_svcrdma_decode_badvers(rdma_argp); + return -EPROTONOSUPPORT; + +out_drop: + trace_svcrdma_decode_drop(rdma_argp); + return 0; + +out_proc: + trace_svcrdma_decode_badproc(rdma_argp); + return -EINVAL; + +out_inval: + trace_svcrdma_decode_parse(rdma_argp); + return -EINVAL; +} + +static void rdma_read_complete(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + int page_no; + + /* Move Read chunk pages to rqstp so that they will be released + * when svc_process is done with them. + */ + for (page_no = 0; page_no < head->rc_page_count; page_no++) { + put_page(rqstp->rq_pages[page_no]); + rqstp->rq_pages[page_no] = head->rc_pages[page_no]; + } + head->rc_page_count = 0; + + /* Point rq_arg.pages past header */ + rqstp->rq_arg.pages = &rqstp->rq_pages[head->rc_hdr_count]; + rqstp->rq_arg.page_len = head->rc_arg.page_len; + + /* rq_respages starts after the last arg page */ + rqstp->rq_respages = &rqstp->rq_pages[page_no]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + + /* Rebuild rq_arg head and tail. */ + rqstp->rq_arg.head[0] = head->rc_arg.head[0]; + rqstp->rq_arg.tail[0] = head->rc_arg.tail[0]; + rqstp->rq_arg.len = head->rc_arg.len; + rqstp->rq_arg.buflen = head->rc_arg.buflen; +} + +static void svc_rdma_send_error(struct svcxprt_rdma *xprt, + __be32 *rdma_argp, int status) +{ + struct svc_rdma_send_ctxt *ctxt; + unsigned int length; + __be32 *p; + int ret; + + ctxt = svc_rdma_send_ctxt_get(xprt); + if (!ctxt) + return; + + p = ctxt->sc_xprt_buf; + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = xprt->sc_fc_credits; + *p++ = rdma_error; + switch (status) { + case -EPROTONOSUPPORT: + *p++ = err_vers; + *p++ = rpcrdma_version; + *p++ = rpcrdma_version; + trace_svcrdma_err_vers(*rdma_argp); + break; + default: + *p++ = err_chunk; + trace_svcrdma_err_chunk(*rdma_argp); + } + length = (unsigned long)p - (unsigned long)ctxt->sc_xprt_buf; + svc_rdma_sync_reply_hdr(xprt, ctxt, length); + + ctxt->sc_send_wr.opcode = IB_WR_SEND; + ret = svc_rdma_send(xprt, &ctxt->sc_send_wr); + if (ret) + svc_rdma_send_ctxt_put(xprt, ctxt); +} + +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool svc_rdma_is_backchannel_reply(struct svc_xprt *xprt, + __be32 *rdma_resp) +{ + __be32 *p; + + if (!xprt->xpt_bc_xprt) + return false; + + p = rdma_resp + 3; + if (*p++ != rdma_msg) + return false; + + if (*p++ != xdr_zero) + return false; + if (*p++ != xdr_zero) + return false; + if (*p++ != xdr_zero) + return false; + + /* XID sanity */ + if (*p++ != *rdma_resp) + return false; + /* call direction */ + if (*p == cpu_to_be32(RPC_CALL)) + return false; + + return true; +} + +/** + * svc_rdma_recvfrom - Receive an RPC call + * @rqstp: request structure into which to receive an RPC Call + * + * Returns: + * The positive number of bytes in the RPC Call message, + * %0 if there were no Calls ready to return, + * %-EINVAL if the Read chunk data is too large, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + * + * Called in a loop when XPT_DATA is set. XPT_DATA is cleared only + * when there are no remaining ctxt's to process. + * + * The next ctxt is removed from the "receive" lists. + * + * - If the ctxt completes a Read, then finish assembling the Call + * message and return the number of bytes in the message. + * + * - If the ctxt completes a Receive, then construct the Call + * message from the contents of the Receive buffer. + * + * - If there are no Read chunks in this message, then finish + * assembling the Call message and return the number of bytes + * in the message. + * + * - If there are Read chunks in this message, post Read WRs to + * pull that payload and return 0. + */ +int svc_rdma_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma_xprt = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svc_rdma_recv_ctxt *ctxt; + __be32 *p; + int ret; + + rqstp->rq_xprt_ctxt = NULL; + + spin_lock(&rdma_xprt->sc_rq_dto_lock); + ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_read_complete_q); + if (ctxt) { + list_del(&ctxt->rc_list); + spin_unlock(&rdma_xprt->sc_rq_dto_lock); + rdma_read_complete(rqstp, ctxt); + goto complete; + } + ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_rq_dto_q); + if (!ctxt) { + /* No new incoming requests, terminate the loop */ + clear_bit(XPT_DATA, &xprt->xpt_flags); + spin_unlock(&rdma_xprt->sc_rq_dto_lock); + return 0; + } + list_del(&ctxt->rc_list); + spin_unlock(&rdma_xprt->sc_rq_dto_lock); + + atomic_inc(&rdma_stat_recv); + + svc_rdma_build_arg_xdr(rqstp, ctxt); + + /* Prevent svc_xprt_release from releasing pages in rq_pages + * if we return 0 or an error. + */ + rqstp->rq_respages = rqstp->rq_pages; + rqstp->rq_next_page = rqstp->rq_respages; + + p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + ret = svc_rdma_xdr_decode_req(&rqstp->rq_arg); + if (ret < 0) + goto out_err; + if (ret == 0) + goto out_drop; + rqstp->rq_xprt_hlen = ret; + + if (svc_rdma_is_backchannel_reply(xprt, p)) { + ret = svc_rdma_handle_bc_reply(xprt->xpt_bc_xprt, p, + &rqstp->rq_arg); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return ret; + } + + p += rpcrdma_fixed_maxsz; + if (*p != xdr_zero) + goto out_readchunk; + +complete: + rqstp->rq_xprt_ctxt = ctxt; + rqstp->rq_prot = IPPROTO_MAX; + svc_xprt_copy_addrs(rqstp, xprt); + return rqstp->rq_arg.len; + +out_readchunk: + ret = svc_rdma_recv_read_chunk(rdma_xprt, rqstp, ctxt, p); + if (ret < 0) + goto out_postfail; + return 0; + +out_err: + svc_rdma_send_error(rdma_xprt, p, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return 0; + +out_postfail: + if (ret == -EINVAL) + svc_rdma_send_error(rdma_xprt, p, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return ret; + +out_drop: + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c new file mode 100644 index 000000000..22f135263 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * + * Use the core R/W API to move RPC-over-RDMA Read and Write chunks. + */ + +#include <rdma/rw.h> + +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/debug.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc); +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc); + +/* Each R/W context contains state for one chain of RDMA Read or + * Write Work Requests. + * + * Each WR chain handles a single contiguous server-side buffer, + * because scatterlist entries after the first have to start on + * page alignment. xdr_buf iovecs cannot guarantee alignment. + * + * Each WR chain handles only one R_key. Each RPC-over-RDMA segment + * from a client may contain a unique R_key, so each WR chain moves + * up to one segment at a time. + * + * The scatterlist makes this data structure over 4KB in size. To + * make it less likely to fail, and to handle the allocation for + * smaller I/O requests without disabling bottom-halves, these + * contexts are created on demand, but cached and reused until the + * controlling svcxprt_rdma is destroyed. + */ +struct svc_rdma_rw_ctxt { + struct list_head rw_list; + struct rdma_rw_ctx rw_ctx; + int rw_nents; + struct sg_table rw_sg_table; + struct scatterlist rw_first_sgl[0]; +}; + +static inline struct svc_rdma_rw_ctxt * +svc_rdma_next_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_rw_ctxt, + rw_list); +} + +static struct svc_rdma_rw_ctxt * +svc_rdma_get_rw_ctxt(struct svcxprt_rdma *rdma, unsigned int sges) +{ + struct svc_rdma_rw_ctxt *ctxt; + + spin_lock(&rdma->sc_rw_ctxt_lock); + + ctxt = svc_rdma_next_ctxt(&rdma->sc_rw_ctxts); + if (ctxt) { + list_del(&ctxt->rw_list); + spin_unlock(&rdma->sc_rw_ctxt_lock); + } else { + spin_unlock(&rdma->sc_rw_ctxt_lock); + ctxt = kmalloc(sizeof(*ctxt) + + SG_CHUNK_SIZE * sizeof(struct scatterlist), + GFP_KERNEL); + if (!ctxt) + goto out; + INIT_LIST_HEAD(&ctxt->rw_list); + } + + ctxt->rw_sg_table.sgl = ctxt->rw_first_sgl; + if (sg_alloc_table_chained(&ctxt->rw_sg_table, sges, + ctxt->rw_sg_table.sgl)) { + kfree(ctxt); + ctxt = NULL; + } +out: + return ctxt; +} + +static void svc_rdma_put_rw_ctxt(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt) +{ + sg_free_table_chained(&ctxt->rw_sg_table, true); + + spin_lock(&rdma->sc_rw_ctxt_lock); + list_add(&ctxt->rw_list, &rdma->sc_rw_ctxts); + spin_unlock(&rdma->sc_rw_ctxt_lock); +} + +/** + * svc_rdma_destroy_rw_ctxts - Free accumulated R/W contexts + * @rdma: transport about to be destroyed + * + */ +void svc_rdma_destroy_rw_ctxts(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_rw_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_ctxt(&rdma->sc_rw_ctxts)) != NULL) { + list_del(&ctxt->rw_list); + kfree(ctxt); + } +} + +/* A chunk context tracks all I/O for moving one Read or Write + * chunk. This is a a set of rdma_rw's that handle data movement + * for all segments of one chunk. + * + * These are small, acquired with a single allocator call, and + * no more than one is needed per chunk. They are allocated on + * demand, and not cached. + */ +struct svc_rdma_chunk_ctxt { + struct ib_cqe cc_cqe; + struct svcxprt_rdma *cc_rdma; + struct list_head cc_rwctxts; + int cc_sqecount; +}; + +static void svc_rdma_cc_init(struct svcxprt_rdma *rdma, + struct svc_rdma_chunk_ctxt *cc) +{ + cc->cc_rdma = rdma; + svc_xprt_get(&rdma->sc_xprt); + + INIT_LIST_HEAD(&cc->cc_rwctxts); + cc->cc_sqecount = 0; +} + +static void svc_rdma_cc_release(struct svc_rdma_chunk_ctxt *cc, + enum dma_data_direction dir) +{ + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_rdma_rw_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_ctxt(&cc->cc_rwctxts)) != NULL) { + list_del(&ctxt->rw_list); + + rdma_rw_ctx_destroy(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, ctxt->rw_sg_table.sgl, + ctxt->rw_nents, dir); + svc_rdma_put_rw_ctxt(rdma, ctxt); + } + svc_xprt_put(&rdma->sc_xprt); +} + +/* State for sending a Write or Reply chunk. + * - Tracks progress of writing one chunk over all its segments + * - Stores arguments for the SGL constructor functions + */ +struct svc_rdma_write_info { + /* write state of this chunk */ + unsigned int wi_seg_off; + unsigned int wi_seg_no; + unsigned int wi_nsegs; + __be32 *wi_segs; + + /* SGL constructor arguments */ + struct xdr_buf *wi_xdr; + unsigned char *wi_base; + unsigned int wi_next_off; + + struct svc_rdma_chunk_ctxt wi_cc; +}; + +static struct svc_rdma_write_info * +svc_rdma_write_info_alloc(struct svcxprt_rdma *rdma, __be32 *chunk) +{ + struct svc_rdma_write_info *info; + + info = kmalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return info; + + info->wi_seg_off = 0; + info->wi_seg_no = 0; + info->wi_nsegs = be32_to_cpup(++chunk); + info->wi_segs = ++chunk; + svc_rdma_cc_init(rdma, &info->wi_cc); + info->wi_cc.cc_cqe.done = svc_rdma_write_done; + return info; +} + +static void svc_rdma_write_info_free(struct svc_rdma_write_info *info) +{ + svc_rdma_cc_release(&info->wi_cc, DMA_TO_DEVICE); + kfree(info); +} + +/** + * svc_rdma_write_done - Write chunk completion + * @cq: controlling Completion Queue + * @wc: Work Completion + * + * Pages under I/O are freed by a subsequent Send completion. + */ +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_rdma_write_info *info = + container_of(cc, struct svc_rdma_write_info, wi_cc); + + trace_svcrdma_wc_write(wc); + + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + + if (unlikely(wc->status != IB_WC_SUCCESS)) { + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("svcrdma: write ctx: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); + } + + svc_rdma_write_info_free(info); +} + +/* State for pulling a Read chunk. + */ +struct svc_rdma_read_info { + struct svc_rdma_recv_ctxt *ri_readctxt; + unsigned int ri_position; + unsigned int ri_pageno; + unsigned int ri_pageoff; + unsigned int ri_chunklen; + + struct svc_rdma_chunk_ctxt ri_cc; +}; + +static struct svc_rdma_read_info * +svc_rdma_read_info_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_read_info *info; + + info = kmalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return info; + + svc_rdma_cc_init(rdma, &info->ri_cc); + info->ri_cc.cc_cqe.done = svc_rdma_wc_read_done; + return info; +} + +static void svc_rdma_read_info_free(struct svc_rdma_read_info *info) +{ + svc_rdma_cc_release(&info->ri_cc, DMA_FROM_DEVICE); + kfree(info); +} + +/** + * svc_rdma_wc_read_done - Handle completion of an RDMA Read ctx + * @cq: controlling Completion Queue + * @wc: Work Completion + * + */ +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_rdma_read_info *info = + container_of(cc, struct svc_rdma_read_info, ri_cc); + + trace_svcrdma_wc_read(wc); + + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + + if (unlikely(wc->status != IB_WC_SUCCESS)) { + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("svcrdma: read ctx: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); + svc_rdma_recv_ctxt_put(rdma, info->ri_readctxt); + } else { + spin_lock(&rdma->sc_rq_dto_lock); + list_add_tail(&info->ri_readctxt->rc_list, + &rdma->sc_read_complete_q); + spin_unlock(&rdma->sc_rq_dto_lock); + + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + svc_xprt_enqueue(&rdma->sc_xprt); + } + + svc_rdma_read_info_free(info); +} + +/* This function sleeps when the transport's Send Queue is congested. + * + * Assumptions: + * - If ib_post_send() succeeds, only one completion is expected, + * even if one or more WRs are flushed. This is true when posting + * an rdma_rw_ctx or when posting a single signaled WR. + */ +static int svc_rdma_post_chunk_ctxt(struct svc_rdma_chunk_ctxt *cc) +{ + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_xprt *xprt = &rdma->sc_xprt; + struct ib_send_wr *first_wr; + const struct ib_send_wr *bad_wr; + struct list_head *tmp; + struct ib_cqe *cqe; + int ret; + + if (cc->cc_sqecount > rdma->sc_sq_depth) + return -EINVAL; + + first_wr = NULL; + cqe = &cc->cc_cqe; + list_for_each(tmp, &cc->cc_rwctxts) { + struct svc_rdma_rw_ctxt *ctxt; + + ctxt = list_entry(tmp, struct svc_rdma_rw_ctxt, rw_list); + first_wr = rdma_rw_ctx_wrs(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, cqe, first_wr); + cqe = NULL; + } + + do { + if (atomic_sub_return(cc->cc_sqecount, + &rdma->sc_sq_avail) > 0) { + ret = ib_post_send(rdma->sc_qp, first_wr, &bad_wr); + if (ret) + break; + return 0; + } + + trace_svcrdma_sq_full(rdma); + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > cc->cc_sqecount); + trace_svcrdma_sq_retry(rdma); + } while (1); + + trace_svcrdma_sq_post_err(rdma, ret); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + + /* If even one was posted, there will be a completion. */ + if (bad_wr != first_wr) + return 0; + + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + return -ENOTCONN; +} + +/* Build and DMA-map an SGL that covers one kvec in an xdr_buf + */ +static void svc_rdma_vec_to_sg(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt) +{ + struct scatterlist *sg = ctxt->rw_sg_table.sgl; + + sg_set_buf(&sg[0], info->wi_base, len); + info->wi_base += len; + + ctxt->rw_nents = 1; +} + +/* Build and DMA-map an SGL that covers part of an xdr_buf's pagelist. + */ +static void svc_rdma_pagelist_to_sg(struct svc_rdma_write_info *info, + unsigned int remaining, + struct svc_rdma_rw_ctxt *ctxt) +{ + unsigned int sge_no, sge_bytes, page_off, page_no; + struct xdr_buf *xdr = info->wi_xdr; + struct scatterlist *sg; + struct page **page; + + page_off = info->wi_next_off + xdr->page_base; + page_no = page_off >> PAGE_SHIFT; + page_off = offset_in_page(page_off); + page = xdr->pages + page_no; + info->wi_next_off += remaining; + sg = ctxt->rw_sg_table.sgl; + sge_no = 0; + do { + sge_bytes = min_t(unsigned int, remaining, + PAGE_SIZE - page_off); + sg_set_page(sg, *page, sge_bytes, page_off); + + remaining -= sge_bytes; + sg = sg_next(sg); + page_off = 0; + sge_no++; + page++; + } while (remaining); + + ctxt->rw_nents = sge_no; +} + +/* Construct RDMA Write WRs to send a portion of an xdr_buf containing + * an RPC Reply. + */ +static int +svc_rdma_build_writes(struct svc_rdma_write_info *info, + void (*constructor)(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt), + unsigned int remaining) +{ + struct svc_rdma_chunk_ctxt *cc = &info->wi_cc; + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_rdma_rw_ctxt *ctxt; + __be32 *seg; + int ret; + + seg = info->wi_segs + info->wi_seg_no * rpcrdma_segment_maxsz; + do { + unsigned int write_len; + u32 seg_length, seg_handle; + u64 seg_offset; + + if (info->wi_seg_no >= info->wi_nsegs) + goto out_overflow; + + seg_handle = be32_to_cpup(seg); + seg_length = be32_to_cpup(seg + 1); + xdr_decode_hyper(seg + 2, &seg_offset); + seg_offset += info->wi_seg_off; + + write_len = min(remaining, seg_length - info->wi_seg_off); + ctxt = svc_rdma_get_rw_ctxt(rdma, + (write_len >> PAGE_SHIFT) + 2); + if (!ctxt) + goto out_noctx; + + constructor(info, write_len, ctxt); + ret = rdma_rw_ctx_init(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, ctxt->rw_sg_table.sgl, + ctxt->rw_nents, 0, seg_offset, + seg_handle, DMA_TO_DEVICE); + if (ret < 0) + goto out_initerr; + + trace_svcrdma_encode_wseg(seg_handle, write_len, seg_offset); + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + if (write_len == seg_length - info->wi_seg_off) { + seg += 4; + info->wi_seg_no++; + info->wi_seg_off = 0; + } else { + info->wi_seg_off += write_len; + } + remaining -= write_len; + } while (remaining); + + return 0; + +out_overflow: + dprintk("svcrdma: inadequate space in Write chunk (%u)\n", + info->wi_nsegs); + return -E2BIG; + +out_noctx: + dprintk("svcrdma: no R/W ctxs available\n"); + return -ENOMEM; + +out_initerr: + svc_rdma_put_rw_ctxt(rdma, ctxt); + trace_svcrdma_dma_map_rwctx(rdma, ret); + return -EIO; +} + +/* Send one of an xdr_buf's kvecs by itself. To send a Reply + * chunk, the whole RPC Reply is written back to the client. + * This function writes either the head or tail of the xdr_buf + * containing the Reply. + */ +static int svc_rdma_send_xdr_kvec(struct svc_rdma_write_info *info, + struct kvec *vec) +{ + info->wi_base = vec->iov_base; + return svc_rdma_build_writes(info, svc_rdma_vec_to_sg, + vec->iov_len); +} + +/* Send an xdr_buf's page list by itself. A Write chunk is + * just the page list. a Reply chunk is the head, page list, + * and tail. This function is shared between the two types + * of chunk. + */ +static int svc_rdma_send_xdr_pagelist(struct svc_rdma_write_info *info, + struct xdr_buf *xdr) +{ + info->wi_xdr = xdr; + info->wi_next_off = 0; + return svc_rdma_build_writes(info, svc_rdma_pagelist_to_sg, + xdr->page_len); +} + +/** + * svc_rdma_send_write_chunk - Write all segments in a Write chunk + * @rdma: controlling RDMA transport + * @wr_ch: Write chunk provided by client + * @xdr: xdr_buf containing the data payload + * + * Returns a non-negative number of bytes the chunk consumed, or + * %-E2BIG if the payload was larger than the Write chunk, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_send_write_chunk(struct svcxprt_rdma *rdma, __be32 *wr_ch, + struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info; + int ret; + + if (!xdr->page_len) + return 0; + + info = svc_rdma_write_info_alloc(rdma, wr_ch); + if (!info) + return -ENOMEM; + + ret = svc_rdma_send_xdr_pagelist(info, xdr); + if (ret < 0) + goto out_err; + + ret = svc_rdma_post_chunk_ctxt(&info->wi_cc); + if (ret < 0) + goto out_err; + + trace_svcrdma_encode_write(xdr->page_len); + return xdr->page_len; + +out_err: + svc_rdma_write_info_free(info); + return ret; +} + +/** + * svc_rdma_send_reply_chunk - Write all segments in the Reply chunk + * @rdma: controlling RDMA transport + * @rp_ch: Reply chunk provided by client + * @writelist: true if client provided a Write list + * @xdr: xdr_buf containing an RPC Reply + * + * Returns a non-negative number of bytes the chunk consumed, or + * %-E2BIG if the payload was larger than the Reply chunk, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_send_reply_chunk(struct svcxprt_rdma *rdma, __be32 *rp_ch, + bool writelist, struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info; + int consumed, ret; + + info = svc_rdma_write_info_alloc(rdma, rp_ch); + if (!info) + return -ENOMEM; + + ret = svc_rdma_send_xdr_kvec(info, &xdr->head[0]); + if (ret < 0) + goto out_err; + consumed = xdr->head[0].iov_len; + + /* Send the page list in the Reply chunk only if the + * client did not provide Write chunks. + */ + if (!writelist && xdr->page_len) { + ret = svc_rdma_send_xdr_pagelist(info, xdr); + if (ret < 0) + goto out_err; + consumed += xdr->page_len; + } + + if (xdr->tail[0].iov_len) { + ret = svc_rdma_send_xdr_kvec(info, &xdr->tail[0]); + if (ret < 0) + goto out_err; + consumed += xdr->tail[0].iov_len; + } + + ret = svc_rdma_post_chunk_ctxt(&info->wi_cc); + if (ret < 0) + goto out_err; + + trace_svcrdma_encode_reply(consumed); + return consumed; + +out_err: + svc_rdma_write_info_free(info); + return ret; +} + +static int svc_rdma_build_read_segment(struct svc_rdma_read_info *info, + struct svc_rqst *rqstp, + u32 rkey, u32 len, u64 offset) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + struct svc_rdma_chunk_ctxt *cc = &info->ri_cc; + struct svc_rdma_rw_ctxt *ctxt; + unsigned int sge_no, seg_len; + struct scatterlist *sg; + int ret; + + sge_no = PAGE_ALIGN(info->ri_pageoff + len) >> PAGE_SHIFT; + ctxt = svc_rdma_get_rw_ctxt(cc->cc_rdma, sge_no); + if (!ctxt) + goto out_noctx; + ctxt->rw_nents = sge_no; + + sg = ctxt->rw_sg_table.sgl; + for (sge_no = 0; sge_no < ctxt->rw_nents; sge_no++) { + seg_len = min_t(unsigned int, len, + PAGE_SIZE - info->ri_pageoff); + + head->rc_arg.pages[info->ri_pageno] = + rqstp->rq_pages[info->ri_pageno]; + if (!info->ri_pageoff) + head->rc_page_count++; + + sg_set_page(sg, rqstp->rq_pages[info->ri_pageno], + seg_len, info->ri_pageoff); + sg = sg_next(sg); + + info->ri_pageoff += seg_len; + if (info->ri_pageoff == PAGE_SIZE) { + info->ri_pageno++; + info->ri_pageoff = 0; + } + len -= seg_len; + + /* Safety check */ + if (len && + &rqstp->rq_pages[info->ri_pageno + 1] > rqstp->rq_page_end) + goto out_overrun; + } + + ret = rdma_rw_ctx_init(&ctxt->rw_ctx, cc->cc_rdma->sc_qp, + cc->cc_rdma->sc_port_num, + ctxt->rw_sg_table.sgl, ctxt->rw_nents, + 0, offset, rkey, DMA_FROM_DEVICE); + if (ret < 0) + goto out_initerr; + + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + return 0; + +out_noctx: + dprintk("svcrdma: no R/W ctxs available\n"); + return -ENOMEM; + +out_overrun: + dprintk("svcrdma: request overruns rq_pages\n"); + return -EINVAL; + +out_initerr: + trace_svcrdma_dma_map_rwctx(cc->cc_rdma, ret); + svc_rdma_put_rw_ctxt(cc->cc_rdma, ctxt); + return -EIO; +} + +/* Walk the segments in the Read chunk starting at @p and construct + * RDMA Read operations to pull the chunk to the server. + */ +static int svc_rdma_build_read_chunk(struct svc_rqst *rqstp, + struct svc_rdma_read_info *info, + __be32 *p) +{ + int ret; + + ret = -EINVAL; + info->ri_chunklen = 0; + while (*p++ != xdr_zero && be32_to_cpup(p++) == info->ri_position) { + u32 rs_handle, rs_length; + u64 rs_offset; + + rs_handle = be32_to_cpup(p++); + rs_length = be32_to_cpup(p++); + p = xdr_decode_hyper(p, &rs_offset); + + ret = svc_rdma_build_read_segment(info, rqstp, + rs_handle, rs_length, + rs_offset); + if (ret < 0) + break; + + trace_svcrdma_encode_rseg(rs_handle, rs_length, rs_offset); + info->ri_chunklen += rs_length; + } + + return ret; +} + +/* Construct RDMA Reads to pull over a normal Read chunk. The chunk + * data lands in the page list of head->rc_arg.pages. + * + * Currently NFSD does not look at the head->rc_arg.tail[0] iovec. + * Therefore, XDR round-up of the Read chunk and trailing + * inline content must both be added at the end of the pagelist. + */ +static int svc_rdma_build_normal_read_chunk(struct svc_rqst *rqstp, + struct svc_rdma_read_info *info, + __be32 *p) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + int ret; + + ret = svc_rdma_build_read_chunk(rqstp, info, p); + if (ret < 0) + goto out; + + trace_svcrdma_encode_read(info->ri_chunklen, info->ri_position); + + head->rc_hdr_count = 0; + + /* Split the Receive buffer between the head and tail + * buffers at Read chunk's position. XDR roundup of the + * chunk is not included in either the pagelist or in + * the tail. + */ + head->rc_arg.tail[0].iov_base = + head->rc_arg.head[0].iov_base + info->ri_position; + head->rc_arg.tail[0].iov_len = + head->rc_arg.head[0].iov_len - info->ri_position; + head->rc_arg.head[0].iov_len = info->ri_position; + + /* Read chunk may need XDR roundup (see RFC 8166, s. 3.4.5.2). + * + * If the client already rounded up the chunk length, the + * length does not change. Otherwise, the length of the page + * list is increased to include XDR round-up. + * + * Currently these chunks always start at page offset 0, + * thus the rounded-up length never crosses a page boundary. + */ + info->ri_chunklen = XDR_QUADLEN(info->ri_chunklen) << 2; + + head->rc_arg.page_len = info->ri_chunklen; + head->rc_arg.len += info->ri_chunklen; + head->rc_arg.buflen += info->ri_chunklen; + +out: + return ret; +} + +/* Construct RDMA Reads to pull over a Position Zero Read chunk. + * The start of the data lands in the first page just after + * the Transport header, and the rest lands in the page list of + * head->rc_arg.pages. + * + * Assumptions: + * - A PZRC has an XDR-aligned length (no implicit round-up). + * - There can be no trailing inline content (IOW, we assume + * a PZRC is never sent in an RDMA_MSG message, though it's + * allowed by spec). + */ +static int svc_rdma_build_pz_read_chunk(struct svc_rqst *rqstp, + struct svc_rdma_read_info *info, + __be32 *p) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + int ret; + + ret = svc_rdma_build_read_chunk(rqstp, info, p); + if (ret < 0) + goto out; + + trace_svcrdma_encode_pzr(info->ri_chunklen); + + head->rc_arg.len += info->ri_chunklen; + head->rc_arg.buflen += info->ri_chunklen; + + head->rc_hdr_count = 1; + head->rc_arg.head[0].iov_base = page_address(head->rc_pages[0]); + head->rc_arg.head[0].iov_len = min_t(size_t, PAGE_SIZE, + info->ri_chunklen); + + head->rc_arg.page_len = info->ri_chunklen - + head->rc_arg.head[0].iov_len; + +out: + return ret; +} + +/* Pages under I/O have been copied to head->rc_pages. Ensure they + * are not released by svc_xprt_release() until the I/O is complete. + * + * This has to be done after all Read WRs are constructed to properly + * handle a page that is part of I/O on behalf of two different RDMA + * segments. + * + * Do this only if I/O has been posted. Otherwise, we do indeed want + * svc_xprt_release() to clean things up properly. + */ +static void svc_rdma_save_io_pages(struct svc_rqst *rqstp, + const unsigned int start, + const unsigned int num_pages) +{ + unsigned int i; + + for (i = start; i < num_pages + start; i++) + rqstp->rq_pages[i] = NULL; +} + +/** + * svc_rdma_recv_read_chunk - Pull a Read chunk from the client + * @rdma: controlling RDMA transport + * @rqstp: set of pages to use as Read sink buffers + * @head: pages under I/O collect here + * @p: pointer to start of Read chunk + * + * Returns: + * %0 if all needed RDMA Reads were posted successfully, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + * + * Assumptions: + * - All Read segments in @p have the same Position value. + */ +int svc_rdma_recv_read_chunk(struct svcxprt_rdma *rdma, struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head, __be32 *p) +{ + struct svc_rdma_read_info *info; + int ret; + + /* The request (with page list) is constructed in + * head->rc_arg. Pages involved with RDMA Read I/O are + * transferred there. + */ + head->rc_arg.head[0] = rqstp->rq_arg.head[0]; + head->rc_arg.tail[0] = rqstp->rq_arg.tail[0]; + head->rc_arg.pages = head->rc_pages; + head->rc_arg.page_base = 0; + head->rc_arg.page_len = 0; + head->rc_arg.len = rqstp->rq_arg.len; + head->rc_arg.buflen = rqstp->rq_arg.buflen; + + info = svc_rdma_read_info_alloc(rdma); + if (!info) + return -ENOMEM; + info->ri_readctxt = head; + info->ri_pageno = 0; + info->ri_pageoff = 0; + + info->ri_position = be32_to_cpup(p + 1); + if (info->ri_position) + ret = svc_rdma_build_normal_read_chunk(rqstp, info, p); + else + ret = svc_rdma_build_pz_read_chunk(rqstp, info, p); + if (ret < 0) + goto out_err; + + ret = svc_rdma_post_chunk_ctxt(&info->ri_cc); + if (ret < 0) + goto out_err; + svc_rdma_save_io_pages(rqstp, 0, head->rc_page_count); + return 0; + +out_err: + svc_rdma_read_info_free(info); + return ret; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c new file mode 100644 index 000000000..4062cd624 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -0,0 +1,929 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +/* Operation + * + * The main entry point is svc_rdma_sendto. This is called by the + * RPC server when an RPC Reply is ready to be transmitted to a client. + * + * The passed-in svc_rqst contains a struct xdr_buf which holds an + * XDR-encoded RPC Reply message. sendto must construct the RPC-over-RDMA + * transport header, post all Write WRs needed for this Reply, then post + * a Send WR conveying the transport header and the RPC message itself to + * the client. + * + * svc_rdma_sendto must fully transmit the Reply before returning, as + * the svc_rqst will be recycled as soon as sendto returns. Remaining + * resources referred to by the svc_rqst are also recycled at that time. + * Therefore any resources that must remain longer must be detached + * from the svc_rqst and released later. + * + * Page Management + * + * The I/O that performs Reply transmission is asynchronous, and may + * complete well after sendto returns. Thus pages under I/O must be + * removed from the svc_rqst before sendto returns. + * + * The logic here depends on Send Queue and completion ordering. Since + * the Send WR is always posted last, it will always complete last. Thus + * when it completes, it is guaranteed that all previous Write WRs have + * also completed. + * + * Write WRs are constructed and posted. Each Write segment gets its own + * svc_rdma_rw_ctxt, allowing the Write completion handler to find and + * DMA-unmap the pages under I/O for that Write segment. The Write + * completion handler does not release any pages. + * + * When the Send WR is constructed, it also gets its own svc_rdma_send_ctxt. + * The ownership of all of the Reply's pages are transferred into that + * ctxt, the Send WR is posted, and sendto returns. + * + * The svc_rdma_send_ctxt is presented when the Send WR completes. The + * Send completion handler finally releases the Reply's pages. + * + * This mechanism also assumes that completions on the transport's Send + * Completion Queue do not run in parallel. Otherwise a Write completion + * and Send completion running at the same time could release pages that + * are still DMA-mapped. + * + * Error Handling + * + * - If the Send WR is posted successfully, it will either complete + * successfully, or get flushed. Either way, the Send completion + * handler releases the Reply's pages. + * - If the Send WR cannot be not posted, the forward path releases + * the Reply's pages. + * + * This handles the case, without the use of page reference counting, + * where two different Write segments send portions of the same page. + */ + +#include <linux/spinlock.h> +#include <asm/unaligned.h> + +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> + +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc); + +static inline struct svc_rdma_send_ctxt * +svc_rdma_next_send_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_send_ctxt, + sc_list); +} + +static struct svc_rdma_send_ctxt * +svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + dma_addr_t addr; + void *buffer; + size_t size; + int i; + + size = sizeof(*ctxt); + size += rdma->sc_max_send_sges * sizeof(struct ib_sge); + ctxt = kmalloc(size, GFP_KERNEL); + if (!ctxt) + goto fail0; + buffer = kmalloc(rdma->sc_max_req_size, GFP_KERNEL); + if (!buffer) + goto fail1; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail2; + + ctxt->sc_send_wr.next = NULL; + ctxt->sc_send_wr.wr_cqe = &ctxt->sc_cqe; + ctxt->sc_send_wr.sg_list = ctxt->sc_sges; + ctxt->sc_send_wr.send_flags = IB_SEND_SIGNALED; + ctxt->sc_cqe.done = svc_rdma_wc_send; + ctxt->sc_xprt_buf = buffer; + ctxt->sc_sges[0].addr = addr; + + for (i = 0; i < rdma->sc_max_send_sges; i++) + ctxt->sc_sges[i].lkey = rdma->sc_pd->local_dma_lkey; + return ctxt; + +fail2: + kfree(buffer); +fail1: + kfree(ctxt); +fail0: + return NULL; +} + +/** + * svc_rdma_send_ctxts_destroy - Release all send_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down + * + */ +void svc_rdma_send_ctxts_destroy(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_send_ctxt(&rdma->sc_send_ctxts))) { + list_del(&ctxt->sc_list); + ib_dma_unmap_single(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, + rdma->sc_max_req_size, + DMA_TO_DEVICE); + kfree(ctxt->sc_xprt_buf); + kfree(ctxt); + } +} + +/** + * svc_rdma_send_ctxt_get - Get a free send_ctxt + * @rdma: controlling svcxprt_rdma + * + * Returns a ready-to-use send_ctxt, or NULL if none are + * available and a fresh one cannot be allocated. + */ +struct svc_rdma_send_ctxt *svc_rdma_send_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + + spin_lock(&rdma->sc_send_lock); + ctxt = svc_rdma_next_send_ctxt(&rdma->sc_send_ctxts); + if (!ctxt) + goto out_empty; + list_del(&ctxt->sc_list); + spin_unlock(&rdma->sc_send_lock); + +out: + ctxt->sc_send_wr.num_sge = 0; + ctxt->sc_cur_sge_no = 0; + ctxt->sc_page_count = 0; + return ctxt; + +out_empty: + spin_unlock(&rdma->sc_send_lock); + ctxt = svc_rdma_send_ctxt_alloc(rdma); + if (!ctxt) + return NULL; + goto out; +} + +/** + * svc_rdma_send_ctxt_put - Return send_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + * + * Pages left in sc_pages are DMA unmapped and released. + */ +void svc_rdma_send_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) +{ + struct ib_device *device = rdma->sc_cm_id->device; + unsigned int i; + + /* The first SGE contains the transport header, which + * remains mapped until @ctxt is destroyed. + */ + for (i = 1; i < ctxt->sc_send_wr.num_sge; i++) + ib_dma_unmap_page(device, + ctxt->sc_sges[i].addr, + ctxt->sc_sges[i].length, + DMA_TO_DEVICE); + + for (i = 0; i < ctxt->sc_page_count; ++i) + put_page(ctxt->sc_pages[i]); + + spin_lock(&rdma->sc_send_lock); + list_add(&ctxt->sc_list, &rdma->sc_send_ctxts); + spin_unlock(&rdma->sc_send_lock); +} + +/** + * svc_rdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: Completion Queue context + * @wc: Work Completion object + * + * NB: The svc_xprt/svcxprt_rdma is pinned whenever it's possible that + * the Send completion handler could be running. + */ +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_send_ctxt *ctxt; + + trace_svcrdma_wc_send(wc); + + atomic_inc(&rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + + ctxt = container_of(cqe, struct svc_rdma_send_ctxt, sc_cqe); + svc_rdma_send_ctxt_put(rdma, ctxt); + + if (unlikely(wc->status != IB_WC_SUCCESS)) { + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + svc_xprt_enqueue(&rdma->sc_xprt); + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("svcrdma: Send: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); + } + + svc_xprt_put(&rdma->sc_xprt); +} + +/** + * svc_rdma_send - Post a single Send WR + * @rdma: transport on which to post the WR + * @wr: prepared Send WR to post + * + * Returns zero the Send WR was posted successfully. Otherwise, a + * negative errno is returned. + */ +int svc_rdma_send(struct svcxprt_rdma *rdma, struct ib_send_wr *wr) +{ + int ret; + + might_sleep(); + + /* If the SQ is full, wait until an SQ entry is available */ + while (1) { + if ((atomic_dec_return(&rdma->sc_sq_avail) < 0)) { + atomic_inc(&rdma_stat_sq_starve); + trace_svcrdma_sq_full(rdma); + atomic_inc(&rdma->sc_sq_avail); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > 1); + if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) + return -ENOTCONN; + trace_svcrdma_sq_retry(rdma); + continue; + } + + svc_xprt_get(&rdma->sc_xprt); + trace_svcrdma_post_send(wr); + ret = ib_post_send(rdma->sc_qp, wr, NULL); + if (ret) + break; + return 0; + } + + trace_svcrdma_sq_post_err(rdma, ret); + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + svc_xprt_put(&rdma->sc_xprt); + wake_up(&rdma->sc_send_wait); + return ret; +} + +static u32 xdr_padsize(u32 len) +{ + return (len & 3) ? (4 - (len & 3)) : 0; +} + +/* Returns length of transport header, in bytes. + */ +static unsigned int svc_rdma_reply_hdr_len(__be32 *rdma_resp) +{ + unsigned int nsegs; + __be32 *p; + + p = rdma_resp; + + /* RPC-over-RDMA V1 replies never have a Read list. */ + p += rpcrdma_fixed_maxsz + 1; + + /* Skip Write list. */ + while (*p++ != xdr_zero) { + nsegs = be32_to_cpup(p++); + p += nsegs * rpcrdma_segment_maxsz; + } + + /* Skip Reply chunk. */ + if (*p++ != xdr_zero) { + nsegs = be32_to_cpup(p++); + p += nsegs * rpcrdma_segment_maxsz; + } + + return (unsigned long)p - (unsigned long)rdma_resp; +} + +/* One Write chunk is copied from Call transport header to Reply + * transport header. Each segment's length field is updated to + * reflect number of bytes consumed in the segment. + * + * Returns number of segments in this chunk. + */ +static unsigned int xdr_encode_write_chunk(__be32 *dst, __be32 *src, + unsigned int remaining) +{ + unsigned int i, nsegs; + u32 seg_len; + + /* Write list discriminator */ + *dst++ = *src++; + + /* number of segments in this chunk */ + nsegs = be32_to_cpup(src); + *dst++ = *src++; + + for (i = nsegs; i; i--) { + /* segment's RDMA handle */ + *dst++ = *src++; + + /* bytes returned in this segment */ + seg_len = be32_to_cpu(*src); + if (remaining >= seg_len) { + /* entire segment was consumed */ + *dst = *src; + remaining -= seg_len; + } else { + /* segment only partly filled */ + *dst = cpu_to_be32(remaining); + remaining = 0; + } + dst++; src++; + + /* segment's RDMA offset */ + *dst++ = *src++; + *dst++ = *src++; + } + + return nsegs; +} + +/* The client provided a Write list in the Call message. Fill in + * the segments in the first Write chunk in the Reply's transport + * header with the number of bytes consumed in each segment. + * Remaining chunks are returned unused. + * + * Assumptions: + * - Client has provided only one Write chunk + */ +static void svc_rdma_xdr_encode_write_list(__be32 *rdma_resp, __be32 *wr_ch, + unsigned int consumed) +{ + unsigned int nsegs; + __be32 *p, *q; + + /* RPC-over-RDMA V1 replies never have a Read list. */ + p = rdma_resp + rpcrdma_fixed_maxsz + 1; + + q = wr_ch; + while (*q != xdr_zero) { + nsegs = xdr_encode_write_chunk(p, q, consumed); + q += 2 + nsegs * rpcrdma_segment_maxsz; + p += 2 + nsegs * rpcrdma_segment_maxsz; + consumed = 0; + } + + /* Terminate Write list */ + *p++ = xdr_zero; + + /* Reply chunk discriminator; may be replaced later */ + *p = xdr_zero; +} + +/* The client provided a Reply chunk in the Call message. Fill in + * the segments in the Reply chunk in the Reply message with the + * number of bytes consumed in each segment. + * + * Assumptions: + * - Reply can always fit in the provided Reply chunk + */ +static void svc_rdma_xdr_encode_reply_chunk(__be32 *rdma_resp, __be32 *rp_ch, + unsigned int consumed) +{ + __be32 *p; + + /* Find the Reply chunk in the Reply's xprt header. + * RPC-over-RDMA V1 replies never have a Read list. + */ + p = rdma_resp + rpcrdma_fixed_maxsz + 1; + + /* Skip past Write list */ + while (*p++ != xdr_zero) + p += 1 + be32_to_cpup(p) * rpcrdma_segment_maxsz; + + xdr_encode_write_chunk(p, rp_ch, consumed); +} + +/* Parse the RPC Call's transport header. + */ +static void svc_rdma_get_write_arrays(__be32 *rdma_argp, + __be32 **write, __be32 **reply) +{ + __be32 *p; + + p = rdma_argp + rpcrdma_fixed_maxsz; + + /* Read list */ + while (*p++ != xdr_zero) + p += 5; + + /* Write list */ + if (*p != xdr_zero) { + *write = p; + while (*p++ != xdr_zero) + p += 1 + be32_to_cpu(*p) * 4; + } else { + *write = NULL; + p++; + } + + /* Reply chunk */ + if (*p != xdr_zero) + *reply = p; + else + *reply = NULL; +} + +/* RPC-over-RDMA Version One private extension: Remote Invalidation. + * Responder's choice: requester signals it can handle Send With + * Invalidate, and responder chooses one rkey to invalidate. + * + * Find a candidate rkey to invalidate when sending a reply. Picks the + * first R_key it finds in the chunk lists. + * + * Returns zero if RPC's chunk lists are empty. + */ +static u32 svc_rdma_get_inv_rkey(__be32 *rdma_argp, + __be32 *wr_lst, __be32 *rp_ch) +{ + __be32 *p; + + p = rdma_argp + rpcrdma_fixed_maxsz; + if (*p != xdr_zero) + p += 2; + else if (wr_lst && be32_to_cpup(wr_lst + 1)) + p = wr_lst + 2; + else if (rp_ch && be32_to_cpup(rp_ch + 1)) + p = rp_ch + 2; + else + return 0; + return be32_to_cpup(p); +} + +static int svc_rdma_dma_map_page(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + struct page *page, + unsigned long offset, + unsigned int len) +{ + struct ib_device *dev = rdma->sc_cm_id->device; + dma_addr_t dma_addr; + + dma_addr = ib_dma_map_page(dev, page, offset, len, DMA_TO_DEVICE); + if (ib_dma_mapping_error(dev, dma_addr)) + goto out_maperr; + + ctxt->sc_sges[ctxt->sc_cur_sge_no].addr = dma_addr; + ctxt->sc_sges[ctxt->sc_cur_sge_no].length = len; + ctxt->sc_send_wr.num_sge++; + return 0; + +out_maperr: + trace_svcrdma_dma_map_page(rdma, page); + return -EIO; +} + +/* ib_dma_map_page() is used here because svc_rdma_dma_unmap() + * handles DMA-unmap and it uses ib_dma_unmap_page() exclusively. + */ +static int svc_rdma_dma_map_buf(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + unsigned char *base, + unsigned int len) +{ + return svc_rdma_dma_map_page(rdma, ctxt, virt_to_page(base), + offset_in_page(base), len); +} + +/** + * svc_rdma_sync_reply_hdr - DMA sync the transport header buffer + * @rdma: controlling transport + * @ctxt: send_ctxt for the Send WR + * @len: length of transport header + * + */ +void svc_rdma_sync_reply_hdr(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + unsigned int len) +{ + ctxt->sc_sges[0].length = len; + ctxt->sc_send_wr.num_sge++; + ib_dma_sync_single_for_device(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, len, + DMA_TO_DEVICE); +} + +/* If the xdr_buf has more elements than the device can + * transmit in a single RDMA Send, then the reply will + * have to be copied into a bounce buffer. + */ +static bool svc_rdma_pull_up_needed(struct svcxprt_rdma *rdma, + struct xdr_buf *xdr, + __be32 *wr_lst) +{ + int elements; + + /* xdr->head */ + elements = 1; + + /* xdr->pages */ + if (!wr_lst) { + unsigned int remaining; + unsigned long pageoff; + + pageoff = xdr->page_base & ~PAGE_MASK; + remaining = xdr->page_len; + while (remaining) { + ++elements; + remaining -= min_t(u32, PAGE_SIZE - pageoff, + remaining); + pageoff = 0; + } + } + + /* xdr->tail */ + if (xdr->tail[0].iov_len) + ++elements; + + /* assume 1 SGE is needed for the transport header */ + return elements >= rdma->sc_max_send_sges; +} + +/* The device is not capable of sending the reply directly. + * Assemble the elements of @xdr into the transport header + * buffer. + */ +static int svc_rdma_pull_up_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + struct xdr_buf *xdr, __be32 *wr_lst) +{ + unsigned char *dst, *tailbase; + unsigned int taillen; + + dst = ctxt->sc_xprt_buf; + dst += ctxt->sc_sges[0].length; + + memcpy(dst, xdr->head[0].iov_base, xdr->head[0].iov_len); + dst += xdr->head[0].iov_len; + + tailbase = xdr->tail[0].iov_base; + taillen = xdr->tail[0].iov_len; + if (wr_lst) { + u32 xdrpad; + + xdrpad = xdr_padsize(xdr->page_len); + if (taillen && xdrpad) { + tailbase += xdrpad; + taillen -= xdrpad; + } + } else { + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = xdr->page_base & ~PAGE_MASK; + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + + memcpy(dst, page_address(*ppages) + pageoff, len); + remaining -= len; + dst += len; + pageoff = 0; + ppages++; + } + } + + if (taillen) + memcpy(dst, tailbase, taillen); + + ctxt->sc_sges[0].length += xdr->len; + ib_dma_sync_single_for_device(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, + ctxt->sc_sges[0].length, + DMA_TO_DEVICE); + + return 0; +} + +/* svc_rdma_map_reply_msg - Map the buffer holding RPC message + * @rdma: controlling transport + * @ctxt: send_ctxt for the Send WR + * @xdr: prepared xdr_buf containing RPC message + * @wr_lst: pointer to Call header's Write list, or NULL + * + * Load the xdr_buf into the ctxt's sge array, and DMA map each + * element as it is added. + * + * Returns zero on success, or a negative errno on failure. + */ +int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + struct xdr_buf *xdr, __be32 *wr_lst) +{ + unsigned int len, remaining; + unsigned long page_off; + struct page **ppages; + unsigned char *base; + u32 xdr_pad; + int ret; + + if (svc_rdma_pull_up_needed(rdma, xdr, wr_lst)) + return svc_rdma_pull_up_reply_msg(rdma, ctxt, xdr, wr_lst); + + ++ctxt->sc_cur_sge_no; + ret = svc_rdma_dma_map_buf(rdma, ctxt, + xdr->head[0].iov_base, + xdr->head[0].iov_len); + if (ret < 0) + return ret; + + /* If a Write chunk is present, the xdr_buf's page list + * is not included inline. However the Upper Layer may + * have added XDR padding in the tail buffer, and that + * should not be included inline. + */ + if (wr_lst) { + base = xdr->tail[0].iov_base; + len = xdr->tail[0].iov_len; + xdr_pad = xdr_padsize(xdr->page_len); + + if (len && xdr_pad) { + base += xdr_pad; + len -= xdr_pad; + } + + goto tail; + } + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_off = xdr->page_base & ~PAGE_MASK; + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - page_off, remaining); + + ++ctxt->sc_cur_sge_no; + ret = svc_rdma_dma_map_page(rdma, ctxt, *ppages++, + page_off, len); + if (ret < 0) + return ret; + + remaining -= len; + page_off = 0; + } + + base = xdr->tail[0].iov_base; + len = xdr->tail[0].iov_len; +tail: + if (len) { + ++ctxt->sc_cur_sge_no; + ret = svc_rdma_dma_map_buf(rdma, ctxt, base, len); + if (ret < 0) + return ret; + } + + return 0; +} + +/* The svc_rqst and all resources it owns are released as soon as + * svc_rdma_sendto returns. Transfer pages under I/O to the ctxt + * so they are released by the Send completion handler. + */ +static void svc_rdma_save_io_pages(struct svc_rqst *rqstp, + struct svc_rdma_send_ctxt *ctxt) +{ + int i, pages = rqstp->rq_next_page - rqstp->rq_respages; + + ctxt->sc_page_count += pages; + for (i = 0; i < pages; i++) { + ctxt->sc_pages[i] = rqstp->rq_respages[i]; + rqstp->rq_respages[i] = NULL; + } + + /* Prevent svc_xprt_release from releasing pages in rq_pages */ + rqstp->rq_next_page = rqstp->rq_respages; +} + +/* Prepare the portion of the RPC Reply that will be transmitted + * via RDMA Send. The RPC-over-RDMA transport header is prepared + * in sc_sges[0], and the RPC xdr_buf is prepared in following sges. + * + * Depending on whether a Write list or Reply chunk is present, + * the server may send all, a portion of, or none of the xdr_buf. + * In the latter case, only the transport header (sc_sges[0]) is + * transmitted. + * + * RDMA Send is the last step of transmitting an RPC reply. Pages + * involved in the earlier RDMA Writes are here transferred out + * of the rqstp and into the ctxt's page array. These pages are + * DMA unmapped by each Write completion, but the subsequent Send + * completion finally releases these pages. + * + * Assumptions: + * - The Reply's transport header will never be larger than a page. + */ +static int svc_rdma_send_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + __be32 *rdma_argp, + struct svc_rqst *rqstp, + __be32 *wr_lst, __be32 *rp_ch) +{ + int ret; + + if (!rp_ch) { + ret = svc_rdma_map_reply_msg(rdma, ctxt, + &rqstp->rq_res, wr_lst); + if (ret < 0) + return ret; + } + + svc_rdma_save_io_pages(rqstp, ctxt); + + ctxt->sc_send_wr.opcode = IB_WR_SEND; + if (rdma->sc_snd_w_inv) { + ctxt->sc_send_wr.ex.invalidate_rkey = + svc_rdma_get_inv_rkey(rdma_argp, wr_lst, rp_ch); + if (ctxt->sc_send_wr.ex.invalidate_rkey) + ctxt->sc_send_wr.opcode = IB_WR_SEND_WITH_INV; + } + dprintk("svcrdma: posting Send WR with %u sge(s)\n", + ctxt->sc_send_wr.num_sge); + return svc_rdma_send(rdma, &ctxt->sc_send_wr); +} + +/* Given the client-provided Write and Reply chunks, the server was not + * able to form a complete reply. Return an RDMA_ERROR message so the + * client can retire this RPC transaction. As above, the Send completion + * routine releases payload pages that were part of a previous RDMA Write. + * + * Remote Invalidation is skipped for simplicity. + */ +static int svc_rdma_send_error_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + struct svc_rqst *rqstp) +{ + __be32 *p; + int ret; + + p = ctxt->sc_xprt_buf; + trace_svcrdma_err_chunk(*p); + p += 3; + *p++ = rdma_error; + *p = err_chunk; + svc_rdma_sync_reply_hdr(rdma, ctxt, RPCRDMA_HDRLEN_ERR); + + svc_rdma_save_io_pages(rqstp, ctxt); + + ctxt->sc_send_wr.opcode = IB_WR_SEND; + ret = svc_rdma_send(rdma, &ctxt->sc_send_wr); + if (ret) { + svc_rdma_send_ctxt_put(rdma, ctxt); + return ret; + } + + return 0; +} + +void svc_rdma_prep_reply_hdr(struct svc_rqst *rqstp) +{ +} + +/** + * svc_rdma_sendto - Transmit an RPC reply + * @rqstp: processed RPC request, reply XDR already in ::rq_res + * + * Any resources still associated with @rqstp are released upon return. + * If no reply message was possible, the connection is closed. + * + * Returns: + * %0 if an RPC reply has been successfully posted, + * %-ENOMEM if a resource shortage occurred (connection is lost), + * %-ENOTCONN if posting failed (connection is lost). + */ +int svc_rdma_sendto(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt; + __be32 *p, *rdma_argp, *rdma_resp, *wr_lst, *rp_ch; + struct xdr_buf *xdr = &rqstp->rq_res; + struct svc_rdma_send_ctxt *sctxt; + int ret; + + rdma_argp = rctxt->rc_recv_buf; + svc_rdma_get_write_arrays(rdma_argp, &wr_lst, &rp_ch); + + /* Create the RDMA response header. xprt->xpt_mutex, + * acquired in svc_send(), serializes RPC replies. The + * code path below that inserts the credit grant value + * into each transport header runs only inside this + * critical section. + */ + ret = -ENOMEM; + sctxt = svc_rdma_send_ctxt_get(rdma); + if (!sctxt) + goto err0; + rdma_resp = sctxt->sc_xprt_buf; + + p = rdma_resp; + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = rdma->sc_fc_credits; + *p++ = rp_ch ? rdma_nomsg : rdma_msg; + + /* Start with empty chunks */ + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + if (wr_lst) { + /* XXX: Presume the client sent only one Write chunk */ + ret = svc_rdma_send_write_chunk(rdma, wr_lst, xdr); + if (ret < 0) + goto err2; + svc_rdma_xdr_encode_write_list(rdma_resp, wr_lst, ret); + } + if (rp_ch) { + ret = svc_rdma_send_reply_chunk(rdma, rp_ch, wr_lst, xdr); + if (ret < 0) + goto err2; + svc_rdma_xdr_encode_reply_chunk(rdma_resp, rp_ch, ret); + } + + svc_rdma_sync_reply_hdr(rdma, sctxt, svc_rdma_reply_hdr_len(rdma_resp)); + ret = svc_rdma_send_reply_msg(rdma, sctxt, rdma_argp, rqstp, + wr_lst, rp_ch); + if (ret < 0) + goto err1; + return 0; + + err2: + if (ret != -E2BIG && ret != -EINVAL) + goto err1; + + ret = svc_rdma_send_error_msg(rdma, sctxt, rqstp); + if (ret < 0) + goto err1; + return 0; + + err1: + svc_rdma_send_ctxt_put(rdma, sctxt); + err0: + trace_svcrdma_send_failed(rqstp, ret); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + return -ENOTCONN; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c new file mode 100644 index 000000000..f1824303e --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -0,0 +1,717 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/interrupt.h> +#include <linux/sched.h> +#include <linux/slab.h> +#include <linux/spinlock.h> +#include <linux/workqueue.h> +#include <linux/export.h> + +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> +#include <rdma/rw.h> + +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net); +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags); +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt); +static void svc_rdma_detach(struct svc_xprt *xprt); +static void svc_rdma_free(struct svc_xprt *xprt); +static int svc_rdma_has_wspace(struct svc_xprt *xprt); +static void svc_rdma_secure_port(struct svc_rqst *); +static void svc_rdma_kill_temp_xprt(struct svc_xprt *); + +static const struct svc_xprt_ops svc_rdma_ops = { + .xpo_create = svc_rdma_create, + .xpo_recvfrom = svc_rdma_recvfrom, + .xpo_sendto = svc_rdma_sendto, + .xpo_release_rqst = svc_rdma_release_rqst, + .xpo_detach = svc_rdma_detach, + .xpo_free = svc_rdma_free, + .xpo_prep_reply_hdr = svc_rdma_prep_reply_hdr, + .xpo_has_wspace = svc_rdma_has_wspace, + .xpo_accept = svc_rdma_accept, + .xpo_secure_port = svc_rdma_secure_port, + .xpo_kill_temp_xprt = svc_rdma_kill_temp_xprt, +}; + +struct svc_xprt_class svc_rdma_class = { + .xcl_name = "rdma", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_rdma_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_RDMA, + .xcl_ident = XPRT_TRANSPORT_RDMA, +}; + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static struct svc_xprt *svc_rdma_bc_create(struct svc_serv *, struct net *, + struct sockaddr *, int, int); +static void svc_rdma_bc_detach(struct svc_xprt *); +static void svc_rdma_bc_free(struct svc_xprt *); + +static const struct svc_xprt_ops svc_rdma_bc_ops = { + .xpo_create = svc_rdma_bc_create, + .xpo_detach = svc_rdma_bc_detach, + .xpo_free = svc_rdma_bc_free, + .xpo_prep_reply_hdr = svc_rdma_prep_reply_hdr, + .xpo_secure_port = svc_rdma_secure_port, +}; + +struct svc_xprt_class svc_rdma_bc_class = { + .xcl_name = "rdma-bc", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_rdma_bc_ops, + .xcl_max_payload = (1024 - RPCRDMA_HDRLEN_MIN) +}; + +static struct svc_xprt *svc_rdma_bc_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + struct svcxprt_rdma *cma_xprt; + struct svc_xprt *xprt; + + cma_xprt = svc_rdma_create_xprt(serv, net); + if (!cma_xprt) + return ERR_PTR(-ENOMEM); + xprt = &cma_xprt->sc_xprt; + + svc_xprt_init(net, &svc_rdma_bc_class, xprt, serv); + set_bit(XPT_CONG_CTRL, &xprt->xpt_flags); + serv->sv_bc_xprt = xprt; + + dprintk("svcrdma: %s(%p)\n", __func__, xprt); + return xprt; +} + +static void svc_rdma_bc_detach(struct svc_xprt *xprt) +{ + dprintk("svcrdma: %s(%p)\n", __func__, xprt); +} + +static void svc_rdma_bc_free(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + dprintk("svcrdma: %s(%p)\n", __func__, xprt); + if (xprt) + kfree(rdma); +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/* QP event handler */ +static void qp_event_handler(struct ib_event *event, void *context) +{ + struct svc_xprt *xprt = context; + + trace_svcrdma_qp_error(event, (struct sockaddr *)&xprt->xpt_remote); + switch (event->event) { + /* These are considered benign events */ + case IB_EVENT_PATH_MIG: + case IB_EVENT_COMM_EST: + case IB_EVENT_SQ_DRAINED: + case IB_EVENT_QP_LAST_WQE_REACHED: + break; + + /* These are considered fatal events */ + case IB_EVENT_PATH_MIG_ERR: + case IB_EVENT_QP_FATAL: + case IB_EVENT_QP_REQ_ERR: + case IB_EVENT_QP_ACCESS_ERR: + case IB_EVENT_DEVICE_FATAL: + default: + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + break; + } +} + +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net) +{ + struct svcxprt_rdma *cma_xprt = kzalloc(sizeof *cma_xprt, GFP_KERNEL); + + if (!cma_xprt) { + dprintk("svcrdma: failed to create new transport\n"); + return NULL; + } + svc_xprt_init(net, &svc_rdma_class, &cma_xprt->sc_xprt, serv); + INIT_LIST_HEAD(&cma_xprt->sc_accept_q); + INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q); + INIT_LIST_HEAD(&cma_xprt->sc_read_complete_q); + INIT_LIST_HEAD(&cma_xprt->sc_send_ctxts); + INIT_LIST_HEAD(&cma_xprt->sc_recv_ctxts); + INIT_LIST_HEAD(&cma_xprt->sc_rw_ctxts); + init_waitqueue_head(&cma_xprt->sc_send_wait); + + spin_lock_init(&cma_xprt->sc_lock); + spin_lock_init(&cma_xprt->sc_rq_dto_lock); + spin_lock_init(&cma_xprt->sc_send_lock); + spin_lock_init(&cma_xprt->sc_recv_lock); + spin_lock_init(&cma_xprt->sc_rw_ctxt_lock); + + /* + * Note that this implies that the underlying transport support + * has some form of congestion control (see RFC 7530 section 3.1 + * paragraph 2). For now, we assume that all supported RDMA + * transports are suitable here. + */ + set_bit(XPT_CONG_CTRL, &cma_xprt->sc_xprt.xpt_flags); + + return cma_xprt; +} + +static void +svc_rdma_parse_connect_private(struct svcxprt_rdma *newxprt, + struct rdma_conn_param *param) +{ + const struct rpcrdma_connect_private *pmsg = param->private_data; + + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + newxprt->sc_snd_w_inv = pmsg->cp_flags & + RPCRDMA_CMP_F_SND_W_INV_OK; + + dprintk("svcrdma: client send_size %u, recv_size %u " + "remote inv %ssupported\n", + rpcrdma_decode_buffer_size(pmsg->cp_send_size), + rpcrdma_decode_buffer_size(pmsg->cp_recv_size), + newxprt->sc_snd_w_inv ? "" : "un"); + } +} + +/* + * This function handles the CONNECT_REQUEST event on a listening + * endpoint. It is passed the cma_id for the _new_ connection. The context in + * this cma_id is inherited from the listening cma_id and is the svc_xprt + * structure for the listening endpoint. + * + * This function creates a new xprt for the new connection and enqueues it on + * the accept queue for the listent xprt. When the listen thread is kicked, it + * will call the recvfrom method on the listen xprt which will accept the new + * connection. + */ +static void handle_connect_req(struct rdma_cm_id *new_cma_id, + struct rdma_conn_param *param) +{ + struct svcxprt_rdma *listen_xprt = new_cma_id->context; + struct svcxprt_rdma *newxprt; + struct sockaddr *sa; + + /* Create a new transport */ + newxprt = svc_rdma_create_xprt(listen_xprt->sc_xprt.xpt_server, + listen_xprt->sc_xprt.xpt_net); + if (!newxprt) + return; + newxprt->sc_cm_id = new_cma_id; + new_cma_id->context = newxprt; + svc_rdma_parse_connect_private(newxprt, param); + + /* Save client advertised inbound read limit for use later in accept. */ + newxprt->sc_ord = param->initiator_depth; + + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + svc_xprt_set_remote(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + /* The remote port is arbitrary and not under the control of the + * client ULP. Set it to a fixed value so that the DRC continues + * to be effective after a reconnect. + */ + rpc_set_port((struct sockaddr *)&newxprt->sc_xprt.xpt_remote, 0); + + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + + /* + * Enqueue the new transport on the accept queue of the listening + * transport + */ + spin_lock_bh(&listen_xprt->sc_lock); + list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q); + spin_unlock_bh(&listen_xprt->sc_lock); + + set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags); + svc_xprt_enqueue(&listen_xprt->sc_xprt); +} + +/* + * Handles events generated on the listening endpoint. These events will be + * either be incoming connect requests or adapter removal events. + */ +static int rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + struct sockaddr *sap = (struct sockaddr *)&cma_id->route.addr.src_addr; + + trace_svcrdma_cm_event(event, sap); + + switch (event->event) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + dprintk("svcrdma: Connect request on cma_id=%p, xprt = %p, " + "event = %s (%d)\n", cma_id, cma_id->context, + rdma_event_msg(event->event), event->event); + handle_connect_req(cma_id, &event->param.conn); + break; + default: + /* NB: No device removal upcall for INADDR_ANY listeners */ + dprintk("svcrdma: Unexpected event on listening endpoint %p, " + "event = %s (%d)\n", cma_id, + rdma_event_msg(event->event), event->event); + break; + } + + return 0; +} + +static int rdma_cma_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + struct sockaddr *sap = (struct sockaddr *)&cma_id->route.addr.dst_addr; + struct svcxprt_rdma *rdma = cma_id->context; + struct svc_xprt *xprt = &rdma->sc_xprt; + + trace_svcrdma_cm_event(event, sap); + + switch (event->event) { + case RDMA_CM_EVENT_ESTABLISHED: + /* Accept complete */ + svc_xprt_get(xprt); + dprintk("svcrdma: Connection completed on DTO xprt=%p, " + "cm_id=%p\n", xprt, cma_id); + clear_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags); + svc_xprt_enqueue(xprt); + break; + case RDMA_CM_EVENT_DISCONNECTED: + dprintk("svcrdma: Disconnect on DTO xprt=%p, cm_id=%p\n", + xprt, cma_id); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + break; + case RDMA_CM_EVENT_DEVICE_REMOVAL: + dprintk("svcrdma: Device removal cma_id=%p, xprt = %p, " + "event = %s (%d)\n", cma_id, xprt, + rdma_event_msg(event->event), event->event); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + break; + default: + dprintk("svcrdma: Unexpected event on DTO endpoint %p, " + "event = %s (%d)\n", cma_id, + rdma_event_msg(event->event), event->event); + break; + } + return 0; +} + +/* + * Create a listening RDMA service endpoint. + */ +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + struct rdma_cm_id *listen_id; + struct svcxprt_rdma *cma_xprt; + int ret; + + dprintk("svcrdma: Creating RDMA listener\n"); + if ((sa->sa_family != AF_INET) && (sa->sa_family != AF_INET6)) { + dprintk("svcrdma: Address family %d is not supported.\n", sa->sa_family); + return ERR_PTR(-EAFNOSUPPORT); + } + cma_xprt = svc_rdma_create_xprt(serv, net); + if (!cma_xprt) + return ERR_PTR(-ENOMEM); + set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); + strcpy(cma_xprt->sc_xprt.xpt_remotebuf, "listener"); + + listen_id = rdma_create_id(net, rdma_listen_handler, cma_xprt, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(listen_id)) { + ret = PTR_ERR(listen_id); + dprintk("svcrdma: rdma_create_id failed = %d\n", ret); + goto err0; + } + + /* Allow both IPv4 and IPv6 sockets to bind a single port + * at the same time. + */ +#if IS_ENABLED(CONFIG_IPV6) + ret = rdma_set_afonly(listen_id, 1); + if (ret) { + dprintk("svcrdma: rdma_set_afonly failed = %d\n", ret); + goto err1; + } +#endif + ret = rdma_bind_addr(listen_id, sa); + if (ret) { + dprintk("svcrdma: rdma_bind_addr failed = %d\n", ret); + goto err1; + } + cma_xprt->sc_cm_id = listen_id; + + ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); + if (ret) { + dprintk("svcrdma: rdma_listen failed = %d\n", ret); + goto err1; + } + + /* + * We need to use the address from the cm_id in case the + * caller specified 0 for the port number. + */ + sa = (struct sockaddr *)&cma_xprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen); + + return &cma_xprt->sc_xprt; + + err1: + rdma_destroy_id(listen_id); + err0: + kfree(cma_xprt); + return ERR_PTR(ret); +} + +/* + * This is the xpo_recvfrom function for listening endpoints. Its + * purpose is to accept incoming connections. The CMA callback handler + * has already created a new transport and attached it to the new CMA + * ID. + * + * There is a queue of pending connections hung on the listening + * transport. This queue contains the new svc_xprt structure. This + * function takes svc_xprt structures off the accept_q and completes + * the connection. + */ +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *listen_rdma; + struct svcxprt_rdma *newxprt = NULL; + struct rdma_conn_param conn_param; + struct rpcrdma_connect_private pmsg; + struct ib_qp_init_attr qp_attr; + unsigned int ctxts, rq_depth; + struct ib_device *dev; + struct sockaddr *sap; + int ret = 0; + + listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); + clear_bit(XPT_CONN, &xprt->xpt_flags); + /* Get the next entry off the accept list */ + spin_lock_bh(&listen_rdma->sc_lock); + if (!list_empty(&listen_rdma->sc_accept_q)) { + newxprt = list_entry(listen_rdma->sc_accept_q.next, + struct svcxprt_rdma, sc_accept_q); + list_del_init(&newxprt->sc_accept_q); + } + if (!list_empty(&listen_rdma->sc_accept_q)) + set_bit(XPT_CONN, &listen_rdma->sc_xprt.xpt_flags); + spin_unlock_bh(&listen_rdma->sc_lock); + if (!newxprt) + return NULL; + + dprintk("svcrdma: newxprt from accept queue = %p, cm_id=%p\n", + newxprt, newxprt->sc_cm_id); + + dev = newxprt->sc_cm_id->device; + newxprt->sc_port_num = newxprt->sc_cm_id->port_num; + + /* Qualify the transport resource defaults with the + * capabilities of this particular device */ + /* Transport header, head iovec, tail iovec */ + newxprt->sc_max_send_sges = 3; + /* Add one SGE per page list entry */ + newxprt->sc_max_send_sges += (svcrdma_max_req_size / PAGE_SIZE) + 1; + if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) + newxprt->sc_max_send_sges = dev->attrs.max_send_sge; + newxprt->sc_max_req_size = svcrdma_max_req_size; + newxprt->sc_max_requests = svcrdma_max_requests; + newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; + rq_depth = newxprt->sc_max_requests + newxprt->sc_max_bc_requests; + if (rq_depth > dev->attrs.max_qp_wr) { + pr_warn("svcrdma: reducing receive depth to %d\n", + dev->attrs.max_qp_wr); + rq_depth = dev->attrs.max_qp_wr; + newxprt->sc_max_requests = rq_depth - 2; + newxprt->sc_max_bc_requests = 2; + } + newxprt->sc_fc_credits = cpu_to_be32(newxprt->sc_max_requests); + ctxts = rdma_rw_mr_factor(dev, newxprt->sc_port_num, RPCSVC_MAXPAGES); + ctxts *= newxprt->sc_max_requests; + newxprt->sc_sq_depth = rq_depth + ctxts; + if (newxprt->sc_sq_depth > dev->attrs.max_qp_wr) { + pr_warn("svcrdma: reducing send depth to %d\n", + dev->attrs.max_qp_wr); + newxprt->sc_sq_depth = dev->attrs.max_qp_wr; + } + atomic_set(&newxprt->sc_sq_avail, newxprt->sc_sq_depth); + + newxprt->sc_pd = ib_alloc_pd(dev, 0); + if (IS_ERR(newxprt->sc_pd)) { + dprintk("svcrdma: error creating PD for connect request\n"); + goto errout; + } + newxprt->sc_sq_cq = ib_alloc_cq(dev, newxprt, newxprt->sc_sq_depth, + 0, IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_sq_cq)) { + dprintk("svcrdma: error creating SQ CQ for connect request\n"); + goto errout; + } + newxprt->sc_rq_cq = ib_alloc_cq(dev, newxprt, rq_depth, + 0, IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_rq_cq)) { + dprintk("svcrdma: error creating RQ CQ for connect request\n"); + goto errout; + } + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.event_handler = qp_event_handler; + qp_attr.qp_context = &newxprt->sc_xprt; + qp_attr.port_num = newxprt->sc_port_num; + qp_attr.cap.max_rdma_ctxs = ctxts; + qp_attr.cap.max_send_wr = newxprt->sc_sq_depth - ctxts; + qp_attr.cap.max_recv_wr = rq_depth; + qp_attr.cap.max_send_sge = newxprt->sc_max_send_sges; + qp_attr.cap.max_recv_sge = 1; + qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + qp_attr.qp_type = IB_QPT_RC; + qp_attr.send_cq = newxprt->sc_sq_cq; + qp_attr.recv_cq = newxprt->sc_rq_cq; + dprintk("svcrdma: newxprt->sc_cm_id=%p, newxprt->sc_pd=%p\n", + newxprt->sc_cm_id, newxprt->sc_pd); + dprintk(" cap.max_send_wr = %d, cap.max_recv_wr = %d\n", + qp_attr.cap.max_send_wr, qp_attr.cap.max_recv_wr); + dprintk(" cap.max_send_sge = %d, cap.max_recv_sge = %d\n", + qp_attr.cap.max_send_sge, qp_attr.cap.max_recv_sge); + + ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, &qp_attr); + if (ret) { + dprintk("svcrdma: failed to create QP, ret=%d\n", ret); + goto errout; + } + newxprt->sc_qp = newxprt->sc_cm_id->qp; + + if (!(dev->attrs.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) + newxprt->sc_snd_w_inv = false; + if (!rdma_protocol_iwarp(dev, newxprt->sc_port_num) && + !rdma_ib_or_roce(dev, newxprt->sc_port_num)) + goto errout; + + if (!svc_rdma_post_recvs(newxprt)) + goto errout; + + /* Swap out the handler */ + newxprt->sc_cm_id->event_handler = rdma_cma_handler; + + /* Construct RDMA-CM private message */ + pmsg.cp_magic = rpcrdma_cmp_magic; + pmsg.cp_version = RPCRDMA_CMP_VERSION; + pmsg.cp_flags = 0; + pmsg.cp_send_size = pmsg.cp_recv_size = + rpcrdma_encode_buffer_size(newxprt->sc_max_req_size); + + /* Accept Connection */ + set_bit(RDMAXPRT_CONN_PENDING, &newxprt->sc_flags); + memset(&conn_param, 0, sizeof conn_param); + conn_param.responder_resources = 0; + conn_param.initiator_depth = min_t(int, newxprt->sc_ord, + dev->attrs.max_qp_init_rd_atom); + if (!conn_param.initiator_depth) { + dprintk("svcrdma: invalid ORD setting\n"); + ret = -EINVAL; + goto errout; + } + conn_param.private_data = &pmsg; + conn_param.private_data_len = sizeof(pmsg); + ret = rdma_accept(newxprt->sc_cm_id, &conn_param); + if (ret) + goto errout; + + dprintk("svcrdma: new connection %p accepted:\n", newxprt); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + dprintk(" local address : %pIS:%u\n", sap, rpc_get_port(sap)); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + dprintk(" remote address : %pIS:%u\n", sap, rpc_get_port(sap)); + dprintk(" max_sge : %d\n", newxprt->sc_max_send_sges); + dprintk(" sq_depth : %d\n", newxprt->sc_sq_depth); + dprintk(" rdma_rw_ctxs : %d\n", ctxts); + dprintk(" max_requests : %d\n", newxprt->sc_max_requests); + dprintk(" ord : %d\n", conn_param.initiator_depth); + + trace_svcrdma_xprt_accept(&newxprt->sc_xprt); + return &newxprt->sc_xprt; + + errout: + dprintk("svcrdma: failure accepting new connection rc=%d.\n", ret); + trace_svcrdma_xprt_fail(&newxprt->sc_xprt); + /* Take a reference in case the DTO handler runs */ + svc_xprt_get(&newxprt->sc_xprt); + if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp)) + ib_destroy_qp(newxprt->sc_qp); + rdma_destroy_id(newxprt->sc_cm_id); + /* This call to put will destroy the transport */ + svc_xprt_put(&newxprt->sc_xprt); + return NULL; +} + +/* + * When connected, an svc_xprt has at least two references: + * + * - A reference held by the cm_id between the ESTABLISHED and + * DISCONNECTED events. If the remote peer disconnected first, this + * reference could be gone. + * + * - A reference held by the svc_recv code that called this function + * as part of close processing. + * + * At a minimum one references should still be held. + */ +static void svc_rdma_detach(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + /* Disconnect and flush posted WQE */ + rdma_disconnect(rdma->sc_cm_id); +} + +static void __svc_rdma_free(struct work_struct *work) +{ + struct svcxprt_rdma *rdma = + container_of(work, struct svcxprt_rdma, sc_work); + struct svc_xprt *xprt = &rdma->sc_xprt; + + trace_svcrdma_xprt_free(xprt); + + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_drain_qp(rdma->sc_qp); + + /* We should only be called from kref_put */ + if (kref_read(&xprt->xpt_ref) != 0) + pr_err("svcrdma: sc_xprt still in use? (%d)\n", + kref_read(&xprt->xpt_ref)); + + svc_rdma_flush_recv_queues(rdma); + + /* Final put of backchannel client transport */ + if (xprt->xpt_bc_xprt) { + xprt_put(xprt->xpt_bc_xprt); + xprt->xpt_bc_xprt = NULL; + } + + svc_rdma_destroy_rw_ctxts(rdma); + svc_rdma_send_ctxts_destroy(rdma); + svc_rdma_recv_ctxts_destroy(rdma); + + /* Destroy the QP if present (not a listener) */ + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_destroy_qp(rdma->sc_qp); + + if (rdma->sc_sq_cq && !IS_ERR(rdma->sc_sq_cq)) + ib_free_cq(rdma->sc_sq_cq); + + if (rdma->sc_rq_cq && !IS_ERR(rdma->sc_rq_cq)) + ib_free_cq(rdma->sc_rq_cq); + + if (rdma->sc_pd && !IS_ERR(rdma->sc_pd)) + ib_dealloc_pd(rdma->sc_pd); + + /* Destroy the CM ID */ + rdma_destroy_id(rdma->sc_cm_id); + + kfree(rdma); +} + +static void svc_rdma_free(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + INIT_WORK(&rdma->sc_work, __svc_rdma_free); + queue_work(svc_rdma_wq, &rdma->sc_work); +} + +static int svc_rdma_has_wspace(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + /* + * If there are already waiters on the SQ, + * return false. + */ + if (waitqueue_active(&rdma->sc_send_wait)) + return 0; + + /* Otherwise return true. */ + return 1; +} + +static void svc_rdma_secure_port(struct svc_rqst *rqstp) +{ + set_bit(RQ_SECURE, &rqstp->rq_flags); +} + +static void svc_rdma_kill_temp_xprt(struct svc_xprt *xprt) +{ +} diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c new file mode 100644 index 000000000..e87a79be7 --- /dev/null +++ b/net/sunrpc/xprtrdma/transport.c @@ -0,0 +1,921 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * transport.c + * + * This file contains the top-level implementation of an RPC RDMA + * transport. + * + * Naming convention: functions beginning with xprt_ are part of the + * transport switch. All others are RPC RDMA internal. + */ + +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/seq_file.h> +#include <linux/smp.h> + +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* + * tunables + */ + +static unsigned int xprt_rdma_slot_table_entries = RPCRDMA_DEF_SLOT_TABLE; +unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE; +static unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE; +unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRWR; +int xprt_rdma_pad_optimize; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + +static unsigned int min_slot_table_size = RPCRDMA_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPCRDMA_MAX_SLOT_TABLE; +static unsigned int min_inline_size = RPCRDMA_MIN_INLINE; +static unsigned int max_inline_size = RPCRDMA_MAX_INLINE; +static unsigned int zero; +static unsigned int max_padding = PAGE_SIZE; +static unsigned int min_memreg = RPCRDMA_BOUNCEBUFFERS; +static unsigned int max_memreg = RPCRDMA_LAST - 1; +static unsigned int dummy; + +static struct ctl_table_header *sunrpc_table_header; + +static struct ctl_table xr_tunables_table[] = { + { + .procname = "rdma_slot_table_entries", + .data = &xprt_rdma_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "rdma_max_inline_read", + .data = &xprt_rdma_max_inline_read, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, + }, + { + .procname = "rdma_max_inline_write", + .data = &xprt_rdma_max_inline_write, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, + }, + { + .procname = "rdma_inline_write_padding", + .data = &dummy, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &max_padding, + }, + { + .procname = "rdma_memreg_strategy", + .data = &xprt_rdma_memreg_strategy, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_memreg, + .extra2 = &max_memreg, + }, + { + .procname = "rdma_pad_optimize", + .data = &xprt_rdma_pad_optimize, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { }, +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = xr_tunables_table + }, + { }, +}; + +#endif + +static const struct rpc_xprt_ops xprt_rdma_procs; + +static void +xprt_rdma_format_addresses4(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)sap; + char buf[20]; + + snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA; +} + +static void +xprt_rdma_format_addresses6(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + char buf[40]; + + snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA6; +} + +void +xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + char buf[128]; + + switch (sap->sa_family) { + case AF_INET: + xprt_rdma_format_addresses4(xprt, sap); + break; + case AF_INET6: + xprt_rdma_format_addresses6(xprt, sap); + break; + default: + pr_err("rpcrdma: Unrecognized address family\n"); + return; + } + + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma"; +} + +void +xprt_rdma_free_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +void +rpcrdma_conn_func(struct rpcrdma_ep *ep) +{ + schedule_delayed_work(&ep->rep_connect_worker, 0); +} + +void +rpcrdma_connect_worker(struct work_struct *work) +{ + struct rpcrdma_ep *ep = + container_of(work, struct rpcrdma_ep, rep_connect_worker.work); + struct rpcrdma_xprt *r_xprt = + container_of(ep, struct rpcrdma_xprt, rx_ep); + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + + spin_lock_bh(&xprt->transport_lock); + if (ep->rep_connected > 0) { + if (!xprt_test_and_set_connected(xprt)) { + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_wake_pending_tasks(xprt, 0); + } + } else { + if (xprt_test_and_clear_connected(xprt)) + xprt_wake_pending_tasks(xprt, -ENOTCONN); + } + spin_unlock_bh(&xprt->transport_lock); +} + +static void +xprt_rdma_connect_worker(struct work_struct *work) +{ + struct rpcrdma_xprt *r_xprt = container_of(work, struct rpcrdma_xprt, + rx_connect_worker.work); + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + int rc = 0; + + xprt_clear_connected(xprt); + + rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); + if (rc) + xprt_wake_pending_tasks(xprt, rc); + + xprt_clear_connecting(xprt); +} + +static void +xprt_rdma_inject_disconnect(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = container_of(xprt, struct rpcrdma_xprt, + rx_xprt); + + trace_xprtrdma_inject_dsc(r_xprt); + rdma_disconnect(r_xprt->rx_ia.ri_id); +} + +/* + * xprt_rdma_destroy + * + * Destroy the xprt. + * Free all memory associated with the object, including its own. + * NOTE: none of the *destroy methods free memory for their top-level + * objects, even though they may have allocated it (they do free + * private memory). It's up to the caller to handle it. In this + * case (RDMA transport), all structure memory is inlined with the + * struct rpcrdma_xprt. + */ +static void +xprt_rdma_destroy(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + trace_xprtrdma_destroy(r_xprt); + + cancel_delayed_work_sync(&r_xprt->rx_connect_worker); + + xprt_clear_connected(xprt); + + rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); + rpcrdma_buffer_destroy(&r_xprt->rx_buf); + rpcrdma_ia_close(&r_xprt->rx_ia); + + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + + module_put(THIS_MODULE); +} + +static const struct rpc_timeout xprt_rdma_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/** + * xprt_setup_rdma - Set up transport to use RDMA + * + * @args: rpc transport arguments + */ +static struct rpc_xprt * +xprt_setup_rdma(struct xprt_create *args) +{ + struct rpcrdma_create_data_internal cdata; + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + struct rpcrdma_ep *new_ep; + struct sockaddr *sap; + int rc; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: %s: address too large\n", __func__); + return ERR_PTR(-EBADF); + } + + xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), 0, 0); + if (xprt == NULL) { + dprintk("RPC: %s: couldn't allocate rpcrdma_xprt\n", + __func__); + return ERR_PTR(-ENOMEM); + } + + /* 60 second timeout, no retries */ + xprt->timeout = &xprt_rdma_default_timeout; + xprt->bind_timeout = RPCRDMA_BIND_TO; + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; + + xprt->resvport = 0; /* privileged port not needed */ + xprt->tsh_size = 0; /* RPC-RDMA handles framing */ + xprt->ops = &xprt_rdma_procs; + + /* + * Set up RDMA-specific connect data. + */ + sap = args->dstaddr; + + /* Ensure xprt->addr holds valid server TCP (not RDMA) + * address, for any side protocols which peek at it */ + xprt->prot = IPPROTO_TCP; + xprt->addrlen = args->addrlen; + memcpy(&xprt->addr, sap, xprt->addrlen); + + if (rpc_get_port(sap)) + xprt_set_bound(xprt); + xprt_rdma_format_addresses(xprt, sap); + + cdata.max_requests = xprt_rdma_slot_table_entries; + + cdata.rsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA write max */ + cdata.wsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA read max */ + + cdata.inline_wsize = xprt_rdma_max_inline_write; + if (cdata.inline_wsize > cdata.wsize) + cdata.inline_wsize = cdata.wsize; + + cdata.inline_rsize = xprt_rdma_max_inline_read; + if (cdata.inline_rsize > cdata.rsize) + cdata.inline_rsize = cdata.rsize; + + /* + * Create new transport instance, which includes initialized + * o ia + * o endpoint + * o buffers + */ + + new_xprt = rpcx_to_rdmax(xprt); + + rc = rpcrdma_ia_open(new_xprt); + if (rc) + goto out1; + + /* + * initialize and create ep + */ + new_xprt->rx_data = cdata; + new_ep = &new_xprt->rx_ep; + + rc = rpcrdma_ep_create(&new_xprt->rx_ep, + &new_xprt->rx_ia, &new_xprt->rx_data); + if (rc) + goto out2; + + rc = rpcrdma_buffer_create(new_xprt); + if (rc) + goto out3; + + INIT_DELAYED_WORK(&new_xprt->rx_connect_worker, + xprt_rdma_connect_worker); + + xprt->max_payload = new_xprt->rx_ia.ri_ops->ro_maxpages(new_xprt); + if (xprt->max_payload == 0) + goto out4; + xprt->max_payload <<= PAGE_SHIFT; + dprintk("RPC: %s: transport data payload maximum: %zu bytes\n", + __func__, xprt->max_payload); + + if (!try_module_get(THIS_MODULE)) + goto out4; + + dprintk("RPC: %s: %s:%s\n", __func__, + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT]); + trace_xprtrdma_create(new_xprt); + return xprt; + +out4: + rpcrdma_buffer_destroy(&new_xprt->rx_buf); + rc = -ENODEV; +out3: + rpcrdma_ep_destroy(new_ep, &new_xprt->rx_ia); +out2: + rpcrdma_ia_close(&new_xprt->rx_ia); +out1: + trace_xprtrdma_destroy(new_xprt); + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + return ERR_PTR(rc); +} + +/** + * xprt_rdma_close - Close down RDMA connection + * @xprt: generic transport to be closed + * + * Called during transport shutdown reconnect, or device + * removal. Caller holds the transport's write lock. + */ +static void +xprt_rdma_close(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_ep *ep = &r_xprt->rx_ep; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + + dprintk("RPC: %s: closing xprt %p\n", __func__, xprt); + + if (test_and_clear_bit(RPCRDMA_IAF_REMOVING, &ia->ri_flags)) { + xprt_clear_connected(xprt); + rpcrdma_ia_remove(ia); + return; + } + if (ep->rep_connected == -ENODEV) + return; + if (ep->rep_connected > 0) + xprt->reestablish_timeout = 0; + xprt_disconnect_done(xprt); + rpcrdma_ep_disconnect(ep, ia); + + /* Prepare @xprt for the next connection by reinitializing + * its credit grant to one (see RFC 8166, Section 3.3.3). + */ + r_xprt->rx_buf.rb_credits = 1; + xprt->cwnd = RPC_CWNDSHIFT; +} + +/** + * xprt_rdma_set_port - update server port with rpcbind result + * @xprt: controlling RPC transport + * @port: new port value + * + * Transport connect status is unchanged. + */ +static void +xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) +{ + struct sockaddr *sap = (struct sockaddr *)&xprt->addr; + char buf[8]; + + dprintk("RPC: %s: setting port for xprt %p (%s:%s) to %u\n", + __func__, xprt, + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + port); + + rpc_set_port(sap, port); + + kfree(xprt->address_strings[RPC_DISPLAY_PORT]); + snprintf(buf, sizeof(buf), "%u", port); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]); + snprintf(buf, sizeof(buf), "%4hx", port); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); +} + +/** + * xprt_rdma_timer - invoked when an RPC times out + * @xprt: controlling RPC transport + * @task: RPC task that timed out + * + * Invoked when the transport is still connected, but an RPC + * retransmit timeout occurs. + * + * Since RDMA connections don't have a keep-alive, forcibly + * disconnect and retry to connect. This drives full + * detection of the network path, and retransmissions of + * all pending RPCs. + */ +static void +xprt_rdma_timer(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt_force_disconnect(xprt); +} + +static void +xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + if (r_xprt->rx_ep.rep_connected != 0) { + /* Reconnect */ + schedule_delayed_work(&r_xprt->rx_connect_worker, + xprt->reestablish_timeout); + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > RPCRDMA_MAX_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_MAX_REEST_TO; + else if (xprt->reestablish_timeout < RPCRDMA_INIT_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + } else { + schedule_delayed_work(&r_xprt->rx_connect_worker, 0); + if (!RPC_IS_ASYNC(task)) + flush_delayed_work(&r_xprt->rx_connect_worker); + } +} + +/** + * xprt_rdma_alloc_slot - allocate an rpc_rqst + * @xprt: controlling RPC transport + * @task: RPC task requesting a fresh rpc_rqst + * + * tk_status values: + * %0 if task->tk_rqstp points to a fresh rpc_rqst + * %-EAGAIN if no rpc_rqst is available; queued on backlog + */ +static void +xprt_rdma_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req; + + req = rpcrdma_buffer_get(&r_xprt->rx_buf); + if (!req) + goto out_sleep; + task->tk_rqstp = &req->rl_slot; + task->tk_status = 0; + return; + +out_sleep: + rpc_sleep_on(&xprt->backlog, task, NULL); + task->tk_status = -EAGAIN; +} + +/** + * xprt_rdma_free_slot - release an rpc_rqst + * @xprt: controlling RPC transport + * @rqst: rpc_rqst to release + * + */ +static void +xprt_rdma_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *rqst) +{ + memset(rqst, 0, sizeof(*rqst)); + rpcrdma_buffer_put(rpcr_to_rdmar(rqst)); + rpc_wake_up_next(&xprt->backlog); +} + +static bool +rpcrdma_get_sendbuf(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, + size_t size, gfp_t flags) +{ + struct rpcrdma_regbuf *rb; + + if (req->rl_sendbuf && rdmab_length(req->rl_sendbuf) >= size) + return true; + + rb = rpcrdma_alloc_regbuf(size, DMA_TO_DEVICE, flags); + if (IS_ERR(rb)) + return false; + + rpcrdma_free_regbuf(req->rl_sendbuf); + r_xprt->rx_stats.hardway_register_count += size; + req->rl_sendbuf = rb; + return true; +} + +/* The rq_rcv_buf is used only if a Reply chunk is necessary. + * The decision to use a Reply chunk is made later in + * rpcrdma_marshal_req. This buffer is registered at that time. + * + * Otherwise, the associated RPC Reply arrives in a separate + * Receive buffer, arbitrarily chosen by the HCA. The buffer + * allocated here for the RPC Reply is not utilized in that + * case. See rpcrdma_inline_fixup. + * + * A regbuf is used here to remember the buffer size. + */ +static bool +rpcrdma_get_recvbuf(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, + size_t size, gfp_t flags) +{ + struct rpcrdma_regbuf *rb; + + if (req->rl_recvbuf && rdmab_length(req->rl_recvbuf) >= size) + return true; + + rb = rpcrdma_alloc_regbuf(size, DMA_NONE, flags); + if (IS_ERR(rb)) + return false; + + rpcrdma_free_regbuf(req->rl_recvbuf); + r_xprt->rx_stats.hardway_register_count += size; + req->rl_recvbuf = rb; + return true; +} + +/** + * xprt_rdma_allocate - allocate transport resources for an RPC + * @task: RPC task + * + * Return values: + * 0: Success; rq_buffer points to RPC buffer to use + * ENOMEM: Out of memory, call again later + * EIO: A permanent error occurred, do not retry + * + * The RDMA allocate/free functions need the task structure as a place + * to hide the struct rpcrdma_req, which is necessary for the actual + * send/recv sequence. + * + * xprt_rdma_allocate provides buffers that are already mapped for + * DMA, and a local DMA lkey is provided for each. + */ +static int +xprt_rdma_allocate(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + gfp_t flags; + + flags = RPCRDMA_DEF_GFP; + if (RPC_IS_ASYNC(task)) + flags = GFP_NOWAIT | __GFP_NOWARN; + if (RPC_IS_SWAPPER(task)) + flags |= __GFP_MEMALLOC; + + if (!rpcrdma_get_sendbuf(r_xprt, req, rqst->rq_callsize, flags)) + goto out_fail; + if (!rpcrdma_get_recvbuf(r_xprt, req, rqst->rq_rcvsize, flags)) + goto out_fail; + + rqst->rq_buffer = req->rl_sendbuf->rg_base; + rqst->rq_rbuffer = req->rl_recvbuf->rg_base; + trace_xprtrdma_allocate(task, req); + return 0; + +out_fail: + trace_xprtrdma_allocate(task, NULL); + return -ENOMEM; +} + +/** + * xprt_rdma_free - release resources allocated by xprt_rdma_allocate + * @task: RPC task + * + * Caller guarantees rqst->rq_buffer is non-NULL. + */ +static void +xprt_rdma_free(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + + if (test_bit(RPCRDMA_REQ_F_PENDING, &req->rl_flags)) + rpcrdma_release_rqst(r_xprt, req); + trace_xprtrdma_rpc_done(task, req); +} + +/** + * xprt_rdma_send_request - marshal and send an RPC request + * @task: RPC task with an RPC message in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EAGAIN if the caller should call again + * %-ENOBUFS if the caller should call again after a delay + * %-EIO if a permanent error occurred and the request was not + * sent. Do not try to send this message again. + */ +static int +xprt_rdma_send_request(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc = 0; + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + if (unlikely(!rqst->rq_buffer)) + return xprt_rdma_bc_send_reply(rqst); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + + if (!xprt_connected(xprt)) + goto drop_connection; + + rc = rpcrdma_marshal_req(r_xprt, rqst); + if (rc < 0) + goto failed_marshal; + + /* Must suppress retransmit to maintain credits */ + if (rqst->rq_connect_cookie == xprt->connect_cookie) + goto drop_connection; + rqst->rq_xtime = ktime_get(); + + __set_bit(RPCRDMA_REQ_F_PENDING, &req->rl_flags); + if (rpcrdma_ep_post(&r_xprt->rx_ia, &r_xprt->rx_ep, req)) + goto drop_connection; + + rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len; + rqst->rq_bytes_sent = 0; + + /* An RPC with no reply will throw off credit accounting, + * so drop the connection to reset the credit grant. + */ + if (!rpc_reply_expected(task)) + goto drop_connection; + return 0; + +failed_marshal: + if (rc != -ENOTCONN) + return rc; +drop_connection: + xprt_disconnect_done(xprt); + return -ENOTCONN; /* implies disconnect */ +} + +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_puts(seq, "\txprt:\trdma "); + seq_printf(seq, "%u %lu %lu %lu %ld %lu %lu %lu %llu %llu ", + 0, /* need a local port? */ + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u); + seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %lu %lu %lu %lu ", + r_xprt->rx_stats.read_chunk_count, + r_xprt->rx_stats.write_chunk_count, + r_xprt->rx_stats.reply_chunk_count, + r_xprt->rx_stats.total_rdma_request, + r_xprt->rx_stats.total_rdma_reply, + r_xprt->rx_stats.pullup_copy_count, + r_xprt->rx_stats.fixup_copy_count, + r_xprt->rx_stats.hardway_register_count, + r_xprt->rx_stats.failed_marshal_count, + r_xprt->rx_stats.bad_reply_count, + r_xprt->rx_stats.nomsg_call_count); + seq_printf(seq, "%lu %lu %lu %lu %lu %lu\n", + r_xprt->rx_stats.mrs_recovered, + r_xprt->rx_stats.mrs_orphaned, + r_xprt->rx_stats.mrs_allocated, + r_xprt->rx_stats.local_inv_needed, + r_xprt->rx_stats.empty_sendctx_q, + r_xprt->rx_stats.reply_waits_for_send); +} + +static int +xprt_rdma_enable_swap(struct rpc_xprt *xprt) +{ + return 0; +} + +static void +xprt_rdma_disable_swap(struct rpc_xprt *xprt) +{ +} + +/* + * Plumbing for rpc transport switch and kernel module + */ + +static const struct rpc_xprt_ops xprt_rdma_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */ + .alloc_slot = xprt_rdma_alloc_slot, + .free_slot = xprt_rdma_free_slot, + .release_request = xprt_release_rqst_cong, /* ditto */ + .set_retrans_timeout = xprt_set_retrans_timeout_def, /* ditto */ + .timer = xprt_rdma_timer, + .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */ + .set_port = xprt_rdma_set_port, + .connect = xprt_rdma_connect, + .buf_alloc = xprt_rdma_allocate, + .buf_free = xprt_rdma_free, + .send_request = xprt_rdma_send_request, + .close = xprt_rdma_close, + .destroy = xprt_rdma_destroy, + .print_stats = xprt_rdma_print_stats, + .enable_swap = xprt_rdma_enable_swap, + .disable_swap = xprt_rdma_disable_swap, + .inject_disconnect = xprt_rdma_inject_disconnect, +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + .bc_setup = xprt_rdma_bc_setup, + .bc_up = xprt_rdma_bc_up, + .bc_maxpayload = xprt_rdma_bc_maxpayload, + .bc_free_rqst = xprt_rdma_bc_free_rqst, + .bc_destroy = xprt_rdma_bc_destroy, +#endif +}; + +static struct xprt_class xprt_rdma = { + .list = LIST_HEAD_INIT(xprt_rdma.list), + .name = "rdma", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_RDMA, + .setup = xprt_setup_rdma, + .netid = { "rdma", "rdma6", "" }, +}; + +void xprt_rdma_cleanup(void) +{ + int rc; + + dprintk("RPCRDMA Module Removed, deregister RPC RDMA transport\n"); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +#endif + rc = xprt_unregister_transport(&xprt_rdma); + if (rc) + dprintk("RPC: %s: xprt_unregister returned %i\n", + __func__, rc); + + rpcrdma_destroy_wq(); + + rc = xprt_unregister_transport(&xprt_rdma_bc); + if (rc) + dprintk("RPC: %s: xprt_unregister(bc) returned %i\n", + __func__, rc); +} + +int xprt_rdma_init(void) +{ + int rc; + + rc = rpcrdma_alloc_wq(); + if (rc) + return rc; + + rc = xprt_register_transport(&xprt_rdma); + if (rc) { + rpcrdma_destroy_wq(); + return rc; + } + + rc = xprt_register_transport(&xprt_rdma_bc); + if (rc) { + xprt_unregister_transport(&xprt_rdma); + rpcrdma_destroy_wq(); + return rc; + } + + dprintk("RPCRDMA Module Init, register RPC RDMA transport\n"); + + dprintk("Defaults:\n"); + dprintk("\tSlots %d\n" + "\tMaxInlineRead %d\n\tMaxInlineWrite %d\n", + xprt_rdma_slot_table_entries, + xprt_rdma_max_inline_read, xprt_rdma_max_inline_write); + dprintk("\tPadding 0\n\tMemreg %d\n", xprt_rdma_memreg_strategy); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +#endif + return 0; +} diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c new file mode 100644 index 000000000..ef1f3d076 --- /dev/null +++ b/net/sunrpc/xprtrdma/verbs.c @@ -0,0 +1,1572 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * verbs.c + * + * Encapsulates the major functions managing: + * o adapters + * o endpoints + * o connections + * o buffer memory + */ + +#include <linux/interrupt.h> +#include <linux/slab.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/svc_rdma.h> + +#include <asm-generic/barrier.h> +#include <asm/bitops.h> + +#include <rdma/ib_cm.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +/* + * Globals/Macros + */ + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* + * internal functions + */ +static void rpcrdma_sendctx_put_locked(struct rpcrdma_sendctx *sc); +static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_mrs_destroy(struct rpcrdma_buffer *buf); +static int rpcrdma_create_rep(struct rpcrdma_xprt *r_xprt, bool temp); +static void rpcrdma_dma_unmap_regbuf(struct rpcrdma_regbuf *rb); + +struct workqueue_struct *rpcrdma_receive_wq __read_mostly; + +int +rpcrdma_alloc_wq(void) +{ + struct workqueue_struct *recv_wq; + + recv_wq = alloc_workqueue("xprtrdma_receive", + WQ_MEM_RECLAIM | WQ_HIGHPRI, + 0); + if (!recv_wq) + return -ENOMEM; + + rpcrdma_receive_wq = recv_wq; + return 0; +} + +void +rpcrdma_destroy_wq(void) +{ + struct workqueue_struct *wq; + + if (rpcrdma_receive_wq) { + wq = rpcrdma_receive_wq; + rpcrdma_receive_wq = NULL; + destroy_workqueue(wq); + } +} + +static void +rpcrdma_qp_async_error_upcall(struct ib_event *event, void *context) +{ + struct rpcrdma_ep *ep = context; + struct rpcrdma_xprt *r_xprt = container_of(ep, struct rpcrdma_xprt, + rx_ep); + + trace_xprtrdma_qp_error(r_xprt, event); + pr_err("rpcrdma: %s on device %s ep %p\n", + ib_event_msg(event->event), event->device->name, context); + + if (ep->rep_connected == 1) { + ep->rep_connected = -EIO; + rpcrdma_conn_func(ep); + wake_up_all(&ep->rep_connect_wait); + } +} + +/** + * rpcrdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: completion queue (ignored) + * @wc: completed WR + * + */ +static void +rpcrdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_sendctx *sc = + container_of(cqe, struct rpcrdma_sendctx, sc_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_send(sc, wc); + if (wc->status != IB_WC_SUCCESS && wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("rpcrdma: Send: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); + + rpcrdma_sendctx_put_locked(sc); +} + +/** + * rpcrdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: completion queue (ignored) + * @wc: completed WR + * + */ +static void +rpcrdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_rep *rep = container_of(cqe, struct rpcrdma_rep, + rr_cqe); + + /* WARNING: Only wr_id and status are reliable at this point */ + trace_xprtrdma_wc_receive(wc); + if (wc->status != IB_WC_SUCCESS) + goto out_fail; + + /* status == SUCCESS means all fields in wc are trustworthy */ + rpcrdma_set_xdrlen(&rep->rr_hdrbuf, wc->byte_len); + rep->rr_wc_flags = wc->wc_flags; + rep->rr_inv_rkey = wc->ex.invalidate_rkey; + + ib_dma_sync_single_for_cpu(rdmab_device(rep->rr_rdmabuf), + rdmab_addr(rep->rr_rdmabuf), + wc->byte_len, DMA_FROM_DEVICE); + +out_schedule: + rpcrdma_reply_handler(rep); + return; + +out_fail: + if (wc->status != IB_WC_WR_FLUSH_ERR) + pr_err("rpcrdma: Recv: %s (%u/0x%x)\n", + ib_wc_status_msg(wc->status), + wc->status, wc->vendor_err); + rpcrdma_set_xdrlen(&rep->rr_hdrbuf, 0); + goto out_schedule; +} + +static void +rpcrdma_update_connect_private(struct rpcrdma_xprt *r_xprt, + struct rdma_conn_param *param) +{ + struct rpcrdma_create_data_internal *cdata = &r_xprt->rx_data; + const struct rpcrdma_connect_private *pmsg = param->private_data; + unsigned int rsize, wsize; + + /* Default settings for RPC-over-RDMA Version One */ + r_xprt->rx_ia.ri_implicit_roundup = xprt_rdma_pad_optimize; + rsize = RPCRDMA_V1_DEF_INLINE_SIZE; + wsize = RPCRDMA_V1_DEF_INLINE_SIZE; + + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + r_xprt->rx_ia.ri_implicit_roundup = true; + rsize = rpcrdma_decode_buffer_size(pmsg->cp_send_size); + wsize = rpcrdma_decode_buffer_size(pmsg->cp_recv_size); + } + + if (rsize < cdata->inline_rsize) + cdata->inline_rsize = rsize; + if (wsize < cdata->inline_wsize) + cdata->inline_wsize = wsize; + dprintk("RPC: %s: max send %u, max recv %u\n", + __func__, cdata->inline_wsize, cdata->inline_rsize); + rpcrdma_set_max_header_sizes(r_xprt); +} + +static int +rpcrdma_conn_upcall(struct rdma_cm_id *id, struct rdma_cm_event *event) +{ + struct rpcrdma_xprt *xprt = id->context; + struct rpcrdma_ia *ia = &xprt->rx_ia; + struct rpcrdma_ep *ep = &xprt->rx_ep; + int connstate = 0; + + trace_xprtrdma_conn_upcall(xprt, event); + switch (event->event) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + case RDMA_CM_EVENT_ROUTE_RESOLVED: + ia->ri_async_rc = 0; + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_ADDR_ERROR: + ia->ri_async_rc = -EPROTO; + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_ROUTE_ERROR: + ia->ri_async_rc = -ENETUNREACH; + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_DEVICE_REMOVAL: +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + pr_info("rpcrdma: removing device %s for %s:%s\n", + ia->ri_device->name, + rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt)); +#endif + init_completion(&ia->ri_remove_done); + set_bit(RPCRDMA_IAF_REMOVING, &ia->ri_flags); + ep->rep_connected = -ENODEV; + xprt_force_disconnect(&xprt->rx_xprt); + wait_for_completion(&ia->ri_remove_done); + + ia->ri_id = NULL; + ia->ri_device = NULL; + /* Return 1 to ensure the core destroys the id. */ + return 1; + case RDMA_CM_EVENT_ESTABLISHED: + ++xprt->rx_xprt.connect_cookie; + connstate = 1; + rpcrdma_update_connect_private(xprt, &event->param.conn); + goto connected; + case RDMA_CM_EVENT_CONNECT_ERROR: + connstate = -ENOTCONN; + goto connected; + case RDMA_CM_EVENT_UNREACHABLE: + connstate = -ENETUNREACH; + goto connected; + case RDMA_CM_EVENT_REJECTED: + dprintk("rpcrdma: connection to %s:%s rejected: %s\n", + rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt), + rdma_reject_msg(id, event->status)); + connstate = -ECONNREFUSED; + if (event->status == IB_CM_REJ_STALE_CONN) + connstate = -EAGAIN; + goto connected; + case RDMA_CM_EVENT_DISCONNECTED: + ++xprt->rx_xprt.connect_cookie; + connstate = -ECONNABORTED; +connected: + ep->rep_connected = connstate; + rpcrdma_conn_func(ep); + wake_up_all(&ep->rep_connect_wait); + /*FALLTHROUGH*/ + default: + dprintk("RPC: %s: %s:%s on %s/%s (ep 0x%p): %s\n", + __func__, + rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt), + ia->ri_device->name, ia->ri_ops->ro_displayname, + ep, rdma_event_msg(event->event)); + break; + } + + return 0; +} + +static struct rdma_cm_id * +rpcrdma_create_id(struct rpcrdma_xprt *xprt, struct rpcrdma_ia *ia) +{ + unsigned long wtimeout = msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1; + struct rdma_cm_id *id; + int rc; + + trace_xprtrdma_conn_start(xprt); + + init_completion(&ia->ri_done); + + id = rdma_create_id(xprt->rx_xprt.xprt_net, rpcrdma_conn_upcall, + xprt, RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(id)) { + rc = PTR_ERR(id); + dprintk("RPC: %s: rdma_create_id() failed %i\n", + __func__, rc); + return id; + } + + ia->ri_async_rc = -ETIMEDOUT; + rc = rdma_resolve_addr(id, NULL, + (struct sockaddr *)&xprt->rx_xprt.addr, + RDMA_RESOLVE_TIMEOUT); + if (rc) { + dprintk("RPC: %s: rdma_resolve_addr() failed %i\n", + __func__, rc); + goto out; + } + rc = wait_for_completion_interruptible_timeout(&ia->ri_done, wtimeout); + if (rc < 0) { + trace_xprtrdma_conn_tout(xprt); + goto out; + } + + rc = ia->ri_async_rc; + if (rc) + goto out; + + ia->ri_async_rc = -ETIMEDOUT; + rc = rdma_resolve_route(id, RDMA_RESOLVE_TIMEOUT); + if (rc) { + dprintk("RPC: %s: rdma_resolve_route() failed %i\n", + __func__, rc); + goto out; + } + rc = wait_for_completion_interruptible_timeout(&ia->ri_done, wtimeout); + if (rc < 0) { + trace_xprtrdma_conn_tout(xprt); + goto out; + } + rc = ia->ri_async_rc; + if (rc) + goto out; + + return id; + +out: + rdma_destroy_id(id); + return ERR_PTR(rc); +} + +/* + * Exported functions. + */ + +/** + * rpcrdma_ia_open - Open and initialize an Interface Adapter. + * @xprt: transport with IA to (re)initialize + * + * Returns 0 on success, negative errno if an appropriate + * Interface Adapter could not be found and opened. + */ +int +rpcrdma_ia_open(struct rpcrdma_xprt *xprt) +{ + struct rpcrdma_ia *ia = &xprt->rx_ia; + int rc; + + ia->ri_id = rpcrdma_create_id(xprt, ia); + if (IS_ERR(ia->ri_id)) { + rc = PTR_ERR(ia->ri_id); + goto out_err; + } + ia->ri_device = ia->ri_id->device; + + ia->ri_pd = ib_alloc_pd(ia->ri_device, 0); + if (IS_ERR(ia->ri_pd)) { + rc = PTR_ERR(ia->ri_pd); + pr_err("rpcrdma: ib_alloc_pd() returned %d\n", rc); + goto out_err; + } + + switch (xprt_rdma_memreg_strategy) { + case RPCRDMA_FRWR: + if (frwr_is_supported(ia)) { + ia->ri_ops = &rpcrdma_frwr_memreg_ops; + break; + } + /*FALLTHROUGH*/ + case RPCRDMA_MTHCAFMR: + if (fmr_is_supported(ia)) { + ia->ri_ops = &rpcrdma_fmr_memreg_ops; + break; + } + /*FALLTHROUGH*/ + default: + pr_err("rpcrdma: Device %s does not support memreg mode %d\n", + ia->ri_device->name, xprt_rdma_memreg_strategy); + rc = -EINVAL; + goto out_err; + } + + return 0; + +out_err: + rpcrdma_ia_close(ia); + return rc; +} + +/** + * rpcrdma_ia_remove - Handle device driver unload + * @ia: interface adapter being removed + * + * Divest transport H/W resources associated with this adapter, + * but allow it to be restored later. + */ +void +rpcrdma_ia_remove(struct rpcrdma_ia *ia) +{ + struct rpcrdma_xprt *r_xprt = container_of(ia, struct rpcrdma_xprt, + rx_ia); + struct rpcrdma_ep *ep = &r_xprt->rx_ep; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + struct rpcrdma_rep *rep; + + cancel_delayed_work_sync(&buf->rb_refresh_worker); + + /* This is similar to rpcrdma_ep_destroy, but: + * - Don't cancel the connect worker. + * - Don't call rpcrdma_ep_disconnect, which waits + * for another conn upcall, which will deadlock. + * - rdma_disconnect is unneeded, the underlying + * connection is already gone. + */ + if (ia->ri_id->qp) { + ib_drain_qp(ia->ri_id->qp); + rdma_destroy_qp(ia->ri_id); + ia->ri_id->qp = NULL; + } + ib_free_cq(ep->rep_attr.recv_cq); + ep->rep_attr.recv_cq = NULL; + ib_free_cq(ep->rep_attr.send_cq); + ep->rep_attr.send_cq = NULL; + + /* The ULP is responsible for ensuring all DMA + * mappings and MRs are gone. + */ + list_for_each_entry(rep, &buf->rb_recv_bufs, rr_list) + rpcrdma_dma_unmap_regbuf(rep->rr_rdmabuf); + list_for_each_entry(req, &buf->rb_allreqs, rl_all) { + rpcrdma_dma_unmap_regbuf(req->rl_rdmabuf); + rpcrdma_dma_unmap_regbuf(req->rl_sendbuf); + rpcrdma_dma_unmap_regbuf(req->rl_recvbuf); + } + rpcrdma_mrs_destroy(buf); + ib_dealloc_pd(ia->ri_pd); + ia->ri_pd = NULL; + + /* Allow waiters to continue */ + complete(&ia->ri_remove_done); + + trace_xprtrdma_remove(r_xprt); +} + +/** + * rpcrdma_ia_close - Clean up/close an IA. + * @ia: interface adapter to close + * + */ +void +rpcrdma_ia_close(struct rpcrdma_ia *ia) +{ + if (ia->ri_id != NULL && !IS_ERR(ia->ri_id)) { + if (ia->ri_id->qp) + rdma_destroy_qp(ia->ri_id); + rdma_destroy_id(ia->ri_id); + } + ia->ri_id = NULL; + ia->ri_device = NULL; + + /* If the pd is still busy, xprtrdma missed freeing a resource */ + if (ia->ri_pd && !IS_ERR(ia->ri_pd)) + ib_dealloc_pd(ia->ri_pd); + ia->ri_pd = NULL; +} + +/* + * Create unconnected endpoint. + */ +int +rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, + struct rpcrdma_create_data_internal *cdata) +{ + struct rpcrdma_connect_private *pmsg = &ep->rep_cm_private; + struct ib_cq *sendcq, *recvcq; + unsigned int max_sge; + int rc; + + max_sge = min_t(unsigned int, ia->ri_device->attrs.max_send_sge, + RPCRDMA_MAX_SEND_SGES); + if (max_sge < RPCRDMA_MIN_SEND_SGES) { + pr_warn("rpcrdma: HCA provides only %d send SGEs\n", max_sge); + return -ENOMEM; + } + ia->ri_max_send_sges = max_sge; + + rc = ia->ri_ops->ro_open(ia, ep, cdata); + if (rc) + return rc; + + ep->rep_attr.event_handler = rpcrdma_qp_async_error_upcall; + ep->rep_attr.qp_context = ep; + ep->rep_attr.srq = NULL; + ep->rep_attr.cap.max_send_sge = max_sge; + ep->rep_attr.cap.max_recv_sge = 1; + ep->rep_attr.cap.max_inline_data = 0; + ep->rep_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + ep->rep_attr.qp_type = IB_QPT_RC; + ep->rep_attr.port_num = ~0; + + dprintk("RPC: %s: requested max: dtos: send %d recv %d; " + "iovs: send %d recv %d\n", + __func__, + ep->rep_attr.cap.max_send_wr, + ep->rep_attr.cap.max_recv_wr, + ep->rep_attr.cap.max_send_sge, + ep->rep_attr.cap.max_recv_sge); + + /* set trigger for requesting send completion */ + ep->rep_send_batch = min_t(unsigned int, RPCRDMA_MAX_SEND_BATCH, + cdata->max_requests >> 2); + ep->rep_send_count = ep->rep_send_batch; + init_waitqueue_head(&ep->rep_connect_wait); + INIT_DELAYED_WORK(&ep->rep_connect_worker, rpcrdma_connect_worker); + + sendcq = ib_alloc_cq(ia->ri_device, NULL, + ep->rep_attr.cap.max_send_wr + 1, + ia->ri_device->num_comp_vectors > 1 ? 1 : 0, + IB_POLL_WORKQUEUE); + if (IS_ERR(sendcq)) { + rc = PTR_ERR(sendcq); + dprintk("RPC: %s: failed to create send CQ: %i\n", + __func__, rc); + goto out1; + } + + recvcq = ib_alloc_cq(ia->ri_device, NULL, + ep->rep_attr.cap.max_recv_wr + 1, + 0, IB_POLL_WORKQUEUE); + if (IS_ERR(recvcq)) { + rc = PTR_ERR(recvcq); + dprintk("RPC: %s: failed to create recv CQ: %i\n", + __func__, rc); + goto out2; + } + + ep->rep_attr.send_cq = sendcq; + ep->rep_attr.recv_cq = recvcq; + + /* Initialize cma parameters */ + memset(&ep->rep_remote_cma, 0, sizeof(ep->rep_remote_cma)); + + /* Prepare RDMA-CM private message */ + pmsg->cp_magic = rpcrdma_cmp_magic; + pmsg->cp_version = RPCRDMA_CMP_VERSION; + pmsg->cp_flags |= ia->ri_ops->ro_send_w_inv_ok; + pmsg->cp_send_size = rpcrdma_encode_buffer_size(cdata->inline_wsize); + pmsg->cp_recv_size = rpcrdma_encode_buffer_size(cdata->inline_rsize); + ep->rep_remote_cma.private_data = pmsg; + ep->rep_remote_cma.private_data_len = sizeof(*pmsg); + + /* Client offers RDMA Read but does not initiate */ + ep->rep_remote_cma.initiator_depth = 0; + ep->rep_remote_cma.responder_resources = + min_t(int, U8_MAX, ia->ri_device->attrs.max_qp_rd_atom); + + /* Limit transport retries so client can detect server + * GID changes quickly. RPC layer handles re-establishing + * transport connection and retransmission. + */ + ep->rep_remote_cma.retry_count = 6; + + /* RPC-over-RDMA handles its own flow control. In addition, + * make all RNR NAKs visible so we know that RPC-over-RDMA + * flow control is working correctly (no NAKs should be seen). + */ + ep->rep_remote_cma.flow_control = 0; + ep->rep_remote_cma.rnr_retry_count = 0; + + return 0; + +out2: + ib_free_cq(sendcq); +out1: + return rc; +} + +/* + * rpcrdma_ep_destroy + * + * Disconnect and destroy endpoint. After this, the only + * valid operations on the ep are to free it (if dynamically + * allocated) or re-create it. + */ +void +rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + cancel_delayed_work_sync(&ep->rep_connect_worker); + + if (ia->ri_id && ia->ri_id->qp) { + rpcrdma_ep_disconnect(ep, ia); + rdma_destroy_qp(ia->ri_id); + ia->ri_id->qp = NULL; + } + + if (ep->rep_attr.recv_cq) + ib_free_cq(ep->rep_attr.recv_cq); + if (ep->rep_attr.send_cq) + ib_free_cq(ep->rep_attr.send_cq); +} + +/* Re-establish a connection after a device removal event. + * Unlike a normal reconnection, a fresh PD and a new set + * of MRs and buffers is needed. + */ +static int +rpcrdma_ep_recreate_xprt(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + int rc, err; + + trace_xprtrdma_reinsert(r_xprt); + + rc = -EHOSTUNREACH; + if (rpcrdma_ia_open(r_xprt)) + goto out1; + + rc = -ENOMEM; + err = rpcrdma_ep_create(ep, ia, &r_xprt->rx_data); + if (err) { + pr_err("rpcrdma: rpcrdma_ep_create returned %d\n", err); + goto out2; + } + + rc = -ENETUNREACH; + err = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); + if (err) { + pr_err("rpcrdma: rdma_create_qp returned %d\n", err); + goto out3; + } + + rpcrdma_mrs_create(r_xprt); + return 0; + +out3: + rpcrdma_ep_destroy(ep, ia); +out2: + rpcrdma_ia_close(ia); +out1: + return rc; +} + +static int +rpcrdma_ep_reconnect(struct rpcrdma_xprt *r_xprt, struct rpcrdma_ep *ep, + struct rpcrdma_ia *ia) +{ + struct rdma_cm_id *id, *old; + int err, rc; + + trace_xprtrdma_reconnect(r_xprt); + + rpcrdma_ep_disconnect(ep, ia); + + rc = -EHOSTUNREACH; + id = rpcrdma_create_id(r_xprt, ia); + if (IS_ERR(id)) + goto out; + + /* As long as the new ID points to the same device as the + * old ID, we can reuse the transport's existing PD and all + * previously allocated MRs. Also, the same device means + * the transport's previous DMA mappings are still valid. + * + * This is a sanity check only. There should be no way these + * point to two different devices here. + */ + old = id; + rc = -ENETUNREACH; + if (ia->ri_device != id->device) { + pr_err("rpcrdma: can't reconnect on different device!\n"); + goto out_destroy; + } + + err = rdma_create_qp(id, ia->ri_pd, &ep->rep_attr); + if (err) { + dprintk("RPC: %s: rdma_create_qp returned %d\n", + __func__, err); + goto out_destroy; + } + + /* Atomically replace the transport's ID and QP. */ + rc = 0; + old = ia->ri_id; + ia->ri_id = id; + rdma_destroy_qp(old); + +out_destroy: + rdma_destroy_id(old); +out: + return rc; +} + +/* + * Connect unconnected endpoint. + */ +int +rpcrdma_ep_connect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + struct rpcrdma_xprt *r_xprt = container_of(ia, struct rpcrdma_xprt, + rx_ia); + int rc; + +retry: + switch (ep->rep_connected) { + case 0: + dprintk("RPC: %s: connecting...\n", __func__); + rc = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); + if (rc) { + dprintk("RPC: %s: rdma_create_qp failed %i\n", + __func__, rc); + rc = -ENETUNREACH; + goto out_noupdate; + } + break; + case -ENODEV: + rc = rpcrdma_ep_recreate_xprt(r_xprt, ep, ia); + if (rc) + goto out_noupdate; + break; + default: + rc = rpcrdma_ep_reconnect(r_xprt, ep, ia); + if (rc) + goto out; + } + + ep->rep_connected = 0; + rpcrdma_post_recvs(r_xprt, true); + + rc = rdma_connect(ia->ri_id, &ep->rep_remote_cma); + if (rc) { + dprintk("RPC: %s: rdma_connect() failed with %i\n", + __func__, rc); + goto out; + } + + wait_event_interruptible(ep->rep_connect_wait, ep->rep_connected != 0); + if (ep->rep_connected <= 0) { + if (ep->rep_connected == -EAGAIN) + goto retry; + rc = ep->rep_connected; + goto out; + } + + dprintk("RPC: %s: connected\n", __func__); + +out: + if (rc) + ep->rep_connected = rc; + +out_noupdate: + return rc; +} + +/* + * rpcrdma_ep_disconnect + * + * This is separate from destroy to facilitate the ability + * to reconnect without recreating the endpoint. + * + * This call is not reentrant, and must not be made in parallel + * on the same endpoint. + */ +void +rpcrdma_ep_disconnect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + int rc; + + rc = rdma_disconnect(ia->ri_id); + if (!rc) + /* returns without wait if not connected */ + wait_event_interruptible(ep->rep_connect_wait, + ep->rep_connected != 1); + else + ep->rep_connected = rc; + trace_xprtrdma_disconnect(container_of(ep, struct rpcrdma_xprt, + rx_ep), rc); + + ib_drain_qp(ia->ri_id->qp); +} + +/* Fixed-size circular FIFO queue. This implementation is wait-free and + * lock-free. + * + * Consumer is the code path that posts Sends. This path dequeues a + * sendctx for use by a Send operation. Multiple consumer threads + * are serialized by the RPC transport lock, which allows only one + * ->send_request call at a time. + * + * Producer is the code path that handles Send completions. This path + * enqueues a sendctx that has been completed. Multiple producer + * threads are serialized by the ib_poll_cq() function. + */ + +/* rpcrdma_sendctxs_destroy() assumes caller has already quiesced + * queue activity, and ib_drain_qp has flushed all remaining Send + * requests. + */ +static void rpcrdma_sendctxs_destroy(struct rpcrdma_buffer *buf) +{ + unsigned long i; + + for (i = 0; i <= buf->rb_sc_last; i++) + kfree(buf->rb_sc_ctxs[i]); + kfree(buf->rb_sc_ctxs); +} + +static struct rpcrdma_sendctx *rpcrdma_sendctx_create(struct rpcrdma_ia *ia) +{ + struct rpcrdma_sendctx *sc; + + sc = kzalloc(sizeof(*sc) + + ia->ri_max_send_sges * sizeof(struct ib_sge), + GFP_KERNEL); + if (!sc) + return NULL; + + sc->sc_wr.wr_cqe = &sc->sc_cqe; + sc->sc_wr.sg_list = sc->sc_sges; + sc->sc_wr.opcode = IB_WR_SEND; + sc->sc_cqe.done = rpcrdma_wc_send; + return sc; +} + +static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_sendctx *sc; + unsigned long i; + + /* Maximum number of concurrent outstanding Send WRs. Capping + * the circular queue size stops Send Queue overflow by causing + * the ->send_request call to fail temporarily before too many + * Sends are posted. + */ + i = buf->rb_max_requests + RPCRDMA_MAX_BC_REQUESTS; + dprintk("RPC: %s: allocating %lu send_ctxs\n", __func__, i); + buf->rb_sc_ctxs = kcalloc(i, sizeof(sc), GFP_KERNEL); + if (!buf->rb_sc_ctxs) + return -ENOMEM; + + buf->rb_sc_last = i - 1; + for (i = 0; i <= buf->rb_sc_last; i++) { + sc = rpcrdma_sendctx_create(&r_xprt->rx_ia); + if (!sc) + return -ENOMEM; + + sc->sc_xprt = r_xprt; + buf->rb_sc_ctxs[i] = sc; + } + buf->rb_flags = 0; + + return 0; +} + +/* The sendctx queue is not guaranteed to have a size that is a + * power of two, thus the helpers in circ_buf.h cannot be used. + * The other option is to use modulus (%), which can be expensive. + */ +static unsigned long rpcrdma_sendctx_next(struct rpcrdma_buffer *buf, + unsigned long item) +{ + return likely(item < buf->rb_sc_last) ? item + 1 : 0; +} + +/** + * rpcrdma_sendctx_get_locked - Acquire a send context + * @buf: transport buffers from which to acquire an unused context + * + * Returns pointer to a free send completion context; or NULL if + * the queue is empty. + * + * Usage: Called to acquire an SGE array before preparing a Send WR. + * + * The caller serializes calls to this function (per rpcrdma_buffer), + * and provides an effective memory barrier that flushes the new value + * of rb_sc_head. + */ +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_buffer *buf) +{ + struct rpcrdma_xprt *r_xprt; + struct rpcrdma_sendctx *sc; + unsigned long next_head; + + next_head = rpcrdma_sendctx_next(buf, buf->rb_sc_head); + + if (next_head == READ_ONCE(buf->rb_sc_tail)) + goto out_emptyq; + + /* ORDER: item must be accessed _before_ head is updated */ + sc = buf->rb_sc_ctxs[next_head]; + + /* Releasing the lock in the caller acts as a memory + * barrier that flushes rb_sc_head. + */ + buf->rb_sc_head = next_head; + + return sc; + +out_emptyq: + /* The queue is "empty" if there have not been enough Send + * completions recently. This is a sign the Send Queue is + * backing up. Cause the caller to pause and try again. + */ + set_bit(RPCRDMA_BUF_F_EMPTY_SCQ, &buf->rb_flags); + r_xprt = container_of(buf, struct rpcrdma_xprt, rx_buf); + r_xprt->rx_stats.empty_sendctx_q++; + return NULL; +} + +/** + * rpcrdma_sendctx_put_locked - Release a send context + * @sc: send context to release + * + * Usage: Called from Send completion to return a sendctxt + * to the queue. + * + * The caller serializes calls to this function (per rpcrdma_buffer). + */ +static void +rpcrdma_sendctx_put_locked(struct rpcrdma_sendctx *sc) +{ + struct rpcrdma_buffer *buf = &sc->sc_xprt->rx_buf; + unsigned long next_tail; + + /* Unmap SGEs of previously completed by unsignaled + * Sends by walking up the queue until @sc is found. + */ + next_tail = buf->rb_sc_tail; + do { + next_tail = rpcrdma_sendctx_next(buf, next_tail); + + /* ORDER: item must be accessed _before_ tail is updated */ + rpcrdma_unmap_sendctx(buf->rb_sc_ctxs[next_tail]); + + } while (buf->rb_sc_ctxs[next_tail] != sc); + + /* Paired with READ_ONCE */ + smp_store_release(&buf->rb_sc_tail, next_tail); + + if (test_and_clear_bit(RPCRDMA_BUF_F_EMPTY_SCQ, &buf->rb_flags)) { + smp_mb__after_atomic(); + xprt_write_space(&sc->sc_xprt->rx_xprt); + } +} + +static void +rpcrdma_mr_recovery_worker(struct work_struct *work) +{ + struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer, + rb_recovery_worker.work); + struct rpcrdma_mr *mr; + + spin_lock(&buf->rb_recovery_lock); + while (!list_empty(&buf->rb_stale_mrs)) { + mr = rpcrdma_mr_pop(&buf->rb_stale_mrs); + spin_unlock(&buf->rb_recovery_lock); + + trace_xprtrdma_recover_mr(mr); + mr->mr_xprt->rx_ia.ri_ops->ro_recover_mr(mr); + + spin_lock(&buf->rb_recovery_lock); + } + spin_unlock(&buf->rb_recovery_lock); +} + +void +rpcrdma_mr_defer_recovery(struct rpcrdma_mr *mr) +{ + struct rpcrdma_xprt *r_xprt = mr->mr_xprt; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + + spin_lock(&buf->rb_recovery_lock); + rpcrdma_mr_push(mr, &buf->rb_stale_mrs); + spin_unlock(&buf->rb_recovery_lock); + + schedule_delayed_work(&buf->rb_recovery_worker, 0); +} + +static void +rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + unsigned int count; + LIST_HEAD(free); + LIST_HEAD(all); + + for (count = 0; count < 3; count++) { + struct rpcrdma_mr *mr; + int rc; + + mr = kzalloc(sizeof(*mr), GFP_KERNEL); + if (!mr) + break; + + rc = ia->ri_ops->ro_init_mr(ia, mr); + if (rc) { + kfree(mr); + break; + } + + mr->mr_xprt = r_xprt; + + list_add(&mr->mr_list, &free); + list_add(&mr->mr_all, &all); + } + + spin_lock(&buf->rb_mrlock); + list_splice(&free, &buf->rb_mrs); + list_splice(&all, &buf->rb_all); + r_xprt->rx_stats.mrs_allocated += count; + spin_unlock(&buf->rb_mrlock); + trace_xprtrdma_createmrs(r_xprt, count); + + xprt_write_space(&r_xprt->rx_xprt); +} + +static void +rpcrdma_mr_refresh_worker(struct work_struct *work) +{ + struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer, + rb_refresh_worker.work); + struct rpcrdma_xprt *r_xprt = container_of(buf, struct rpcrdma_xprt, + rx_buf); + + rpcrdma_mrs_create(r_xprt); +} + +struct rpcrdma_req * +rpcrdma_create_req(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buffer = &r_xprt->rx_buf; + struct rpcrdma_regbuf *rb; + struct rpcrdma_req *req; + + req = kzalloc(sizeof(*req), GFP_KERNEL); + if (req == NULL) + return ERR_PTR(-ENOMEM); + + rb = rpcrdma_alloc_regbuf(RPCRDMA_HDRBUF_SIZE, + DMA_TO_DEVICE, GFP_KERNEL); + if (IS_ERR(rb)) { + kfree(req); + return ERR_PTR(-ENOMEM); + } + req->rl_rdmabuf = rb; + xdr_buf_init(&req->rl_hdrbuf, rb->rg_base, rdmab_length(rb)); + req->rl_buffer = buffer; + INIT_LIST_HEAD(&req->rl_registered); + + spin_lock(&buffer->rb_reqslock); + list_add(&req->rl_all, &buffer->rb_allreqs); + spin_unlock(&buffer->rb_reqslock); + return req; +} + +static int +rpcrdma_create_rep(struct rpcrdma_xprt *r_xprt, bool temp) +{ + struct rpcrdma_create_data_internal *cdata = &r_xprt->rx_data; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_rep *rep; + int rc; + + rc = -ENOMEM; + rep = kzalloc(sizeof(*rep), GFP_KERNEL); + if (rep == NULL) + goto out; + + rep->rr_rdmabuf = rpcrdma_alloc_regbuf(cdata->inline_rsize, + DMA_FROM_DEVICE, GFP_KERNEL); + if (IS_ERR(rep->rr_rdmabuf)) { + rc = PTR_ERR(rep->rr_rdmabuf); + goto out_free; + } + xdr_buf_init(&rep->rr_hdrbuf, rep->rr_rdmabuf->rg_base, + rdmab_length(rep->rr_rdmabuf)); + + rep->rr_cqe.done = rpcrdma_wc_receive; + rep->rr_rxprt = r_xprt; + INIT_WORK(&rep->rr_work, rpcrdma_deferred_completion); + rep->rr_recv_wr.next = NULL; + rep->rr_recv_wr.wr_cqe = &rep->rr_cqe; + rep->rr_recv_wr.sg_list = &rep->rr_rdmabuf->rg_iov; + rep->rr_recv_wr.num_sge = 1; + rep->rr_temp = temp; + + spin_lock(&buf->rb_lock); + list_add(&rep->rr_list, &buf->rb_recv_bufs); + spin_unlock(&buf->rb_lock); + return 0; + +out_free: + kfree(rep); +out: + dprintk("RPC: %s: reply buffer %d alloc failed\n", + __func__, rc); + return rc; +} + +int +rpcrdma_buffer_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + int i, rc; + + buf->rb_max_requests = r_xprt->rx_data.max_requests; + buf->rb_bc_srv_max_requests = 0; + spin_lock_init(&buf->rb_mrlock); + spin_lock_init(&buf->rb_lock); + spin_lock_init(&buf->rb_recovery_lock); + INIT_LIST_HEAD(&buf->rb_mrs); + INIT_LIST_HEAD(&buf->rb_all); + INIT_LIST_HEAD(&buf->rb_stale_mrs); + INIT_DELAYED_WORK(&buf->rb_refresh_worker, + rpcrdma_mr_refresh_worker); + INIT_DELAYED_WORK(&buf->rb_recovery_worker, + rpcrdma_mr_recovery_worker); + + rpcrdma_mrs_create(r_xprt); + + INIT_LIST_HEAD(&buf->rb_send_bufs); + INIT_LIST_HEAD(&buf->rb_allreqs); + spin_lock_init(&buf->rb_reqslock); + for (i = 0; i < buf->rb_max_requests; i++) { + struct rpcrdma_req *req; + + req = rpcrdma_create_req(r_xprt); + if (IS_ERR(req)) { + dprintk("RPC: %s: request buffer %d alloc" + " failed\n", __func__, i); + rc = PTR_ERR(req); + goto out; + } + list_add(&req->rl_list, &buf->rb_send_bufs); + } + + buf->rb_credits = 1; + buf->rb_posted_receives = 0; + INIT_LIST_HEAD(&buf->rb_recv_bufs); + + rc = rpcrdma_sendctxs_create(r_xprt); + if (rc) + goto out; + + return 0; +out: + rpcrdma_buffer_destroy(buf); + return rc; +} + +static void +rpcrdma_destroy_rep(struct rpcrdma_rep *rep) +{ + rpcrdma_free_regbuf(rep->rr_rdmabuf); + kfree(rep); +} + +void +rpcrdma_destroy_req(struct rpcrdma_req *req) +{ + rpcrdma_free_regbuf(req->rl_recvbuf); + rpcrdma_free_regbuf(req->rl_sendbuf); + rpcrdma_free_regbuf(req->rl_rdmabuf); + kfree(req); +} + +static void +rpcrdma_mrs_destroy(struct rpcrdma_buffer *buf) +{ + struct rpcrdma_xprt *r_xprt = container_of(buf, struct rpcrdma_xprt, + rx_buf); + struct rpcrdma_ia *ia = rdmab_to_ia(buf); + struct rpcrdma_mr *mr; + unsigned int count; + + count = 0; + spin_lock(&buf->rb_mrlock); + while (!list_empty(&buf->rb_all)) { + mr = list_entry(buf->rb_all.next, struct rpcrdma_mr, mr_all); + list_del(&mr->mr_all); + + spin_unlock(&buf->rb_mrlock); + + /* Ensure MW is not on any rl_registered list */ + if (!list_empty(&mr->mr_list)) + list_del(&mr->mr_list); + + ia->ri_ops->ro_release_mr(mr); + count++; + spin_lock(&buf->rb_mrlock); + } + spin_unlock(&buf->rb_mrlock); + r_xprt->rx_stats.mrs_allocated = 0; + + dprintk("RPC: %s: released %u MRs\n", __func__, count); +} + +void +rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) +{ + cancel_delayed_work_sync(&buf->rb_recovery_worker); + cancel_delayed_work_sync(&buf->rb_refresh_worker); + + rpcrdma_sendctxs_destroy(buf); + + while (!list_empty(&buf->rb_recv_bufs)) { + struct rpcrdma_rep *rep; + + rep = list_first_entry(&buf->rb_recv_bufs, + struct rpcrdma_rep, rr_list); + list_del(&rep->rr_list); + rpcrdma_destroy_rep(rep); + } + + spin_lock(&buf->rb_reqslock); + while (!list_empty(&buf->rb_allreqs)) { + struct rpcrdma_req *req; + + req = list_first_entry(&buf->rb_allreqs, + struct rpcrdma_req, rl_all); + list_del(&req->rl_all); + + spin_unlock(&buf->rb_reqslock); + rpcrdma_destroy_req(req); + spin_lock(&buf->rb_reqslock); + } + spin_unlock(&buf->rb_reqslock); + + rpcrdma_mrs_destroy(buf); +} + +/** + * rpcrdma_mr_get - Allocate an rpcrdma_mr object + * @r_xprt: controlling transport + * + * Returns an initialized rpcrdma_mr or NULL if no free + * rpcrdma_mr objects are available. + */ +struct rpcrdma_mr * +rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_mr *mr = NULL; + + spin_lock(&buf->rb_mrlock); + if (!list_empty(&buf->rb_mrs)) + mr = rpcrdma_mr_pop(&buf->rb_mrs); + spin_unlock(&buf->rb_mrlock); + + if (!mr) + goto out_nomrs; + return mr; + +out_nomrs: + trace_xprtrdma_nomrs(r_xprt); + if (r_xprt->rx_ep.rep_connected != -ENODEV) + schedule_delayed_work(&buf->rb_refresh_worker, 0); + + /* Allow the reply handler and refresh worker to run */ + cond_resched(); + + return NULL; +} + +static void +__rpcrdma_mr_put(struct rpcrdma_buffer *buf, struct rpcrdma_mr *mr) +{ + spin_lock(&buf->rb_mrlock); + rpcrdma_mr_push(mr, &buf->rb_mrs); + spin_unlock(&buf->rb_mrlock); +} + +/** + * rpcrdma_mr_put - Release an rpcrdma_mr object + * @mr: object to release + * + */ +void +rpcrdma_mr_put(struct rpcrdma_mr *mr) +{ + __rpcrdma_mr_put(&mr->mr_xprt->rx_buf, mr); +} + +/** + * rpcrdma_mr_unmap_and_put - DMA unmap an MR and release it + * @mr: object to release + * + */ +void +rpcrdma_mr_unmap_and_put(struct rpcrdma_mr *mr) +{ + struct rpcrdma_xprt *r_xprt = mr->mr_xprt; + + trace_xprtrdma_dma_unmap(mr); + ib_dma_unmap_sg(r_xprt->rx_ia.ri_device, + mr->mr_sg, mr->mr_nents, mr->mr_dir); + __rpcrdma_mr_put(&r_xprt->rx_buf, mr); +} + +/** + * rpcrdma_buffer_get - Get a request buffer + * @buffers: Buffer pool from which to obtain a buffer + * + * Returns a fresh rpcrdma_req, or NULL if none are available. + */ +struct rpcrdma_req * +rpcrdma_buffer_get(struct rpcrdma_buffer *buffers) +{ + struct rpcrdma_req *req; + + spin_lock(&buffers->rb_lock); + req = list_first_entry_or_null(&buffers->rb_send_bufs, + struct rpcrdma_req, rl_list); + if (req) + list_del_init(&req->rl_list); + spin_unlock(&buffers->rb_lock); + return req; +} + +/** + * rpcrdma_buffer_put - Put request/reply buffers back into pool + * @req: object to return + * + */ +void +rpcrdma_buffer_put(struct rpcrdma_req *req) +{ + struct rpcrdma_buffer *buffers = req->rl_buffer; + struct rpcrdma_rep *rep = req->rl_reply; + + req->rl_reply = NULL; + + spin_lock(&buffers->rb_lock); + list_add(&req->rl_list, &buffers->rb_send_bufs); + if (rep) { + if (!rep->rr_temp) { + list_add(&rep->rr_list, &buffers->rb_recv_bufs); + rep = NULL; + } + } + spin_unlock(&buffers->rb_lock); + if (rep) + rpcrdma_destroy_rep(rep); +} + +/* + * Put reply buffers back into pool when not attached to + * request. This happens in error conditions. + */ +void +rpcrdma_recv_buffer_put(struct rpcrdma_rep *rep) +{ + struct rpcrdma_buffer *buffers = &rep->rr_rxprt->rx_buf; + + if (!rep->rr_temp) { + spin_lock(&buffers->rb_lock); + list_add(&rep->rr_list, &buffers->rb_recv_bufs); + spin_unlock(&buffers->rb_lock); + } else { + rpcrdma_destroy_rep(rep); + } +} + +/** + * rpcrdma_alloc_regbuf - allocate and DMA-map memory for SEND/RECV buffers + * @size: size of buffer to be allocated, in bytes + * @direction: direction of data movement + * @flags: GFP flags + * + * Returns an ERR_PTR, or a pointer to a regbuf, a buffer that + * can be persistently DMA-mapped for I/O. + * + * xprtrdma uses a regbuf for posting an outgoing RDMA SEND, or for + * receiving the payload of RDMA RECV operations. During Long Calls + * or Replies they may be registered externally via ro_map. + */ +struct rpcrdma_regbuf * +rpcrdma_alloc_regbuf(size_t size, enum dma_data_direction direction, + gfp_t flags) +{ + struct rpcrdma_regbuf *rb; + + rb = kmalloc(sizeof(*rb) + size, flags); + if (rb == NULL) + return ERR_PTR(-ENOMEM); + + rb->rg_device = NULL; + rb->rg_direction = direction; + rb->rg_iov.length = size; + + return rb; +} + +/** + * __rpcrdma_map_regbuf - DMA-map a regbuf + * @ia: controlling rpcrdma_ia + * @rb: regbuf to be mapped + */ +bool +__rpcrdma_dma_map_regbuf(struct rpcrdma_ia *ia, struct rpcrdma_regbuf *rb) +{ + struct ib_device *device = ia->ri_device; + + if (rb->rg_direction == DMA_NONE) + return false; + + rb->rg_iov.addr = ib_dma_map_single(device, + (void *)rb->rg_base, + rdmab_length(rb), + rb->rg_direction); + if (ib_dma_mapping_error(device, rdmab_addr(rb))) + return false; + + rb->rg_device = device; + rb->rg_iov.lkey = ia->ri_pd->local_dma_lkey; + return true; +} + +static void +rpcrdma_dma_unmap_regbuf(struct rpcrdma_regbuf *rb) +{ + if (!rb) + return; + + if (!rpcrdma_regbuf_is_mapped(rb)) + return; + + ib_dma_unmap_single(rb->rg_device, rdmab_addr(rb), + rdmab_length(rb), rb->rg_direction); + rb->rg_device = NULL; +} + +/** + * rpcrdma_free_regbuf - deregister and free registered buffer + * @rb: regbuf to be deregistered and freed + */ +void +rpcrdma_free_regbuf(struct rpcrdma_regbuf *rb) +{ + rpcrdma_dma_unmap_regbuf(rb); + kfree(rb); +} + +/* + * Prepost any receive buffer, then post send. + * + * Receive buffer is donated to hardware, reclaimed upon recv completion. + */ +int +rpcrdma_ep_post(struct rpcrdma_ia *ia, + struct rpcrdma_ep *ep, + struct rpcrdma_req *req) +{ + struct ib_send_wr *send_wr = &req->rl_sendctx->sc_wr; + int rc; + + if (!ep->rep_send_count || + test_bit(RPCRDMA_REQ_F_TX_RESOURCES, &req->rl_flags)) { + send_wr->send_flags |= IB_SEND_SIGNALED; + ep->rep_send_count = ep->rep_send_batch; + } else { + send_wr->send_flags &= ~IB_SEND_SIGNALED; + --ep->rep_send_count; + } + + rc = ia->ri_ops->ro_send(ia, req); + trace_xprtrdma_post_send(req, rc); + if (rc) + return -ENOTCONN; + return 0; +} + +/** + * rpcrdma_post_recvs - Maybe post some Receive buffers + * @r_xprt: controlling transport + * @temp: when true, allocate temp rpcrdma_rep objects + * + */ +void +rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct ib_recv_wr *wr, *bad_wr; + int needed, count, rc; + + needed = buf->rb_credits + (buf->rb_bc_srv_max_requests << 1); + if (buf->rb_posted_receives > needed) + return; + needed -= buf->rb_posted_receives; + + count = 0; + wr = NULL; + while (needed) { + struct rpcrdma_regbuf *rb; + struct rpcrdma_rep *rep; + + spin_lock(&buf->rb_lock); + rep = list_first_entry_or_null(&buf->rb_recv_bufs, + struct rpcrdma_rep, rr_list); + if (likely(rep)) + list_del(&rep->rr_list); + spin_unlock(&buf->rb_lock); + if (!rep) { + if (rpcrdma_create_rep(r_xprt, temp)) + break; + continue; + } + + rb = rep->rr_rdmabuf; + if (!rpcrdma_regbuf_is_mapped(rb)) { + if (!__rpcrdma_dma_map_regbuf(&r_xprt->rx_ia, rb)) { + rpcrdma_recv_buffer_put(rep); + break; + } + } + + trace_xprtrdma_post_recv(rep->rr_recv_wr.wr_cqe); + rep->rr_recv_wr.next = wr; + wr = &rep->rr_recv_wr; + ++count; + --needed; + } + if (!count) + return; + + rc = ib_post_recv(r_xprt->rx_ia.ri_id->qp, wr, + (const struct ib_recv_wr **)&bad_wr); + if (rc) { + for (wr = bad_wr; wr;) { + struct rpcrdma_rep *rep; + + rep = container_of(wr, struct rpcrdma_rep, rr_recv_wr); + wr = wr->next; + rpcrdma_recv_buffer_put(rep); + --count; + } + } + buf->rb_posted_receives += count; + trace_xprtrdma_post_recvs(r_xprt, count, rc); +} diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h new file mode 100644 index 000000000..2ca14f7c2 --- /dev/null +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -0,0 +1,675 @@ +/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */ +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _LINUX_SUNRPC_XPRT_RDMA_H +#define _LINUX_SUNRPC_XPRT_RDMA_H + +#include <linux/wait.h> /* wait_queue_head_t, etc */ +#include <linux/spinlock.h> /* spinlock_t, etc */ +#include <linux/atomic.h> /* atomic_t, etc */ +#include <linux/workqueue.h> /* struct work_struct */ + +#include <rdma/rdma_cm.h> /* RDMA connection api */ +#include <rdma/ib_verbs.h> /* RDMA verbs api */ + +#include <linux/sunrpc/clnt.h> /* rpc_xprt */ +#include <linux/sunrpc/rpc_rdma.h> /* RPC/RDMA protocol */ +#include <linux/sunrpc/xprtrdma.h> /* xprt parameters */ + +#define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */ +#define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */ + +#define RPCRDMA_BIND_TO (60U * HZ) +#define RPCRDMA_INIT_REEST_TO (5U * HZ) +#define RPCRDMA_MAX_REEST_TO (30U * HZ) +#define RPCRDMA_IDLE_DISC_TO (5U * 60 * HZ) + +/* + * Interface Adapter -- one per transport instance + */ +struct rpcrdma_ia { + const struct rpcrdma_memreg_ops *ri_ops; + struct ib_device *ri_device; + struct rdma_cm_id *ri_id; + struct ib_pd *ri_pd; + struct completion ri_done; + struct completion ri_remove_done; + int ri_async_rc; + unsigned int ri_max_segs; + unsigned int ri_max_frwr_depth; + unsigned int ri_max_inline_write; + unsigned int ri_max_inline_read; + unsigned int ri_max_send_sges; + bool ri_implicit_roundup; + enum ib_mr_type ri_mrtype; + unsigned long ri_flags; + struct ib_qp_attr ri_qp_attr; + struct ib_qp_init_attr ri_qp_init_attr; +}; + +enum { + RPCRDMA_IAF_REMOVING = 0, +}; + +/* + * RDMA Endpoint -- one per transport instance + */ + +struct rpcrdma_ep { + unsigned int rep_send_count; + unsigned int rep_send_batch; + int rep_connected; + struct ib_qp_init_attr rep_attr; + wait_queue_head_t rep_connect_wait; + struct rpcrdma_connect_private rep_cm_private; + struct rdma_conn_param rep_remote_cma; + struct delayed_work rep_connect_worker; +}; + +/* Pre-allocate extra Work Requests for handling backward receives + * and sends. This is a fixed value because the Work Queues are + * allocated when the forward channel is set up. + */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +#define RPCRDMA_BACKWARD_WRS (8) +#else +#define RPCRDMA_BACKWARD_WRS (0) +#endif + +/* Registered buffer -- registered kmalloc'd memory for RDMA SEND/RECV + * + * The below structure appears at the front of a large region of kmalloc'd + * memory, which always starts on a good alignment boundary. + */ + +struct rpcrdma_regbuf { + struct ib_sge rg_iov; + struct ib_device *rg_device; + enum dma_data_direction rg_direction; + __be32 rg_base[0] __attribute__ ((aligned(256))); +}; + +static inline u64 +rdmab_addr(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.addr; +} + +static inline u32 +rdmab_length(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.length; +} + +static inline u32 +rdmab_lkey(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.lkey; +} + +static inline struct ib_device * +rdmab_device(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device; +} + +#define RPCRDMA_DEF_GFP (GFP_NOIO | __GFP_NOWARN) + +/* To ensure a transport can always make forward progress, + * the number of RDMA segments allowed in header chunk lists + * is capped at 8. This prevents less-capable devices and + * memory registrations from overrunning the Send buffer + * while building chunk lists. + * + * Elements of the Read list take up more room than the + * Write list or Reply chunk. 8 read segments means the Read + * list (or Write list or Reply chunk) cannot consume more + * than + * + * ((8 + 2) * read segment size) + 1 XDR words, or 244 bytes. + * + * And the fixed part of the header is another 24 bytes. + * + * The smallest inline threshold is 1024 bytes, ensuring that + * at least 750 bytes are available for RPC messages. + */ +enum { + RPCRDMA_MAX_HDR_SEGS = 8, + RPCRDMA_HDRBUF_SIZE = 256, +}; + +/* + * struct rpcrdma_rep -- this structure encapsulates state required + * to receive and complete an RPC Reply, asychronously. It needs + * several pieces of state: + * + * o receive buffer and ib_sge (donated to provider) + * o status of receive (success or not, length, inv rkey) + * o bookkeeping state to get run by reply handler (XDR stream) + * + * These structures are allocated during transport initialization. + * N of these are associated with a transport instance, managed by + * struct rpcrdma_buffer. N is the max number of outstanding RPCs. + */ + +struct rpcrdma_rep { + struct ib_cqe rr_cqe; + __be32 rr_xid; + __be32 rr_vers; + __be32 rr_proc; + int rr_wc_flags; + u32 rr_inv_rkey; + bool rr_temp; + struct rpcrdma_regbuf *rr_rdmabuf; + struct rpcrdma_xprt *rr_rxprt; + struct work_struct rr_work; + struct xdr_buf rr_hdrbuf; + struct xdr_stream rr_stream; + struct rpc_rqst *rr_rqst; + struct list_head rr_list; + struct ib_recv_wr rr_recv_wr; +}; + +/* struct rpcrdma_sendctx - DMA mapped SGEs to unmap after Send completes + */ +struct rpcrdma_req; +struct rpcrdma_xprt; +struct rpcrdma_sendctx { + struct ib_send_wr sc_wr; + struct ib_cqe sc_cqe; + struct rpcrdma_xprt *sc_xprt; + struct rpcrdma_req *sc_req; + unsigned int sc_unmap_count; + struct ib_sge sc_sges[]; +}; + +/* Limit the number of SGEs that can be unmapped during one + * Send completion. This caps the amount of work a single + * completion can do before returning to the provider. + * + * Setting this to zero disables Send completion batching. + */ +enum { + RPCRDMA_MAX_SEND_BATCH = 7, +}; + +/* + * struct rpcrdma_mr - external memory region metadata + * + * An external memory region is any buffer or page that is registered + * on the fly (ie, not pre-registered). + * + * Each rpcrdma_buffer has a list of free MWs anchored in rb_mrs. During + * call_allocate, rpcrdma_buffer_get() assigns one to each segment in + * an rpcrdma_req. Then rpcrdma_register_external() grabs these to keep + * track of registration metadata while each RPC is pending. + * rpcrdma_deregister_external() uses this metadata to unmap and + * release these resources when an RPC is complete. + */ +enum rpcrdma_frwr_state { + FRWR_IS_INVALID, /* ready to be used */ + FRWR_IS_VALID, /* in use */ + FRWR_FLUSHED_FR, /* flushed FASTREG WR */ + FRWR_FLUSHED_LI, /* flushed LOCALINV WR */ +}; + +struct rpcrdma_frwr { + struct ib_mr *fr_mr; + struct ib_cqe fr_cqe; + enum rpcrdma_frwr_state fr_state; + struct completion fr_linv_done; + union { + struct ib_reg_wr fr_regwr; + struct ib_send_wr fr_invwr; + }; +}; + +struct rpcrdma_fmr { + struct ib_fmr *fm_mr; + u64 *fm_physaddrs; +}; + +struct rpcrdma_mr { + struct list_head mr_list; + struct scatterlist *mr_sg; + int mr_nents; + enum dma_data_direction mr_dir; + union { + struct rpcrdma_fmr fmr; + struct rpcrdma_frwr frwr; + }; + struct rpcrdma_xprt *mr_xprt; + u32 mr_handle; + u32 mr_length; + u64 mr_offset; + struct list_head mr_all; +}; + +/* + * struct rpcrdma_req -- structure central to the request/reply sequence. + * + * N of these are associated with a transport instance, and stored in + * struct rpcrdma_buffer. N is the max number of outstanding requests. + * + * It includes pre-registered buffer memory for send AND recv. + * The recv buffer, however, is not owned by this structure, and + * is "donated" to the hardware when a recv is posted. When a + * reply is handled, the recv buffer used is given back to the + * struct rpcrdma_req associated with the request. + * + * In addition to the basic memory, this structure includes an array + * of iovs for send operations. The reason is that the iovs passed to + * ib_post_{send,recv} must not be modified until the work request + * completes. + */ + +/* Maximum number of page-sized "segments" per chunk list to be + * registered or invalidated. Must handle a Reply chunk: + */ +enum { + RPCRDMA_MAX_IOV_SEGS = 3, + RPCRDMA_MAX_DATA_SEGS = ((1 * 1024 * 1024) / PAGE_SIZE) + 1, + RPCRDMA_MAX_SEGS = RPCRDMA_MAX_DATA_SEGS + + RPCRDMA_MAX_IOV_SEGS, +}; + +struct rpcrdma_mr_seg { /* chunk descriptors */ + u32 mr_len; /* length of chunk or segment */ + struct page *mr_page; /* owning page, if any */ + char *mr_offset; /* kva if no page, else offset */ +}; + +/* The Send SGE array is provisioned to send a maximum size + * inline request: + * - RPC-over-RDMA header + * - xdr_buf head iovec + * - RPCRDMA_MAX_INLINE bytes, in pages + * - xdr_buf tail iovec + * + * The actual number of array elements consumed by each RPC + * depends on the device's max_sge limit. + */ +enum { + RPCRDMA_MIN_SEND_SGES = 3, + RPCRDMA_MAX_PAGE_SGES = RPCRDMA_MAX_INLINE >> PAGE_SHIFT, + RPCRDMA_MAX_SEND_SGES = 1 + 1 + RPCRDMA_MAX_PAGE_SGES + 1, +}; + +struct rpcrdma_buffer; +struct rpcrdma_req { + struct list_head rl_list; + struct rpc_rqst rl_slot; + struct rpcrdma_buffer *rl_buffer; + struct rpcrdma_rep *rl_reply; + struct xdr_stream rl_stream; + struct xdr_buf rl_hdrbuf; + struct rpcrdma_sendctx *rl_sendctx; + struct rpcrdma_regbuf *rl_rdmabuf; /* xprt header */ + struct rpcrdma_regbuf *rl_sendbuf; /* rq_snd_buf */ + struct rpcrdma_regbuf *rl_recvbuf; /* rq_rcv_buf */ + + struct list_head rl_all; + unsigned long rl_flags; + + struct list_head rl_registered; /* registered segments */ + struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS]; +}; + +/* rl_flags */ +enum { + RPCRDMA_REQ_F_PENDING = 0, + RPCRDMA_REQ_F_TX_RESOURCES, +}; + +static inline struct rpcrdma_req * +rpcr_to_rdmar(const struct rpc_rqst *rqst) +{ + return container_of(rqst, struct rpcrdma_req, rl_slot); +} + +static inline void +rpcrdma_mr_push(struct rpcrdma_mr *mr, struct list_head *list) +{ + list_add_tail(&mr->mr_list, list); +} + +static inline struct rpcrdma_mr * +rpcrdma_mr_pop(struct list_head *list) +{ + struct rpcrdma_mr *mr; + + mr = list_first_entry(list, struct rpcrdma_mr, mr_list); + list_del_init(&mr->mr_list); + return mr; +} + +/* + * struct rpcrdma_buffer -- holds list/queue of pre-registered memory for + * inline requests/replies, and client/server credits. + * + * One of these is associated with a transport instance + */ +struct rpcrdma_buffer { + spinlock_t rb_mrlock; /* protect rb_mrs list */ + struct list_head rb_mrs; + struct list_head rb_all; + + unsigned long rb_sc_head; + unsigned long rb_sc_tail; + unsigned long rb_sc_last; + struct rpcrdma_sendctx **rb_sc_ctxs; + + spinlock_t rb_lock; /* protect buf lists */ + struct list_head rb_send_bufs; + struct list_head rb_recv_bufs; + unsigned long rb_flags; + u32 rb_max_requests; + u32 rb_credits; /* most recent credit grant */ + int rb_posted_receives; + + u32 rb_bc_srv_max_requests; + spinlock_t rb_reqslock; /* protect rb_allreqs */ + struct list_head rb_allreqs; + + u32 rb_bc_max_requests; + + spinlock_t rb_recovery_lock; /* protect rb_stale_mrs */ + struct list_head rb_stale_mrs; + struct delayed_work rb_recovery_worker; + struct delayed_work rb_refresh_worker; +}; +#define rdmab_to_ia(b) (&container_of((b), struct rpcrdma_xprt, rx_buf)->rx_ia) + +/* rb_flags */ +enum { + RPCRDMA_BUF_F_EMPTY_SCQ = 0, +}; + +/* + * Internal structure for transport instance creation. This + * exists primarily for modularity. + * + * This data should be set with mount options + */ +struct rpcrdma_create_data_internal { + unsigned int max_requests; /* max requests (slots) in flight */ + unsigned int rsize; /* mount rsize - max read hdr+data */ + unsigned int wsize; /* mount wsize - max write hdr+data */ + unsigned int inline_rsize; /* max non-rdma read data payload */ + unsigned int inline_wsize; /* max non-rdma write data payload */ +}; + +/* + * Statistics for RPCRDMA + */ +struct rpcrdma_stats { + /* accessed when sending a call */ + unsigned long read_chunk_count; + unsigned long write_chunk_count; + unsigned long reply_chunk_count; + unsigned long long total_rdma_request; + + /* rarely accessed error counters */ + unsigned long long pullup_copy_count; + unsigned long hardway_register_count; + unsigned long failed_marshal_count; + unsigned long bad_reply_count; + unsigned long mrs_recovered; + unsigned long mrs_orphaned; + unsigned long mrs_allocated; + unsigned long empty_sendctx_q; + + /* accessed when receiving a reply */ + unsigned long long total_rdma_reply; + unsigned long long fixup_copy_count; + unsigned long reply_waits_for_send; + unsigned long local_inv_needed; + unsigned long nomsg_call_count; + unsigned long bcall_count; +}; + +/* + * Per-registration mode operations + */ +struct rpcrdma_xprt; +struct rpcrdma_memreg_ops { + struct rpcrdma_mr_seg * + (*ro_map)(struct rpcrdma_xprt *, + struct rpcrdma_mr_seg *, int, bool, + struct rpcrdma_mr **); + int (*ro_send)(struct rpcrdma_ia *ia, + struct rpcrdma_req *req); + void (*ro_reminv)(struct rpcrdma_rep *rep, + struct list_head *mrs); + void (*ro_unmap_sync)(struct rpcrdma_xprt *, + struct list_head *); + void (*ro_recover_mr)(struct rpcrdma_mr *mr); + int (*ro_open)(struct rpcrdma_ia *, + struct rpcrdma_ep *, + struct rpcrdma_create_data_internal *); + size_t (*ro_maxpages)(struct rpcrdma_xprt *); + int (*ro_init_mr)(struct rpcrdma_ia *, + struct rpcrdma_mr *); + void (*ro_release_mr)(struct rpcrdma_mr *mr); + const char *ro_displayname; + const int ro_send_w_inv_ok; +}; + +extern const struct rpcrdma_memreg_ops rpcrdma_fmr_memreg_ops; +extern const struct rpcrdma_memreg_ops rpcrdma_frwr_memreg_ops; + +/* + * RPCRDMA transport -- encapsulates the structures above for + * integration with RPC. + * + * The contained structures are embedded, not pointers, + * for convenience. This structure need not be visible externally. + * + * It is allocated and initialized during mount, and released + * during unmount. + */ +struct rpcrdma_xprt { + struct rpc_xprt rx_xprt; + struct rpcrdma_ia rx_ia; + struct rpcrdma_ep rx_ep; + struct rpcrdma_buffer rx_buf; + struct rpcrdma_create_data_internal rx_data; + struct delayed_work rx_connect_worker; + struct rpcrdma_stats rx_stats; +}; + +#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, rx_xprt) +#define rpcx_to_rdmad(x) (rpcx_to_rdmax(x)->rx_data) + +static inline const char * +rpcrdma_addrstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_ADDR]; +} + +static inline const char * +rpcrdma_portstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_PORT]; +} + +/* Setting this to 0 ensures interoperability with early servers. + * Setting this to 1 enhances certain unaligned read/write performance. + * Default is 0, see sysctl entry and rpc_rdma.c rpcrdma_convert_iovs() */ +extern int xprt_rdma_pad_optimize; + +/* This setting controls the hunt for a supported memory + * registration strategy. + */ +extern unsigned int xprt_rdma_memreg_strategy; + +/* + * Interface Adapter calls - xprtrdma/verbs.c + */ +int rpcrdma_ia_open(struct rpcrdma_xprt *xprt); +void rpcrdma_ia_remove(struct rpcrdma_ia *ia); +void rpcrdma_ia_close(struct rpcrdma_ia *); +bool frwr_is_supported(struct rpcrdma_ia *); +bool fmr_is_supported(struct rpcrdma_ia *); + +extern struct workqueue_struct *rpcrdma_receive_wq; + +/* + * Endpoint calls - xprtrdma/verbs.c + */ +int rpcrdma_ep_create(struct rpcrdma_ep *, struct rpcrdma_ia *, + struct rpcrdma_create_data_internal *); +void rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); +int rpcrdma_ep_connect(struct rpcrdma_ep *, struct rpcrdma_ia *); +void rpcrdma_conn_func(struct rpcrdma_ep *ep); +void rpcrdma_ep_disconnect(struct rpcrdma_ep *, struct rpcrdma_ia *); + +int rpcrdma_ep_post(struct rpcrdma_ia *, struct rpcrdma_ep *, + struct rpcrdma_req *); +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp); + +/* + * Buffer calls - xprtrdma/verbs.c + */ +struct rpcrdma_req *rpcrdma_create_req(struct rpcrdma_xprt *); +void rpcrdma_destroy_req(struct rpcrdma_req *); +int rpcrdma_buffer_create(struct rpcrdma_xprt *); +void rpcrdma_buffer_destroy(struct rpcrdma_buffer *); +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_buffer *buf); + +struct rpcrdma_mr *rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt); +void rpcrdma_mr_put(struct rpcrdma_mr *mr); +void rpcrdma_mr_unmap_and_put(struct rpcrdma_mr *mr); +void rpcrdma_mr_defer_recovery(struct rpcrdma_mr *mr); + +struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *); +void rpcrdma_buffer_put(struct rpcrdma_req *); +void rpcrdma_recv_buffer_put(struct rpcrdma_rep *); + +struct rpcrdma_regbuf *rpcrdma_alloc_regbuf(size_t, enum dma_data_direction, + gfp_t); +bool __rpcrdma_dma_map_regbuf(struct rpcrdma_ia *, struct rpcrdma_regbuf *); +void rpcrdma_free_regbuf(struct rpcrdma_regbuf *); + +static inline bool +rpcrdma_regbuf_is_mapped(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device != NULL; +} + +static inline bool +rpcrdma_dma_map_regbuf(struct rpcrdma_ia *ia, struct rpcrdma_regbuf *rb) +{ + if (likely(rpcrdma_regbuf_is_mapped(rb))) + return true; + return __rpcrdma_dma_map_regbuf(ia, rb); +} + +int rpcrdma_alloc_wq(void); +void rpcrdma_destroy_wq(void); + +/* + * Wrappers for chunk registration, shared by read/write chunk code. + */ + +static inline enum dma_data_direction +rpcrdma_data_dir(bool writing) +{ + return writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE; +} + +/* + * RPC/RDMA protocol calls - xprtrdma/rpc_rdma.c + */ + +enum rpcrdma_chunktype { + rpcrdma_noch = 0, + rpcrdma_readch, + rpcrdma_areadch, + rpcrdma_writech, + rpcrdma_replych +}; + +int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, + enum rpcrdma_chunktype rtype); +void rpcrdma_unmap_sendctx(struct rpcrdma_sendctx *sc); +int rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst); +void rpcrdma_set_max_header_sizes(struct rpcrdma_xprt *); +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep); +void rpcrdma_reply_handler(struct rpcrdma_rep *rep); +void rpcrdma_release_rqst(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req); +void rpcrdma_deferred_completion(struct work_struct *work); + +static inline void rpcrdma_set_xdrlen(struct xdr_buf *xdr, size_t len) +{ + xdr->head[0].iov_len = len; + xdr->len = len; +} + +/* RPC/RDMA module init - xprtrdma/transport.c + */ +extern unsigned int xprt_rdma_max_inline_read; +void xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap); +void xprt_rdma_free_addresses(struct rpc_xprt *xprt); +void rpcrdma_connect_worker(struct work_struct *work); +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq); +int xprt_rdma_init(void); +void xprt_rdma_cleanup(void); + +/* Backchannel calls - xprtrdma/backchannel.c + */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +int xprt_rdma_bc_setup(struct rpc_xprt *, unsigned int); +int xprt_rdma_bc_up(struct svc_serv *, struct net *); +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *); +int rpcrdma_bc_post_recv(struct rpcrdma_xprt *, unsigned int); +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *, struct rpcrdma_rep *); +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst); +void xprt_rdma_bc_free_rqst(struct rpc_rqst *); +void xprt_rdma_bc_destroy(struct rpc_xprt *, unsigned int); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +extern struct xprt_class xprt_rdma_bc; + +#endif /* _LINUX_SUNRPC_XPRT_RDMA_H */ diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c new file mode 100644 index 000000000..a0a82d9a5 --- /dev/null +++ b/net/sunrpc/xprtsock.c @@ -0,0 +1,3382 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/xprtsock.c + * + * Client-side transport implementation for sockets. + * + * TCP callback races fixes (C) 1998 Red Hat + * TCP send fixes (C) 1998 Red Hat + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie> + * + * Rewrite of larges part of the code in order to stabilize TCP stuff. + * Fix behaviour when socket buffer is full. + * (C) 1999 Trond Myklebust <trond.myklebust@fys.uio.no> + * + * IP socket transport implementation, (C) 2005 Chuck Lever <cel@netapp.com> + * + * IPv6 support contributed by Gilles Quillard, Bull Open Source, 2005. + * <gilles.quillard@bull.net> + */ + +#include <linux/types.h> +#include <linux/string.h> +#include <linux/slab.h> +#include <linux/module.h> +#include <linux/capability.h> +#include <linux/pagemap.h> +#include <linux/errno.h> +#include <linux/socket.h> +#include <linux/in.h> +#include <linux/net.h> +#include <linux/mm.h> +#include <linux/un.h> +#include <linux/udp.h> +#include <linux/tcp.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/xprtsock.h> +#include <linux/file.h> +#ifdef CONFIG_SUNRPC_BACKCHANNEL +#include <linux/sunrpc/bc_xprt.h> +#endif + +#include <net/sock.h> +#include <net/checksum.h> +#include <net/udp.h> +#include <net/tcp.h> + +#include <trace/events/sunrpc.h> + +#include "sunrpc.h" + +#define RPC_TCP_READ_CHUNK_SZ (3*512*1024) + +static void xs_close(struct rpc_xprt *xprt); +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock); + +/* + * xprtsock tunables + */ +static unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE; +static unsigned int xprt_tcp_slot_table_entries = RPC_MIN_SLOT_TABLE; +static unsigned int xprt_max_tcp_slot_table_entries = RPC_MAX_SLOT_TABLE; + +static unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT; +static unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + +#define XS_TCP_LINGER_TO (15U * HZ) +static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; + +/* + * We can register our own files under /proc/sys/sunrpc by + * calling register_sysctl_table() again. The files in that + * directory become the union of all files registered there. + * + * We simply need to make sure that we don't collide with + * someone else's file names! + */ + +static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE; +static unsigned int max_tcp_slot_table_limit = RPC_MAX_SLOT_TABLE_LIMIT; +static unsigned int xprt_min_resvport_limit = RPC_MIN_RESVPORT; +static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT; + +static struct ctl_table_header *sunrpc_table_header; + +/* + * FIXME: changing the UDP slot table size should also resize the UDP + * socket buffers for existing UDP transports + */ +static struct ctl_table xs_tunables_table[] = { + { + .procname = "udp_slot_table_entries", + .data = &xprt_udp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "tcp_slot_table_entries", + .data = &xprt_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "tcp_max_slot_table_entries", + .data = &xprt_max_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_tcp_slot_table_limit + }, + { + .procname = "min_resvport", + .data = &xprt_min_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .procname = "max_resvport", + .data = &xprt_max_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .procname = "tcp_fin_timeout", + .data = &xs_tcp_fin_timeout, + .maxlen = sizeof(xs_tcp_fin_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { }, +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = xs_tunables_table + }, + { }, +}; + +#endif + +/* + * Wait duration for a reply from the RPC portmapper. + */ +#define XS_BIND_TO (60U * HZ) + +/* + * Delay if a UDP socket connect error occurs. This is most likely some + * kind of resource problem on the local host. + */ +#define XS_UDP_REEST_TO (2U * HZ) + +/* + * The reestablish timeout allows clients to delay for a bit before attempting + * to reconnect to a server that just dropped our connection. + * + * We implement an exponential backoff when trying to reestablish a TCP + * transport connection with the server. Some servers like to drop a TCP + * connection when they are overworked, so we start with a short timeout and + * increase over time if the server is down or not responding. + */ +#define XS_TCP_INIT_REEST_TO (3U * HZ) + +/* + * TCP idle timeout; client drops the transport socket if it is idle + * for this long. Note that we also timeout UDP sockets to prevent + * holding port numbers when there is no RPC traffic. + */ +#define XS_IDLE_DISC_TO (5U * 60 * HZ) + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# undef RPC_DEBUG_DATA +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +#ifdef RPC_DEBUG_DATA +static void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + u8 *buf = (u8 *) packet; + int j; + + dprintk("RPC: %s\n", msg); + for (j = 0; j < count && j < 128; j += 4) { + if (!(j & 31)) { + if (j) + dprintk("\n"); + dprintk("0x%04x ", j); + } + dprintk("%02x%02x%02x%02x ", + buf[j], buf[j+1], buf[j+2], buf[j+3]); + } + dprintk("\n"); +} +#else +static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + /* NOP */ +} +#endif + +static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +{ + return (struct rpc_xprt *) sk->sk_user_data; +} + +static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt) +{ + return (struct sockaddr *) &xprt->addr; +} + +static inline struct sockaddr_un *xs_addr_un(struct rpc_xprt *xprt) +{ + return (struct sockaddr_un *) &xprt->addr; +} + +static inline struct sockaddr_in *xs_addr_in(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in *) &xprt->addr; +} + +static inline struct sockaddr_in6 *xs_addr_in6(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in6 *) &xprt->addr; +} + +static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) +{ + struct sockaddr *sap = xs_addr(xprt); + struct sockaddr_in6 *sin6; + struct sockaddr_in *sin; + struct sockaddr_un *sun; + char buf[128]; + + switch (sap->sa_family) { + case AF_LOCAL: + sun = xs_addr_un(xprt); + strlcpy(buf, sun->sun_path, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + break; + case AF_INET: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + sin = xs_addr_in(xprt); + snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); + break; + case AF_INET6: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + sin6 = xs_addr_in6(xprt); + snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); + break; + default: + BUG(); + } + + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); +} + +static void xs_format_common_peer_ports(struct rpc_xprt *xprt) +{ + struct sockaddr *sap = xs_addr(xprt); + char buf[128]; + + snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); +} + +static void xs_format_peer_addresses(struct rpc_xprt *xprt, + const char *protocol, + const char *netid) +{ + xprt->address_strings[RPC_DISPLAY_PROTO] = protocol; + xprt->address_strings[RPC_DISPLAY_NETID] = netid; + xs_format_common_peer_addresses(xprt); + xs_format_common_peer_ports(xprt); +} + +static void xs_update_peer_port(struct rpc_xprt *xprt) +{ + kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]); + kfree(xprt->address_strings[RPC_DISPLAY_PORT]); + + xs_format_common_peer_ports(xprt); +} + +static void xs_free_peer_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) + +static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, struct kvec *vec, unsigned int base, int more) +{ + struct msghdr msg = { + .msg_name = addr, + .msg_namelen = addrlen, + .msg_flags = XS_SENDMSG_FLAGS | (more ? MSG_MORE : 0), + }; + struct kvec iov = { + .iov_base = vec->iov_base + base, + .iov_len = vec->iov_len - base, + }; + + if (iov.iov_len != 0) + return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len); + return kernel_sendmsg(sock, &msg, NULL, 0, 0); +} + +static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more, bool zerocopy, int *sent_p) +{ + ssize_t (*do_sendpage)(struct socket *sock, struct page *page, + int offset, size_t size, int flags); + struct page **ppage; + unsigned int remainder; + int err; + + remainder = xdr->page_len - base; + base += xdr->page_base; + ppage = xdr->pages + (base >> PAGE_SHIFT); + base &= ~PAGE_MASK; + do_sendpage = sock->ops->sendpage; + if (!zerocopy) + do_sendpage = sock_no_sendpage; + for(;;) { + unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder); + int flags = XS_SENDMSG_FLAGS; + + remainder -= len; + if (more) + flags |= MSG_MORE; + if (remainder != 0) + flags |= MSG_SENDPAGE_NOTLAST | MSG_MORE; + err = do_sendpage(sock, *ppage, base, len, flags); + if (remainder == 0 || err != len) + break; + *sent_p += err; + ppage++; + base = 0; + } + if (err > 0) { + *sent_p += err; + err = 0; + } + return err; +} + +/** + * xs_sendpages - write pages directly to a socket + * @sock: socket to send on + * @addr: UDP only -- address of destination + * @addrlen: UDP only -- length of destination address + * @xdr: buffer containing this request + * @base: starting position in the buffer + * @zerocopy: true if it is safe to use sendpage() + * @sent_p: return the total number of bytes successfully queued for sending + * + */ +static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, bool zerocopy, int *sent_p) +{ + unsigned int remainder = xdr->len - base; + int err = 0; + int sent = 0; + + if (unlikely(!sock)) + return -ENOTSOCK; + + if (base != 0) { + addr = NULL; + addrlen = 0; + } + + if (base < xdr->head[0].iov_len || addr != NULL) { + unsigned int len = xdr->head[0].iov_len - base; + remainder -= len; + err = xs_send_kvec(sock, addr, addrlen, &xdr->head[0], base, remainder != 0); + if (remainder == 0 || err != len) + goto out; + *sent_p += err; + base = 0; + } else + base -= xdr->head[0].iov_len; + + if (base < xdr->page_len) { + unsigned int len = xdr->page_len - base; + remainder -= len; + err = xs_send_pagedata(sock, xdr, base, remainder != 0, zerocopy, &sent); + *sent_p += sent; + if (remainder == 0 || sent != len) + goto out; + base = 0; + } else + base -= xdr->page_len; + + if (base >= xdr->tail[0].iov_len) + return 0; + err = xs_send_kvec(sock, NULL, 0, &xdr->tail[0], base, 0); +out: + if (err > 0) { + *sent_p += err; + err = 0; + } + return err; +} + +static void xs_nospace_callback(struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(task->tk_rqstp->rq_xprt, struct sock_xprt, xprt); + + transport->inet->sk_write_pending--; +} + +/** + * xs_nospace - place task on wait queue if transmit was incomplete + * @task: task to put to sleep + * + */ +static int xs_nospace(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + dprintk("RPC: %5u xmit incomplete (%u left of %u)\n", + task->tk_pid, req->rq_slen - req->rq_bytes_sent, + req->rq_slen); + + /* Protect against races with write_space */ + spin_lock_bh(&xprt->transport_lock); + + /* Don't race with disconnect */ + if (xprt_connected(xprt)) { + /* wait for more buffer space */ + sk->sk_write_pending++; + xprt_wait_for_buffer_space(task, xs_nospace_callback); + } else + ret = -ENOTCONN; + + spin_unlock_bh(&xprt->transport_lock); + + /* Race breaker in case memory is freed before above code is called */ + if (ret == -EAGAIN) { + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + set_bit(SOCKWQ_ASYNC_NOSPACE, &wq->flags); + rcu_read_unlock(); + + sk->sk_write_space(sk); + } + return ret; +} + +/* + * Construct a stream transport record marker in @buf. + */ +static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) +{ + u32 reclen = buf->len - sizeof(rpc_fraghdr); + rpc_fraghdr *base = buf->head[0].iov_base; + *base = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | reclen); +} + +/** + * xs_local_send_request - write an RPC request to an AF_LOCAL socket + * @task: RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occured, the request was not sent + */ +static int xs_local_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + int status; + int sent = 0; + + xs_encode_stream_record_marker(&req->rq_snd_buf); + + xs_pktdump("packet data:", + req->rq_svec->iov_base, req->rq_svec->iov_len); + + req->rq_xtime = ktime_get(); + status = xs_sendpages(transport->sock, NULL, 0, xdr, req->rq_bytes_sent, + true, &sent); + dprintk("RPC: %s(%u) = %d\n", + __func__, xdr->len - req->rq_bytes_sent, status); + + if (status == -EAGAIN && sock_writeable(transport->inet)) + status = -ENOBUFS; + + if (likely(sent > 0) || status == 0) { + req->rq_bytes_sent += sent; + req->rq_xmit_bytes_sent += sent; + if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_bytes_sent = 0; + return 0; + } + status = -EAGAIN; + } + + switch (status) { + case -ENOBUFS: + break; + case -EAGAIN: + status = xs_nospace(task); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + /* fall through */ + case -EPIPE: + xs_close(xprt); + status = -ENOTCONN; + } + + return status; +} + +/** + * xs_udp_send_request - write an RPC request to a UDP socket + * @task: address of RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occurred, the request was not sent + */ +static int xs_udp_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + int sent = 0; + int status; + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + + if (!xprt_bound(xprt)) + return -ENOTCONN; + req->rq_xtime = ktime_get(); + status = xs_sendpages(transport->sock, xs_addr(xprt), xprt->addrlen, + xdr, req->rq_bytes_sent, true, &sent); + + dprintk("RPC: xs_udp_send_request(%u) = %d\n", + xdr->len - req->rq_bytes_sent, status); + + /* firewall is blocking us, don't return -EAGAIN or we end up looping */ + if (status == -EPERM) + goto process_status; + + if (status == -EAGAIN && sock_writeable(transport->inet)) + status = -ENOBUFS; + + if (sent > 0 || status == 0) { + req->rq_xmit_bytes_sent += sent; + if (sent >= req->rq_slen) + return 0; + /* Still some bytes left; set up for a retry later. */ + status = -EAGAIN; + } + +process_status: + switch (status) { + case -ENOTSOCK: + status = -ENOTCONN; + /* Should we call xs_close() here? */ + break; + case -EAGAIN: + status = xs_nospace(task); + break; + case -ENETUNREACH: + case -ENOBUFS: + case -EPIPE: + case -ECONNREFUSED: + case -EPERM: + /* When the server has died, an ICMP port unreachable message + * prompts ECONNREFUSED. */ + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + } + + return status; +} + +/** + * xs_tcp_send_request - write an RPC request to a TCP socket + * @task: address of RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occurred, the request was not sent + * + * XXX: In the case of soft timeouts, should we eventually give up + * if sendmsg is not able to make progress? + */ +static int xs_tcp_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + bool zerocopy = true; + bool vm_wait = false; + int status; + int sent; + + xs_encode_stream_record_marker(&req->rq_snd_buf); + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + /* Don't use zero copy if this is a resend. If the RPC call + * completes while the socket holds a reference to the pages, + * then we may end up resending corrupted data. + */ + if (task->tk_flags & RPC_TASK_SENT) + zerocopy = false; + + if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state)) + xs_tcp_set_socket_timeouts(xprt, transport->sock); + + /* Continue transmitting the packet/record. We must be careful + * to cope with writespace callbacks arriving _after_ we have + * called sendmsg(). */ + req->rq_xtime = ktime_get(); + while (1) { + sent = 0; + status = xs_sendpages(transport->sock, NULL, 0, xdr, + req->rq_bytes_sent, zerocopy, &sent); + + dprintk("RPC: xs_tcp_send_request(%u) = %d\n", + xdr->len - req->rq_bytes_sent, status); + + /* If we've sent the entire packet, immediately + * reset the count of bytes sent. */ + req->rq_bytes_sent += sent; + req->rq_xmit_bytes_sent += sent; + if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_bytes_sent = 0; + return 0; + } + + WARN_ON_ONCE(sent == 0 && status == 0); + + if (status == -EAGAIN ) { + /* + * Return EAGAIN if we're sure we're hitting the + * socket send buffer limits. + */ + if (test_bit(SOCK_NOSPACE, &transport->sock->flags)) + break; + /* + * Did we hit a memory allocation failure? + */ + if (sent == 0) { + status = -ENOBUFS; + if (vm_wait) + break; + /* Retry, knowing now that we're below the + * socket send buffer limit + */ + vm_wait = true; + } + continue; + } + if (status < 0) + break; + vm_wait = false; + } + + switch (status) { + case -ENOTSOCK: + status = -ENOTCONN; + /* Should we call xs_close() here? */ + break; + case -EAGAIN: + status = xs_nospace(task); + break; + case -ECONNRESET: + case -ECONNREFUSED: + case -ENOTCONN: + case -EADDRINUSE: + case -ENOBUFS: + case -EPIPE: + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + } + + return status; +} + +/** + * xs_tcp_release_xprt - clean up after a tcp transmission + * @xprt: transport + * @task: rpc task + * + * This cleans up if an error causes us to abort the transmission of a request. + * In this case, the socket may need to be reset in order to avoid confusing + * the server. + */ +static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req; + + if (task != xprt->snd_task) + return; + if (task == NULL) + goto out_release; + req = task->tk_rqstp; + if (req == NULL) + goto out_release; + if (req->rq_bytes_sent == 0) + goto out_release; + if (req->rq_bytes_sent == req->rq_snd_buf.len) + goto out_release; + set_bit(XPRT_CLOSE_WAIT, &xprt->state); +out_release: + xprt_release_xprt(xprt, task); +} + +static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + transport->old_data_ready = sk->sk_data_ready; + transport->old_state_change = sk->sk_state_change; + transport->old_write_space = sk->sk_write_space; + transport->old_error_report = sk->sk_error_report; +} + +static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + sk->sk_data_ready = transport->old_data_ready; + sk->sk_state_change = transport->old_state_change; + sk->sk_write_space = transport->old_write_space; + sk->sk_error_report = transport->old_error_report; +} + +static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); +} + +static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt) +{ + smp_mb__before_atomic(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + clear_bit(XPRT_CLOSING, &xprt->state); + xs_sock_reset_state_flags(xprt); + smp_mb__after_atomic(); +} + +/** + * xs_error_report - callback to handle TCP socket state errors + * @sk: socket + * + * Note: we don't call sock_error() since there may be a rpc_task + * using the socket, and so we don't want to clear sk->sk_err. + */ +static void xs_error_report(struct sock *sk) +{ + struct rpc_xprt *xprt; + int err; + + read_lock_bh(&sk->sk_callback_lock); + if (!(xprt = xprt_from_sock(sk))) + goto out; + + err = -sk->sk_err; + if (err == 0) + goto out; + dprintk("RPC: xs_error_report client %p, error=%d...\n", + xprt, -err); + trace_rpc_socket_error(xprt, sk->sk_socket, err); + xprt_wake_pending_tasks(xprt, err); + out: + read_unlock_bh(&sk->sk_callback_lock); +} + +static void xs_reset_transport(struct sock_xprt *transport) +{ + struct socket *sock = transport->sock; + struct sock *sk = transport->inet; + struct rpc_xprt *xprt = &transport->xprt; + + if (sk == NULL) + return; + + if (atomic_read(&transport->xprt.swapper)) + sk_clear_memalloc(sk); + + kernel_sock_shutdown(sock, SHUT_RDWR); + + mutex_lock(&transport->recv_mutex); + write_lock_bh(&sk->sk_callback_lock); + transport->inet = NULL; + transport->sock = NULL; + + sk->sk_user_data = NULL; + + xs_restore_old_callbacks(transport, sk); + xprt_clear_connected(xprt); + write_unlock_bh(&sk->sk_callback_lock); + xs_sock_reset_connection_flags(xprt); + mutex_unlock(&transport->recv_mutex); + + trace_rpc_socket_close(xprt, sock); + sock_release(sock); +} + +/** + * xs_close - close a socket + * @xprt: transport + * + * This is used when all requests are complete; ie, no DRC state remains + * on the server we want to save. + * + * The caller _must_ be holding XPRT_LOCKED in order to avoid issues with + * xs_reset_transport() zeroing the socket from underneath a writer. + */ +static void xs_close(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + dprintk("RPC: xs_close xprt %p\n", xprt); + + xs_reset_transport(transport); + xprt->reestablish_timeout = 0; + + xprt_disconnect_done(xprt); +} + +static void xs_inject_disconnect(struct rpc_xprt *xprt) +{ + dprintk("RPC: injecting transport disconnect on xprt=%p\n", + xprt); + xprt_disconnect_done(xprt); +} + +static void xs_xprt_free(struct rpc_xprt *xprt) +{ + xs_free_peer_addresses(xprt); + xprt_free(xprt); +} + +/** + * xs_destroy - prepare to shutdown a transport + * @xprt: doomed transport + * + */ +static void xs_destroy(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); + dprintk("RPC: xs_destroy xprt %p\n", xprt); + + cancel_delayed_work_sync(&transport->connect_worker); + xs_close(xprt); + cancel_work_sync(&transport->recv_worker); + xs_xprt_free(xprt); + module_put(THIS_MODULE); +} + +static int xs_local_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) +{ + struct xdr_skb_reader desc = { + .skb = skb, + .offset = sizeof(rpc_fraghdr), + .count = skb->len - sizeof(rpc_fraghdr), + }; + + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) + return -1; + if (desc.count) + return -1; + return 0; +} + +/** + * xs_local_data_read_skb + * @xprt: transport + * @sk: socket + * @skb: skbuff + * + * Currently this assumes we can read the whole reply in a single gulp. + */ +static void xs_local_data_read_skb(struct rpc_xprt *xprt, + struct sock *sk, + struct sk_buff *skb) +{ + struct rpc_task *task; + struct rpc_rqst *rovr; + int repsize, copied; + u32 _xid; + __be32 *xp; + + repsize = skb->len - sizeof(rpc_fraghdr); + if (repsize < 4) { + dprintk("RPC: impossible RPC reply size %d\n", repsize); + return; + } + + /* Copy the XID from the skb... */ + xp = skb_header_pointer(skb, sizeof(rpc_fraghdr), sizeof(_xid), &_xid); + if (xp == NULL) + return; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->recv_lock); + rovr = xprt_lookup_rqst(xprt, *xp); + if (!rovr) + goto out_unlock; + xprt_pin_rqst(rovr); + spin_unlock(&xprt->recv_lock); + task = rovr->rq_task; + + copied = rovr->rq_private_buf.buflen; + if (copied > repsize) + copied = repsize; + + if (xs_local_copy_to_xdr(&rovr->rq_private_buf, skb)) { + dprintk("RPC: sk_buff copy failed\n"); + spin_lock(&xprt->recv_lock); + goto out_unpin; + } + + spin_lock(&xprt->recv_lock); + xprt_complete_rqst(task, copied); +out_unpin: + xprt_unpin_rqst(rovr); + out_unlock: + spin_unlock(&xprt->recv_lock); +} + +static void xs_local_data_receive(struct sock_xprt *transport) +{ + struct sk_buff *skb; + struct sock *sk; + int err; + +restart: + mutex_lock(&transport->recv_mutex); + sk = transport->inet; + if (sk == NULL) + goto out; + for (;;) { + skb = skb_recv_datagram(sk, 0, 1, &err); + if (skb != NULL) { + xs_local_data_read_skb(&transport->xprt, sk, skb); + skb_free_datagram(sk, skb); + continue; + } + if (!test_and_clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + break; + if (need_resched()) { + mutex_unlock(&transport->recv_mutex); + cond_resched(); + goto restart; + } + } +out: + mutex_unlock(&transport->recv_mutex); +} + +static void xs_local_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + xs_local_data_receive(transport); +} + +/** + * xs_udp_data_read_skb - receive callback for UDP sockets + * @xprt: transport + * @sk: socket + * @skb: skbuff + * + */ +static void xs_udp_data_read_skb(struct rpc_xprt *xprt, + struct sock *sk, + struct sk_buff *skb) +{ + struct rpc_task *task; + struct rpc_rqst *rovr; + int repsize, copied; + u32 _xid; + __be32 *xp; + + repsize = skb->len; + if (repsize < 4) { + dprintk("RPC: impossible RPC reply size %d!\n", repsize); + return; + } + + /* Copy the XID from the skb... */ + xp = skb_header_pointer(skb, 0, sizeof(_xid), &_xid); + if (xp == NULL) + return; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->recv_lock); + rovr = xprt_lookup_rqst(xprt, *xp); + if (!rovr) + goto out_unlock; + xprt_pin_rqst(rovr); + xprt_update_rtt(rovr->rq_task); + spin_unlock(&xprt->recv_lock); + task = rovr->rq_task; + + if ((copied = rovr->rq_private_buf.buflen) > repsize) + copied = repsize; + + /* Suck it into the iovec, verify checksum if not done by hw. */ + if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) { + spin_lock(&xprt->recv_lock); + __UDPX_INC_STATS(sk, UDP_MIB_INERRORS); + goto out_unpin; + } + + + spin_lock_bh(&xprt->transport_lock); + xprt_adjust_cwnd(xprt, task, copied); + spin_unlock_bh(&xprt->transport_lock); + spin_lock(&xprt->recv_lock); + xprt_complete_rqst(task, copied); + __UDPX_INC_STATS(sk, UDP_MIB_INDATAGRAMS); +out_unpin: + xprt_unpin_rqst(rovr); + out_unlock: + spin_unlock(&xprt->recv_lock); +} + +static void xs_udp_data_receive(struct sock_xprt *transport) +{ + struct sk_buff *skb; + struct sock *sk; + int err; + +restart: + mutex_lock(&transport->recv_mutex); + sk = transport->inet; + if (sk == NULL) + goto out; + for (;;) { + skb = skb_recv_udp(sk, 0, 1, &err); + if (skb != NULL) { + xs_udp_data_read_skb(&transport->xprt, sk, skb); + consume_skb(skb); + continue; + } + if (!test_and_clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + break; + if (need_resched()) { + mutex_unlock(&transport->recv_mutex); + cond_resched(); + goto restart; + } + } +out: + mutex_unlock(&transport->recv_mutex); +} + +static void xs_udp_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + xs_udp_data_receive(transport); +} + +/** + * xs_data_ready - "data ready" callback for UDP sockets + * @sk: socket with data to read + * + */ +static void xs_data_ready(struct sock *sk) +{ + struct rpc_xprt *xprt; + + read_lock_bh(&sk->sk_callback_lock); + dprintk("RPC: xs_data_ready...\n"); + xprt = xprt_from_sock(sk); + if (xprt != NULL) { + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); + transport->old_data_ready(sk); + /* Any data means we had a useful conversation, so + * then we don't need to delay the next reconnect + */ + if (xprt->reestablish_timeout) + xprt->reestablish_timeout = 0; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); + } + read_unlock_bh(&sk->sk_callback_lock); +} + +/* + * Helper function to force a TCP close if the server is sending + * junk and/or it has put us in CLOSE_WAIT + */ +static void xs_tcp_force_close(struct rpc_xprt *xprt) +{ + xprt_force_disconnect(xprt); +} + +static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + size_t len, used; + char *p; + + p = ((char *) &transport->tcp_fraghdr) + transport->tcp_offset; + len = sizeof(transport->tcp_fraghdr) - transport->tcp_offset; + used = xdr_skb_read_bits(desc, p, len); + transport->tcp_offset += used; + if (used != len) + return; + + transport->tcp_reclen = ntohl(transport->tcp_fraghdr); + if (transport->tcp_reclen & RPC_LAST_STREAM_FRAGMENT) + transport->tcp_flags |= TCP_RCV_LAST_FRAG; + else + transport->tcp_flags &= ~TCP_RCV_LAST_FRAG; + transport->tcp_reclen &= RPC_FRAGMENT_SIZE_MASK; + + transport->tcp_flags &= ~TCP_RCV_COPY_FRAGHDR; + transport->tcp_offset = 0; + + /* Sanity check of the record length */ + if (unlikely(transport->tcp_reclen < 8)) { + dprintk("RPC: invalid TCP record fragment length\n"); + xs_tcp_force_close(xprt); + return; + } + dprintk("RPC: reading TCP record fragment of length %d\n", + transport->tcp_reclen); +} + +static void xs_tcp_check_fraghdr(struct sock_xprt *transport) +{ + if (transport->tcp_offset == transport->tcp_reclen) { + transport->tcp_flags |= TCP_RCV_COPY_FRAGHDR; + transport->tcp_offset = 0; + if (transport->tcp_flags & TCP_RCV_LAST_FRAG) { + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + transport->tcp_flags |= TCP_RCV_COPY_XID; + transport->tcp_copied = 0; + } + } +} + +static inline void xs_tcp_read_xid(struct sock_xprt *transport, struct xdr_skb_reader *desc) +{ + size_t len, used; + char *p; + + len = sizeof(transport->tcp_xid) - transport->tcp_offset; + dprintk("RPC: reading XID (%zu bytes)\n", len); + p = ((char *) &transport->tcp_xid) + transport->tcp_offset; + used = xdr_skb_read_bits(desc, p, len); + transport->tcp_offset += used; + if (used != len) + return; + transport->tcp_flags &= ~TCP_RCV_COPY_XID; + transport->tcp_flags |= TCP_RCV_READ_CALLDIR; + transport->tcp_copied = 4; + dprintk("RPC: reading %s XID %08x\n", + (transport->tcp_flags & TCP_RPC_REPLY) ? "reply for" + : "request with", + ntohl(transport->tcp_xid)); + xs_tcp_check_fraghdr(transport); +} + +static inline void xs_tcp_read_calldir(struct sock_xprt *transport, + struct xdr_skb_reader *desc) +{ + size_t len, used; + u32 offset; + char *p; + + /* + * We want transport->tcp_offset to be 8 at the end of this routine + * (4 bytes for the xid and 4 bytes for the call/reply flag). + * When this function is called for the first time, + * transport->tcp_offset is 4 (after having already read the xid). + */ + offset = transport->tcp_offset - sizeof(transport->tcp_xid); + len = sizeof(transport->tcp_calldir) - offset; + dprintk("RPC: reading CALL/REPLY flag (%zu bytes)\n", len); + p = ((char *) &transport->tcp_calldir) + offset; + used = xdr_skb_read_bits(desc, p, len); + transport->tcp_offset += used; + if (used != len) + return; + transport->tcp_flags &= ~TCP_RCV_READ_CALLDIR; + /* + * We don't yet have the XDR buffer, so we will write the calldir + * out after we get the buffer from the 'struct rpc_rqst' + */ + switch (ntohl(transport->tcp_calldir)) { + case RPC_REPLY: + transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; + transport->tcp_flags |= TCP_RCV_COPY_DATA; + transport->tcp_flags |= TCP_RPC_REPLY; + break; + case RPC_CALL: + transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; + transport->tcp_flags |= TCP_RCV_COPY_DATA; + transport->tcp_flags &= ~TCP_RPC_REPLY; + break; + default: + dprintk("RPC: invalid request message type\n"); + xs_tcp_force_close(&transport->xprt); + } + xs_tcp_check_fraghdr(transport); +} + +static inline void xs_tcp_read_common(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc, + struct rpc_rqst *req) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *rcvbuf; + size_t len; + ssize_t r; + + rcvbuf = &req->rq_private_buf; + + if (transport->tcp_flags & TCP_RCV_COPY_CALLDIR) { + /* + * Save the RPC direction in the XDR buffer + */ + memcpy(rcvbuf->head[0].iov_base + transport->tcp_copied, + &transport->tcp_calldir, + sizeof(transport->tcp_calldir)); + transport->tcp_copied += sizeof(transport->tcp_calldir); + transport->tcp_flags &= ~TCP_RCV_COPY_CALLDIR; + } + + len = desc->count; + if (len > transport->tcp_reclen - transport->tcp_offset) + desc->count = transport->tcp_reclen - transport->tcp_offset; + r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, + desc, xdr_skb_read_bits); + + if (desc->count) { + /* Error when copying to the receive buffer, + * usually because we weren't able to allocate + * additional buffer pages. All we can do now + * is turn off TCP_RCV_COPY_DATA, so the request + * will not receive any additional updates, + * and time out. + * Any remaining data from this record will + * be discarded. + */ + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + dprintk("RPC: XID %08x truncated request\n", + ntohl(transport->tcp_xid)); + dprintk("RPC: xprt = %p, tcp_copied = %lu, " + "tcp_offset = %u, tcp_reclen = %u\n", + xprt, transport->tcp_copied, + transport->tcp_offset, transport->tcp_reclen); + return; + } + + transport->tcp_copied += r; + transport->tcp_offset += r; + desc->count = len - r; + + dprintk("RPC: XID %08x read %zd bytes\n", + ntohl(transport->tcp_xid), r); + dprintk("RPC: xprt = %p, tcp_copied = %lu, tcp_offset = %u, " + "tcp_reclen = %u\n", xprt, transport->tcp_copied, + transport->tcp_offset, transport->tcp_reclen); + + if (transport->tcp_copied == req->rq_private_buf.buflen) + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + else if (transport->tcp_offset == transport->tcp_reclen) { + if (transport->tcp_flags & TCP_RCV_LAST_FRAG) + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + } +} + +/* + * Finds the request corresponding to the RPC xid and invokes the common + * tcp read code to read the data. + */ +static inline int xs_tcp_read_reply(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct rpc_rqst *req; + + dprintk("RPC: read reply XID %08x\n", ntohl(transport->tcp_xid)); + + /* Find and lock the request corresponding to this xid */ + spin_lock(&xprt->recv_lock); + req = xprt_lookup_rqst(xprt, transport->tcp_xid); + if (!req) { + dprintk("RPC: XID %08x request not found!\n", + ntohl(transport->tcp_xid)); + spin_unlock(&xprt->recv_lock); + return -1; + } + xprt_pin_rqst(req); + spin_unlock(&xprt->recv_lock); + + xs_tcp_read_common(xprt, desc, req); + + spin_lock(&xprt->recv_lock); + if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) + xprt_complete_rqst(req->rq_task, transport->tcp_copied); + xprt_unpin_rqst(req); + spin_unlock(&xprt->recv_lock); + return 0; +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/* + * Obtains an rpc_rqst previously allocated and invokes the common + * tcp read code to read the data. The result is placed in the callback + * queue. + * If we're unable to obtain the rpc_rqst we schedule the closing of the + * connection and return -1. + */ +static int xs_tcp_read_callback(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct rpc_rqst *req; + + /* Look up the request corresponding to the given XID */ + req = xprt_lookup_bc_request(xprt, transport->tcp_xid); + if (req == NULL) { + printk(KERN_WARNING "Callback slot table overflowed\n"); + xprt_force_disconnect(xprt); + return -1; + } + + dprintk("RPC: read callback XID %08x\n", ntohl(req->rq_xid)); + xs_tcp_read_common(xprt, desc, req); + + if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) + xprt_complete_bc_request(req, transport->tcp_copied); + + return 0; +} + +static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + + return (transport->tcp_flags & TCP_RPC_REPLY) ? + xs_tcp_read_reply(xprt, desc) : + xs_tcp_read_callback(xprt, desc); +} + +static int xs_tcp_bc_up(struct svc_serv *serv, struct net *net) +{ + int ret; + + ret = svc_create_xprt(serv, "tcp-bc", net, PF_INET, 0, + SVC_SOCK_ANONYMOUS); + if (ret < 0) + return ret; + return 0; +} + +static size_t xs_tcp_bc_maxpayload(struct rpc_xprt *xprt) +{ + return PAGE_SIZE; +} +#else +static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc) +{ + return xs_tcp_read_reply(xprt, desc); +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/* + * Read data off the transport. This can be either an RPC_CALL or an + * RPC_REPLY. Relay the processing to helper functions. + */ +static void xs_tcp_read_data(struct rpc_xprt *xprt, + struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + + if (_xs_tcp_read_data(xprt, desc) == 0) + xs_tcp_check_fraghdr(transport); + else { + /* + * The transport_lock protects the request handling. + * There's no need to hold it to update the tcp_flags. + */ + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + } +} + +static inline void xs_tcp_read_discard(struct sock_xprt *transport, struct xdr_skb_reader *desc) +{ + size_t len; + + len = transport->tcp_reclen - transport->tcp_offset; + if (len > desc->count) + len = desc->count; + desc->count -= len; + desc->offset += len; + transport->tcp_offset += len; + dprintk("RPC: discarded %zu bytes\n", len); + xs_tcp_check_fraghdr(transport); +} + +static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len) +{ + struct rpc_xprt *xprt = rd_desc->arg.data; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_skb_reader desc = { + .skb = skb, + .offset = offset, + .count = len, + }; + size_t ret; + + dprintk("RPC: xs_tcp_data_recv started\n"); + do { + trace_xs_tcp_data_recv(transport); + /* Read in a new fragment marker if necessary */ + /* Can we ever really expect to get completely empty fragments? */ + if (transport->tcp_flags & TCP_RCV_COPY_FRAGHDR) { + xs_tcp_read_fraghdr(xprt, &desc); + continue; + } + /* Read in the xid if necessary */ + if (transport->tcp_flags & TCP_RCV_COPY_XID) { + xs_tcp_read_xid(transport, &desc); + continue; + } + /* Read in the call/reply flag */ + if (transport->tcp_flags & TCP_RCV_READ_CALLDIR) { + xs_tcp_read_calldir(transport, &desc); + continue; + } + /* Read in the request data */ + if (transport->tcp_flags & TCP_RCV_COPY_DATA) { + xs_tcp_read_data(xprt, &desc); + continue; + } + /* Skip over any trailing bytes on short reads */ + xs_tcp_read_discard(transport, &desc); + } while (desc.count); + ret = len - desc.count; + if (ret < rd_desc->count) + rd_desc->count -= ret; + else + rd_desc->count = 0; + trace_xs_tcp_data_recv(transport); + dprintk("RPC: xs_tcp_data_recv done\n"); + return ret; +} + +static void xs_tcp_data_receive(struct sock_xprt *transport) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct sock *sk; + read_descriptor_t rd_desc = { + .arg.data = xprt, + }; + unsigned long total = 0; + int read = 0; + +restart: + mutex_lock(&transport->recv_mutex); + sk = transport->inet; + if (sk == NULL) + goto out; + + /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */ + for (;;) { + rd_desc.count = RPC_TCP_READ_CHUNK_SZ; + lock_sock(sk); + read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv); + if (rd_desc.count != 0 || read < 0) { + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + release_sock(sk); + break; + } + release_sock(sk); + total += read; + if (need_resched()) { + mutex_unlock(&transport->recv_mutex); + cond_resched(); + goto restart; + } + } + if (test_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); +out: + mutex_unlock(&transport->recv_mutex); + trace_xs_tcp_data_ready(xprt, read, total); +} + +static void xs_tcp_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + xs_tcp_data_receive(transport); +} + +/** + * xs_tcp_state_change - callback to handle TCP socket state changes + * @sk: socket whose state has changed + * + */ +static void xs_tcp_state_change(struct sock *sk) +{ + struct rpc_xprt *xprt; + struct sock_xprt *transport; + + read_lock_bh(&sk->sk_callback_lock); + if (!(xprt = xprt_from_sock(sk))) + goto out; + dprintk("RPC: xs_tcp_state_change client %p...\n", xprt); + dprintk("RPC: state %x conn %d dead %d zapped %d sk_shutdown %d\n", + sk->sk_state, xprt_connected(xprt), + sock_flag(sk, SOCK_DEAD), + sock_flag(sk, SOCK_ZAPPED), + sk->sk_shutdown); + + transport = container_of(xprt, struct sock_xprt, xprt); + trace_rpc_socket_state_change(xprt, sk->sk_socket); + switch (sk->sk_state) { + case TCP_ESTABLISHED: + spin_lock(&xprt->transport_lock); + if (!xprt_test_and_set_connected(xprt)) { + + /* Reset TCP record info */ + transport->tcp_offset = 0; + transport->tcp_reclen = 0; + transport->tcp_copied = 0; + transport->tcp_flags = + TCP_RCV_COPY_FRAGHDR | TCP_RCV_COPY_XID; + xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + xprt_clear_connecting(xprt); + + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_wake_pending_tasks(xprt, -EAGAIN); + } + spin_unlock(&xprt->transport_lock); + break; + case TCP_FIN_WAIT1: + /* The client initiated a shutdown of the socket */ + xprt->connect_cookie++; + xprt->reestablish_timeout = 0; + set_bit(XPRT_CLOSING, &xprt->state); + smp_mb__before_atomic(); + clear_bit(XPRT_CONNECTED, &xprt->state); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + smp_mb__after_atomic(); + break; + case TCP_CLOSE_WAIT: + /* The server initiated a shutdown of the socket */ + xprt->connect_cookie++; + clear_bit(XPRT_CONNECTED, &xprt->state); + xs_tcp_force_close(xprt); + /* fall through */ + case TCP_CLOSING: + /* + * If the server closed down the connection, make sure that + * we back off before reconnecting + */ + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + break; + case TCP_LAST_ACK: + set_bit(XPRT_CLOSING, &xprt->state); + smp_mb__before_atomic(); + clear_bit(XPRT_CONNECTED, &xprt->state); + smp_mb__after_atomic(); + break; + case TCP_CLOSE: + if (test_and_clear_bit(XPRT_SOCK_CONNECTING, + &transport->sock_state)) + xprt_clear_connecting(xprt); + clear_bit(XPRT_CLOSING, &xprt->state); + if (sk->sk_err) + xprt_wake_pending_tasks(xprt, -sk->sk_err); + /* Trigger the socket release */ + xs_tcp_force_close(xprt); + } + out: + read_unlock_bh(&sk->sk_callback_lock); +} + +static void xs_write_space(struct sock *sk) +{ + struct socket_wq *wq; + struct rpc_xprt *xprt; + + if (!sk->sk_socket) + return; + clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + + if (unlikely(!(xprt = xprt_from_sock(sk)))) + return; + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (!wq || test_and_clear_bit(SOCKWQ_ASYNC_NOSPACE, &wq->flags) == 0) + goto out; + + xprt_write_space(xprt); +out: + rcu_read_unlock(); +} + +/** + * xs_udp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_udp_write_space(struct sock *sk) +{ + read_lock_bh(&sk->sk_callback_lock); + + /* from net/core/sock.c:sock_def_write_space */ + if (sock_writeable(sk)) + xs_write_space(sk); + + read_unlock_bh(&sk->sk_callback_lock); +} + +/** + * xs_tcp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_tcp_write_space(struct sock *sk) +{ + read_lock_bh(&sk->sk_callback_lock); + + /* from net/core/stream.c:sk_stream_write_space */ + if (sk_stream_is_writeable(sk)) + xs_write_space(sk); + + read_unlock_bh(&sk->sk_callback_lock); +} + +static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + + if (transport->rcvsize) { + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + sk->sk_rcvbuf = transport->rcvsize * xprt->max_reqs * 2; + } + if (transport->sndsize) { + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + sk->sk_sndbuf = transport->sndsize * xprt->max_reqs * 2; + sk->sk_write_space(sk); + } +} + +/** + * xs_udp_set_buffer_size - set send and receive limits + * @xprt: generic transport + * @sndsize: requested size of send buffer, in bytes + * @rcvsize: requested size of receive buffer, in bytes + * + * Set socket send and receive buffer size limits. + */ +static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t rcvsize) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + transport->sndsize = 0; + if (sndsize) + transport->sndsize = sndsize + 1024; + transport->rcvsize = 0; + if (rcvsize) + transport->rcvsize = rcvsize + 1024; + + xs_udp_do_set_buffer_size(xprt); +} + +/** + * xs_udp_timer - called when a retransmit timeout occurs on a UDP transport + * @task: task that timed out + * + * Adjust the congestion window after a retransmit timeout has occurred. + */ +static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task) +{ + spin_lock_bh(&xprt->transport_lock); + xprt_adjust_cwnd(xprt, task, -ETIMEDOUT); + spin_unlock_bh(&xprt->transport_lock); +} + +static int xs_get_random_port(void) +{ + unsigned short min = xprt_min_resvport, max = xprt_max_resvport; + unsigned short range; + unsigned short rand; + + if (max < min) + return -EADDRINUSE; + range = max - min + 1; + rand = (unsigned short) prandom_u32() % range; + return rand + min; +} + +/** + * xs_set_reuseaddr_port - set the socket's port and address reuse options + * @sock: socket + * + * Note that this function has to be called on all sockets that share the + * same port, and it must be called before binding. + */ +static void xs_sock_set_reuseport(struct socket *sock) +{ + int opt = 1; + + kernel_setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, + (char *)&opt, sizeof(opt)); +} + +static unsigned short xs_sock_getport(struct socket *sock) +{ + struct sockaddr_storage buf; + unsigned short port = 0; + + if (kernel_getsockname(sock, (struct sockaddr *)&buf) < 0) + goto out; + switch (buf.ss_family) { + case AF_INET6: + port = ntohs(((struct sockaddr_in6 *)&buf)->sin6_port); + break; + case AF_INET: + port = ntohs(((struct sockaddr_in *)&buf)->sin_port); + } +out: + return port; +} + +/** + * xs_set_port - reset the port number in the remote endpoint address + * @xprt: generic transport + * @port: new port number + * + */ +static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) +{ + dprintk("RPC: setting port for xprt %p to %u\n", xprt, port); + + rpc_set_port(xs_addr(xprt), port); + xs_update_peer_port(xprt); +} + +static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) +{ + if (transport->srcport == 0) + transport->srcport = xs_sock_getport(sock); +} + +static int xs_get_srcport(struct sock_xprt *transport) +{ + int port = transport->srcport; + + if (port == 0 && transport->xprt.resvport) + port = xs_get_random_port(); + return port; +} + +static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port) +{ + if (transport->srcport != 0) + transport->srcport = 0; + if (!transport->xprt.resvport) + return 0; + if (port <= xprt_min_resvport || port > xprt_max_resvport) + return xprt_max_resvport; + return --port; +} +static int xs_bind(struct sock_xprt *transport, struct socket *sock) +{ + struct sockaddr_storage myaddr; + int err, nloop = 0; + int port = xs_get_srcport(transport); + unsigned short last; + + /* + * If we are asking for any ephemeral port (i.e. port == 0 && + * transport->xprt.resvport == 0), don't bind. Let the local + * port selection happen implicitly when the socket is used + * (for example at connect time). + * + * This ensures that we can continue to establish TCP + * connections even when all local ephemeral ports are already + * a part of some TCP connection. This makes no difference + * for UDP sockets, but also doens't harm them. + * + * If we're asking for any reserved port (i.e. port == 0 && + * transport->xprt.resvport == 1) xs_get_srcport above will + * ensure that port is non-zero and we will bind as needed. + */ + if (port <= 0) + return port; + + memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); + do { + rpc_set_port((struct sockaddr *)&myaddr, port); + err = kernel_bind(sock, (struct sockaddr *)&myaddr, + transport->xprt.addrlen); + if (err == 0) { + transport->srcport = port; + break; + } + last = port; + port = xs_next_srcport(transport, port); + if (port > last) + nloop++; + } while (err == -EADDRINUSE && nloop != 2); + + if (myaddr.ss_family == AF_INET) + dprintk("RPC: %s %pI4:%u: %s (%d)\n", __func__, + &((struct sockaddr_in *)&myaddr)->sin_addr, + port, err ? "failed" : "ok", err); + else + dprintk("RPC: %s %pI6:%u: %s (%d)\n", __func__, + &((struct sockaddr_in6 *)&myaddr)->sin6_addr, + port, err ? "failed" : "ok", err); + return err; +} + +/* + * We don't support autobind on AF_LOCAL sockets + */ +static void xs_local_rpcbind(struct rpc_task *task) +{ + xprt_set_bound(task->tk_xprt); +} + +static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port) +{ +} + +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key xs_key[2]; +static struct lock_class_key xs_slock_key[2]; + +static inline void xs_reclassify_socketu(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_LOCAL-RPC", + &xs_slock_key[1], "sk_lock-AF_LOCAL-RPC", &xs_key[1]); +} + +static inline void xs_reclassify_socket4(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC", + &xs_slock_key[0], "sk_lock-AF_INET-RPC", &xs_key[0]); +} + +static inline void xs_reclassify_socket6(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC", + &xs_slock_key[1], "sk_lock-AF_INET6-RPC", &xs_key[1]); +} + +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ + if (WARN_ON_ONCE(!sock_allow_reclassification(sock->sk))) + return; + + switch (family) { + case AF_LOCAL: + xs_reclassify_socketu(sock); + break; + case AF_INET: + xs_reclassify_socket4(sock); + break; + case AF_INET6: + xs_reclassify_socket6(sock); + break; + } +} +#else +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ +} +#endif + +static void xs_dummy_setup_socket(struct work_struct *work) +{ +} + +static struct socket *xs_create_sock(struct rpc_xprt *xprt, + struct sock_xprt *transport, int family, int type, + int protocol, bool reuseport) +{ + struct socket *sock; + int err; + + err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create %d transport socket (%d).\n", + protocol, -err); + goto out; + } + xs_reclassify_socket(family, sock); + + if (reuseport) + xs_sock_set_reuseport(sock); + + err = xs_bind(transport, sock); + if (err) { + sock_release(sock); + goto out; + } + + return sock; +out: + return ERR_PTR(err); +} + +static int xs_local_finish_connecting(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_write_space = xs_udp_write_space; + sock_set_flag(sk, SOCK_FASYNC); + sk->sk_error_report = xs_error_report; + sk->sk_allocation = GFP_NOIO; + + xprt_clear_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + write_unlock_bh(&sk->sk_callback_lock); + } + + /* Tell the socket layer to start connecting... */ + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); +} + +/** + * xs_local_setup_socket - create AF_LOCAL socket, connect to a local endpoint + * @transport: socket transport to connect + */ +static int xs_local_setup_socket(struct sock_xprt *transport) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock; + int status = -EIO; + + status = __sock_create(xprt->xprt_net, AF_LOCAL, + SOCK_STREAM, 0, &sock, 1); + if (status < 0) { + dprintk("RPC: can't create AF_LOCAL " + "transport socket (%d).\n", -status); + goto out; + } + xs_reclassify_socket(AF_LOCAL, sock); + + dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + + status = xs_local_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); + switch (status) { + case 0: + dprintk("RPC: xprt %p connected to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_set_connected(xprt); + case -ENOBUFS: + break; + case -ENOENT: + dprintk("RPC: xprt %p: socket %s does not exist\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + case -ECONNREFUSED: + dprintk("RPC: xprt %p: connection refused for %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + default: + printk(KERN_ERR "%s: unhandled error (%d) connecting to %s\n", + __func__, -status, + xprt->address_strings[RPC_DISPLAY_ADDR]); + } + +out: + xprt_clear_connecting(xprt); + xprt_wake_pending_tasks(xprt, status); + return status; +} + +static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + int ret; + + if (RPC_IS_ASYNC(task)) { + /* + * We want the AF_LOCAL connect to be resolved in the + * filesystem namespace of the process making the rpc + * call. Thus we connect synchronously. + * + * If we want to support asynchronous AF_LOCAL calls, + * we'll need to figure out how to pass a namespace to + * connect. + */ + rpc_exit(task, -ENOTCONN); + return; + } + ret = xs_local_setup_socket(transport); + if (ret && !RPC_IS_SOFTCONN(task)) + msleep_interruptible(15000); +} + +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +/* + * Note that this should be called with XPRT_LOCKED held (or when we otherwise + * know that we have exclusive access to the socket), to guard against + * races with xs_reset_transport. + */ +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + + /* + * If there's no sock, then we have nothing to set. The + * reconnecting process will get it for us. + */ + if (!transport->inet) + return; + if (atomic_read(&xprt->swapper)) + sk_set_memalloc(transport->inet); +} + +/** + * xs_enable_swap - Tag this transport as being used for swap. + * @xprt: transport to tag + * + * Take a reference to this transport on behalf of the rpc_clnt, and + * optionally mark it for swapping if it wasn't already. + */ +static int +xs_enable_swap(struct rpc_xprt *xprt) +{ + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); + + if (atomic_inc_return(&xprt->swapper) != 1) + return 0; + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) + return -ERESTARTSYS; + if (xs->inet) + sk_set_memalloc(xs->inet); + xprt_release_xprt(xprt, NULL); + return 0; +} + +/** + * xs_disable_swap - Untag this transport as being used for swap. + * @xprt: transport to tag + * + * Drop a "swapper" reference to this xprt on behalf of the rpc_clnt. If the + * swapper refcount goes to 0, untag the socket as a memalloc socket. + */ +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); + + if (!atomic_dec_and_test(&xprt->swapper)) + return; + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) + return; + if (xs->inet) + sk_clear_memalloc(xs->inet); + xprt_release_xprt(xprt, NULL); +} +#else +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ +} + +static int +xs_enable_swap(struct rpc_xprt *xprt) +{ + return -EINVAL; +} + +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ +} +#endif + +static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_write_space = xs_udp_write_space; + sock_set_flag(sk, SOCK_FASYNC); + sk->sk_allocation = GFP_NOIO; + + xprt_set_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + xs_set_memalloc(xprt); + + write_unlock_bh(&sk->sk_callback_lock); + } + xs_udp_do_set_buffer_size(xprt); + + xprt->stat.connect_start = jiffies; +} + +static void xs_udp_setup_socket(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock; + int status = -EIO; + + sock = xs_create_sock(xprt, transport, + xs_addr(xprt)->sa_family, SOCK_DGRAM, + IPPROTO_UDP, false); + if (IS_ERR(sock)) + goto out; + + dprintk("RPC: worker connecting xprt %p via %s to " + "%s (port %s)\n", xprt, + xprt->address_strings[RPC_DISPLAY_PROTO], + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT]); + + xs_udp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, 0); + status = 0; +out: + xprt_clear_connecting(xprt); + xprt_unlock_connect(xprt, transport); + xprt_wake_pending_tasks(xprt, status); +} + +/** + * xs_tcp_shutdown - gracefully shut down a TCP socket + * @xprt: transport + * + * Initiates a graceful shutdown of the TCP socket by calling the + * equivalent of shutdown(SHUT_RDWR); + */ +static void xs_tcp_shutdown(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + int skst = transport->inet ? transport->inet->sk_state : TCP_CLOSE; + + if (sock == NULL) + return; + switch (skst) { + default: + kernel_sock_shutdown(sock, SHUT_RDWR); + trace_rpc_socket_shutdown(xprt, sock); + break; + case TCP_CLOSE: + case TCP_TIME_WAIT: + xs_reset_transport(transport); + } +} + +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + unsigned int keepidle; + unsigned int keepcnt; + unsigned int opt_on = 1; + unsigned int timeo; + + spin_lock_bh(&xprt->transport_lock); + keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ); + keepcnt = xprt->timeout->to_retries + 1; + timeo = jiffies_to_msecs(xprt->timeout->to_initval) * + (xprt->timeout->to_retries + 1); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock_bh(&xprt->transport_lock); + + /* TCP Keepalive options */ + kernel_setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, + (char *)&opt_on, sizeof(opt_on)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, + (char *)&keepidle, sizeof(keepidle)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, + (char *)&keepidle, sizeof(keepidle)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPCNT, + (char *)&keepcnt, sizeof(keepcnt)); + + /* TCP user timeout (see RFC5482) */ + kernel_setsockopt(sock, SOL_TCP, TCP_USER_TIMEOUT, + (char *)&timeo, sizeof(timeo)); +} + +static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct rpc_timeout to; + unsigned long initval; + + spin_lock_bh(&xprt->transport_lock); + if (reconnect_timeout < xprt->max_reconnect_timeout) + xprt->max_reconnect_timeout = reconnect_timeout; + if (connect_timeout < xprt->connect_timeout) { + memcpy(&to, xprt->timeout, sizeof(to)); + initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1); + /* Arbitrary lower limit */ + if (initval < XS_TCP_INIT_REEST_TO << 1) + initval = XS_TCP_INIT_REEST_TO << 1; + to.to_initval = initval; + to.to_maxval = initval; + memcpy(&transport->tcp_timeout, &to, + sizeof(transport->tcp_timeout)); + xprt->timeout = &transport->tcp_timeout; + xprt->connect_timeout = connect_timeout; + } + set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock_bh(&xprt->transport_lock); +} + +static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + int ret = -ENOTCONN; + + if (!transport->inet) { + struct sock *sk = sock->sk; + unsigned int addr_pref = IPV6_PREFER_SRC_PUBLIC; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + kernel_setsockopt(sock, SOL_IPV6, IPV6_ADDR_PREFERENCES, + (char *)&addr_pref, sizeof(addr_pref)); + + xs_tcp_set_socket_timeouts(xprt, sock); + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_state_change = xs_tcp_state_change; + sk->sk_write_space = xs_tcp_write_space; + sock_set_flag(sk, SOCK_FASYNC); + sk->sk_error_report = xs_error_report; + sk->sk_allocation = GFP_NOIO; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; + + xprt_clear_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + write_unlock_bh(&sk->sk_callback_lock); + } + + if (!xprt_bound(xprt)) + goto out; + + xs_set_memalloc(xprt); + + /* Tell the socket layer to start connecting... */ + set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + ret = kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); + switch (ret) { + case 0: + xs_set_srcport(transport, sock); + /* fall through */ + case -EINPROGRESS: + /* SYN_SENT! */ + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + break; + case -EADDRNOTAVAIL: + /* Source port number is unavailable. Try a new one! */ + transport->srcport = 0; + } +out: + return ret; +} + +/** + * xs_tcp_setup_socket - create a TCP socket and connect to a remote endpoint + * + * Invoked by a work queue tasklet. + */ +static void xs_tcp_setup_socket(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct socket *sock = transport->sock; + struct rpc_xprt *xprt = &transport->xprt; + int status = -EIO; + + if (!sock) { + sock = xs_create_sock(xprt, transport, + xs_addr(xprt)->sa_family, SOCK_STREAM, + IPPROTO_TCP, true); + if (IS_ERR(sock)) { + status = PTR_ERR(sock); + goto out; + } + } + + dprintk("RPC: worker connecting xprt %p via %s to " + "%s (port %s)\n", xprt, + xprt->address_strings[RPC_DISPLAY_PROTO], + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT]); + + status = xs_tcp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); + dprintk("RPC: %p connect status %d connected %d sock state %d\n", + xprt, -status, xprt_connected(xprt), + sock->sk->sk_state); + switch (status) { + default: + printk("%s: connect returned unhandled error %d\n", + __func__, status); + /* fall through */ + case -EADDRNOTAVAIL: + /* We're probably in TIME_WAIT. Get rid of existing socket, + * and retry + */ + xs_tcp_force_close(xprt); + break; + case 0: + case -EINPROGRESS: + case -EALREADY: + xprt_unlock_connect(xprt, transport); + return; + case -EINVAL: + /* Happens, for instance, if the user specified a link + * local IPv6 address without a scope-id. + */ + case -ECONNREFUSED: + case -ECONNRESET: + case -ENETDOWN: + case -ENETUNREACH: + case -EHOSTUNREACH: + case -EADDRINUSE: + case -ENOBUFS: + /* + * xs_tcp_force_close() wakes tasks with -EIO. + * We need to wake them first to ensure the + * correct error code. + */ + xprt_wake_pending_tasks(xprt, status); + xs_tcp_force_close(xprt); + goto out; + } + status = -EAGAIN; +out: + xprt_clear_connecting(xprt); + xprt_unlock_connect(xprt, transport); + xprt_wake_pending_tasks(xprt, status); +} + +static unsigned long xs_reconnect_delay(const struct rpc_xprt *xprt) +{ + unsigned long start, now = jiffies; + + start = xprt->stat.connect_start + xprt->reestablish_timeout; + if (time_after(start, now)) + return start - now; + return 0; +} + +static void xs_reconnect_backoff(struct rpc_xprt *xprt) +{ + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > xprt->max_reconnect_timeout) + xprt->reestablish_timeout = xprt->max_reconnect_timeout; + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; +} + +/** + * xs_connect - connect a socket to a remote endpoint + * @xprt: pointer to transport structure + * @task: address of RPC task that manages state of connect request + * + * TCP: If the remote end dropped the connection, delay reconnecting. + * + * UDP socket connects are synchronous, but we use a work queue anyway + * to guarantee that even unprivileged user processes can set up a + * socket on a privileged port. + * + * If a UDP socket connect fails, the delay behavior here prevents + * retry floods (hard mounts). + */ +static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + unsigned long delay = 0; + + WARN_ON_ONCE(!xprt_lock_connect(xprt, task, transport)); + + if (transport->sock != NULL) { + dprintk("RPC: xs_connect delayed xprt %p for %lu " + "seconds\n", + xprt, xprt->reestablish_timeout / HZ); + + /* Start by resetting any existing state */ + xs_reset_transport(transport); + + delay = xs_reconnect_delay(xprt); + xs_reconnect_backoff(xprt); + + } else + dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + + queue_delayed_work(xprtiod_workqueue, + &transport->connect_worker, + delay); +} + +/** + * xs_local_print_stats - display AF_LOCAL socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\tlocal %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/** + * xs_udp_print_stats - display UDP socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %llu %llu " + "%lu %llu %llu\n", + transport->srcport, + xprt->stat.bind_count, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/** + * xs_tcp_print_stats - display TCP socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", + transport->srcport, + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/* + * Allocate a bunch of pages for a scratch buffer for the rpc code. The reason + * we allocate pages instead doing a kmalloc like rpc_malloc is because we want + * to use the server side send routines. + */ +static int bc_malloc(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; + struct page *page; + struct rpc_buffer *buf; + + if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) { + WARN_ONCE(1, "xprtsock: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } + + page = alloc_page(GFP_KERNEL); + if (!page) + return -ENOMEM; + + buf = page_address(page); + buf->len = PAGE_SIZE; + + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; +} + +/* + * Free the space allocated in the bc_alloc routine + */ +static void bc_free(struct rpc_task *task) +{ + void *buffer = task->tk_rqstp->rq_buffer; + struct rpc_buffer *buf; + + buf = container_of(buffer, struct rpc_buffer, data); + free_page((unsigned long)buf); +} + +/* + * Use the svc_sock to send the callback. Must be called with svsk->sk_mutex + * held. Borrows heavily from svc_tcp_sendto and xs_tcp_send_request. + */ +static int bc_sendto(struct rpc_rqst *req) +{ + int len; + struct xdr_buf *xbufp = &req->rq_snd_buf; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + unsigned long headoff; + unsigned long tailoff; + + xs_encode_stream_record_marker(xbufp); + + tailoff = (unsigned long)xbufp->tail[0].iov_base & ~PAGE_MASK; + headoff = (unsigned long)xbufp->head[0].iov_base & ~PAGE_MASK; + len = svc_send_common(sock, xbufp, + virt_to_page(xbufp->head[0].iov_base), headoff, + xbufp->tail[0].iov_base, tailoff); + + if (len != xbufp->len) { + printk(KERN_NOTICE "Error sending entire callback!\n"); + len = -EAGAIN; + } + + return len; +} + +/* + * The send routine. Borrows from svc_send + */ +static int bc_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct svc_xprt *xprt; + int len; + + dprintk("sending request with xid: %08x\n", ntohl(req->rq_xid)); + /* + * Get the server socket associated with this callback xprt + */ + xprt = req->rq_xprt->bc_xprt; + + /* + * Grab the mutex to serialize data as the connection is shared + * with the fore channel + */ + if (!mutex_trylock(&xprt->xpt_mutex)) { + rpc_sleep_on(&xprt->xpt_bc_pending, task, NULL); + if (!mutex_trylock(&xprt->xpt_mutex)) + return -EAGAIN; + rpc_wake_up_queued_task(&xprt->xpt_bc_pending, task); + } + if (test_bit(XPT_DEAD, &xprt->xpt_flags)) + len = -ENOTCONN; + else + len = bc_sendto(req); + mutex_unlock(&xprt->xpt_mutex); + + if (len > 0) + len = 0; + + return len; +} + +/* + * The close routine. Since this is client initiated, we do nothing + */ + +static void bc_close(struct rpc_xprt *xprt) +{ +} + +/* + * The xprt destroy routine. Again, because this connection is client + * initiated, we do nothing + */ + +static void bc_destroy(struct rpc_xprt *xprt) +{ + dprintk("RPC: bc_destroy xprt %p\n", xprt); + + xs_xprt_free(xprt); + module_put(THIS_MODULE); +} + +static const struct rpc_xprt_ops xs_local_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xs_tcp_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = xs_local_rpcbind, + .set_port = xs_local_set_port, + .connect = xs_local_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_local_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_local_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, +}; + +static const struct rpc_xprt_ops xs_udp_ops = { + .set_buffer_size = xs_udp_set_buffer_size, + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_udp_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_rtt, + .timer = xs_udp_timer, + .release_request = xprt_release_rqst_cong, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_udp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +}; + +static const struct rpc_xprt_ops xs_tcp_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xs_tcp_release_xprt, + .alloc_slot = xprt_lock_and_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_tcp_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = xs_tcp_shutdown, + .destroy = xs_destroy, + .set_connect_timeout = xs_tcp_set_connect_timeout, + .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +#ifdef CONFIG_SUNRPC_BACKCHANNEL + .bc_setup = xprt_setup_bc, + .bc_up = xs_tcp_bc_up, + .bc_maxpayload = xs_tcp_bc_maxpayload, + .bc_free_rqst = xprt_free_bc_rqst, + .bc_destroy = xprt_destroy_bc, +#endif +}; + +/* + * The rpc_xprt_ops for the server backchannel + */ + +static const struct rpc_xprt_ops bc_tcp_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .buf_alloc = bc_malloc, + .buf_free = bc_free, + .send_request = bc_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = bc_close, + .destroy = bc_destroy, + .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +}; + +static int xs_init_anyaddr(const int family, struct sockaddr *sap) +{ + static const struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + static const struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + }; + + switch (family) { + case AF_LOCAL: + break; + case AF_INET: + memcpy(sap, &sin, sizeof(sin)); + break; + case AF_INET6: + memcpy(sap, &sin6, sizeof(sin6)); + break; + default: + dprintk("RPC: %s: Bad address family\n", __func__); + return -EAFNOSUPPORT; + } + return 0; +} + +static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, + unsigned int slot_table_size, + unsigned int max_slot_table_size) +{ + struct rpc_xprt *xprt; + struct sock_xprt *new; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: xs_setup_xprt: address too large\n"); + return ERR_PTR(-EBADF); + } + + xprt = xprt_alloc(args->net, sizeof(*new), slot_table_size, + max_slot_table_size); + if (xprt == NULL) { + dprintk("RPC: xs_setup_xprt: couldn't allocate " + "rpc_xprt\n"); + return ERR_PTR(-ENOMEM); + } + + new = container_of(xprt, struct sock_xprt, xprt); + mutex_init(&new->recv_mutex); + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + if (args->srcaddr) + memcpy(&new->srcaddr, args->srcaddr, args->addrlen); + else { + int err; + err = xs_init_anyaddr(args->dstaddr->sa_family, + (struct sockaddr *)&new->srcaddr); + if (err != 0) { + xprt_free(xprt); + return ERR_PTR(err); + } + } + + return xprt; +} + +static const struct rpc_timeout xs_local_default_timeout = { + .to_initval = 10 * HZ, + .to_maxval = 10 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_local - Set up transport to use an AF_LOCAL socket + * @args: rpc transport creation arguments + * + * AF_LOCAL is a "tpi_cots_ord" transport, just like TCP + */ +static struct rpc_xprt *xs_setup_local(struct xprt_create *args) +{ + struct sockaddr_un *sun = (struct sockaddr_un *)args->dstaddr; + struct sock_xprt *transport; + struct rpc_xprt *xprt; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_max_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = 0; + xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_local_ops; + xprt->timeout = &xs_local_default_timeout; + + INIT_WORK(&transport->recv_worker, xs_local_data_receive_workfn); + INIT_DELAYED_WORK(&transport->connect_worker, + xs_dummy_setup_socket); + + switch (sun->sun_family) { + case AF_LOCAL: + if (sun->sun_path[0] != '/') { + dprintk("RPC: bad AF_LOCAL address: %s\n", + sun->sun_path); + ret = ERR_PTR(-EINVAL); + goto out_err; + } + xprt_set_bound(xprt); + xs_format_peer_addresses(xprt, "local", RPCBIND_NETID_LOCAL); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + dprintk("RPC: set up xprt to %s via AF_LOCAL\n", + xprt->address_strings[RPC_DISPLAY_ADDR]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static const struct rpc_timeout xs_udp_default_timeout = { + .to_initval = 5 * HZ, + .to_maxval = 30 * HZ, + .to_increment = 5 * HZ, + .to_retries = 5, +}; + +/** + * xs_setup_udp - Set up transport to use a UDP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries, + xprt_udp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_UDP; + xprt->tsh_size = 0; + /* XXX: header size can vary due to auth type, IPv6, etc. */ + xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_UDP_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_udp_ops; + + xprt->timeout = &xs_udp_default_timeout; + + INIT_WORK(&transport->recv_worker, xs_udp_data_receive_workfn); + INIT_DELAYED_WORK(&transport->connect_worker, xs_udp_setup_socket); + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static const struct rpc_timeout xs_tcp_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_tcp - Set up transport to use a TCP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_tcp_data_receive_workfn); + INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_setup_socket); + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** + * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct svc_sock *bc_sock; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + xprt->timeout = &xs_tcp_default_timeout; + + /* backchannel */ + xprt_set_bound(xprt); + xprt->bind_timeout = 0; + xprt->reestablish_timeout = 0; + xprt->idle_timeout = 0; + + xprt->ops = &bc_tcp_ops; + + switch (addr->sa_family) { + case AF_INET: + xs_format_peer_addresses(xprt, "tcp", + RPCBIND_NETID_TCP); + break; + case AF_INET6: + xs_format_peer_addresses(xprt, "tcp", + RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + /* + * Once we've associated a backchannel xprt with a connection, + * we want to keep it around as long as the connection lasts, + * in case we need to start using it for a backchannel again; + * this reference won't be dropped until bc_xprt is destroyed. + */ + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt); + transport->sock = bc_sock->sk_sock; + transport->inet = bc_sock->sk_sk; + + /* + * Since we don't want connections for the backchannel, we set + * the xprt status to connected + */ + xprt_set_connected(xprt); + + if (try_module_get(THIS_MODULE)) + return xprt; + + args->bc_xprt->xpt_bc_xprt = NULL; + args->bc_xprt->xpt_bc_xps = NULL; + xprt_put(xprt); + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static struct xprt_class xs_local_transport = { + .list = LIST_HEAD_INIT(xs_local_transport.list), + .name = "named UNIX socket", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_LOCAL, + .setup = xs_setup_local, + .netid = { "" }, +}; + +static struct xprt_class xs_udp_transport = { + .list = LIST_HEAD_INIT(xs_udp_transport.list), + .name = "udp", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_UDP, + .setup = xs_setup_udp, + .netid = { "udp", "udp6", "" }, +}; + +static struct xprt_class xs_tcp_transport = { + .list = LIST_HEAD_INIT(xs_tcp_transport.list), + .name = "tcp", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP, + .setup = xs_setup_tcp, + .netid = { "tcp", "tcp6", "" }, +}; + +static struct xprt_class xs_bc_tcp_transport = { + .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list), + .name = "tcp NFSv4.1 backchannel", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_BC_TCP, + .setup = xs_setup_bc_tcp, + .netid = { "" }, +}; + +/** + * init_socket_xprt - set up xprtsock's sysctls, register with RPC client + * + */ +int init_socket_xprt(void) +{ +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +#endif + + xprt_register_transport(&xs_local_transport); + xprt_register_transport(&xs_udp_transport); + xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_bc_tcp_transport); + + return 0; +} + +/** + * cleanup_socket_xprt - remove xprtsock's sysctls, unregister + * + */ +void cleanup_socket_xprt(void) +{ +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +#endif + + xprt_unregister_transport(&xs_local_transport); + xprt_unregister_transport(&xs_udp_transport); + xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_bc_tcp_transport); +} + +static int param_set_uint_minmax(const char *val, + const struct kernel_param *kp, + unsigned int min, unsigned int max) +{ + unsigned int num; + int ret; + + if (!val) + return -EINVAL; + ret = kstrtouint(val, 0, &num); + if (ret) + return ret; + if (num < min || num > max) + return -EINVAL; + *((unsigned int *)kp->arg) = num; + return 0; +} + +static int param_set_portnr(const char *val, const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_RESVPORT, + RPC_MAX_RESVPORT); +} + +static const struct kernel_param_ops param_ops_portnr = { + .set = param_set_portnr, + .get = param_get_uint, +}; + +#define param_check_portnr(name, p) \ + __param_check(name, p, unsigned int); + +module_param_named(min_resvport, xprt_min_resvport, portnr, 0644); +module_param_named(max_resvport, xprt_max_resvport, portnr, 0644); + +static int param_set_slot_table_size(const char *val, + const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_SLOT_TABLE, + RPC_MAX_SLOT_TABLE); +} + +static const struct kernel_param_ops param_ops_slot_table_size = { + .set = param_set_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_slot_table_size(name, p) \ + __param_check(name, p, unsigned int); + +static int param_set_max_slot_table_size(const char *val, + const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_SLOT_TABLE, + RPC_MAX_SLOT_TABLE_LIMIT); +} + +static const struct kernel_param_ops param_ops_max_slot_table_size = { + .set = param_set_max_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_max_slot_table_size(name, p) \ + __param_check(name, p, unsigned int); + +module_param_named(tcp_slot_table_entries, xprt_tcp_slot_table_entries, + slot_table_size, 0644); +module_param_named(tcp_max_slot_table_entries, xprt_max_tcp_slot_table_entries, + max_slot_table_size, 0644); +module_param_named(udp_slot_table_entries, xprt_udp_slot_table_entries, + slot_table_size, 0644); |