summaryrefslogtreecommitdiffstats
path: root/net/sunrpc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-11 08:27:49 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-11 08:27:49 +0000
commitace9429bb58fd418f0c81d4c2835699bddf6bde6 (patch)
treeb2d64bc10158fdd5497876388cd68142ca374ed3 /net/sunrpc
parentInitial commit. (diff)
downloadlinux-ace9429bb58fd418f0c81d4c2835699bddf6bde6.tar.xz
linux-ace9429bb58fd418f0c81d4c2835699bddf6bde6.zip
Adding upstream version 6.6.15.upstream/6.6.15
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--net/sunrpc/.kunitconfig29
-rw-r--r--net/sunrpc/Kconfig117
-rw-r--r--net/sunrpc/Makefile21
-rw-r--r--net/sunrpc/addr.c354
-rw-r--r--net/sunrpc/auth.c893
-rw-r--r--net/sunrpc/auth_gss/Makefile17
-rw-r--r--net/sunrpc/auth_gss/auth_gss.c2298
-rw-r--r--net/sunrpc/auth_gss/auth_gss_internal.h45
-rw-r--r--net/sunrpc/auth_gss/gss_generic_token.c231
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_crypto.c1154
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_internal.h209
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_keys.c546
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_mech.c655
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_seal.c133
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_test.c1859
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_unseal.c128
-rw-r--r--net/sunrpc/auth_gss/gss_krb5_wrap.c237
-rw-r--r--net/sunrpc/auth_gss/gss_mech_switch.c448
-rw-r--r--net/sunrpc/auth_gss/gss_rpc_upcall.c403
-rw-r--r--net/sunrpc/auth_gss/gss_rpc_upcall.h36
-rw-r--r--net/sunrpc/auth_gss/gss_rpc_xdr.c838
-rw-r--r--net/sunrpc/auth_gss/gss_rpc_xdr.h252
-rw-r--r--net/sunrpc/auth_gss/svcauth_gss.c2134
-rw-r--r--net/sunrpc/auth_gss/trace.c14
-rw-r--r--net/sunrpc/auth_null.c143
-rw-r--r--net/sunrpc/auth_tls.c175
-rw-r--r--net/sunrpc/auth_unix.c243
-rw-r--r--net/sunrpc/backchannel_rqst.c376
-rw-r--r--net/sunrpc/cache.c1918
-rw-r--r--net/sunrpc/clnt.c3395
-rw-r--r--net/sunrpc/debugfs.c294
-rw-r--r--net/sunrpc/fail.h25
-rw-r--r--net/sunrpc/netns.h44
-rw-r--r--net/sunrpc/rpc_pipe.c1517
-rw-r--r--net/sunrpc/rpcb_clnt.c1121
-rw-r--r--net/sunrpc/sched.c1361
-rw-r--r--net/sunrpc/socklib.c324
-rw-r--r--net/sunrpc/socklib.h15
-rw-r--r--net/sunrpc/stats.c348
-rw-r--r--net/sunrpc/sunrpc.h42
-rw-r--r--net/sunrpc/sunrpc_syms.c153
-rw-r--r--net/sunrpc/svc.c1764
-rw-r--r--net/sunrpc/svc_xprt.c1450
-rw-r--r--net/sunrpc/svcauth.c260
-rw-r--r--net/sunrpc/svcauth_unix.c1061
-rw-r--r--net/sunrpc/svcsock.c1644
-rw-r--r--net/sunrpc/sysctl.c181
-rw-r--r--net/sunrpc/sysfs.c627
-rw-r--r--net/sunrpc/sysfs.h35
-rw-r--r--net/sunrpc/timer.c123
-rw-r--r--net/sunrpc/xdr.c2413
-rw-r--r--net/sunrpc/xprt.c2192
-rw-r--r--net/sunrpc/xprtmultipath.c655
-rw-r--r--net/sunrpc/xprtrdma/Makefile8
-rw-r--r--net/sunrpc/xprtrdma/backchannel.c282
-rw-r--r--net/sunrpc/xprtrdma/frwr_ops.c696
-rw-r--r--net/sunrpc/xprtrdma/module.c52
-rw-r--r--net/sunrpc/xprtrdma/rpc_rdma.c1510
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma.c283
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_backchannel.c287
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_pcl.c306
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_recvfrom.c863
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_rw.c1169
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_sendto.c1062
-rw-r--r--net/sunrpc/xprtrdma/svc_rdma_transport.c603
-rw-r--r--net/sunrpc/xprtrdma/transport.c796
-rw-r--r--net/sunrpc/xprtrdma/verbs.c1396
-rw-r--r--net/sunrpc/xprtrdma/xprt_rdma.h604
-rw-r--r--net/sunrpc/xprtsock.c3718
69 files changed, 50585 insertions, 0 deletions
diff --git a/net/sunrpc/.kunitconfig b/net/sunrpc/.kunitconfig
new file mode 100644
index 0000000000..eb02b906c2
--- /dev/null
+++ b/net/sunrpc/.kunitconfig
@@ -0,0 +1,29 @@
+CONFIG_KUNIT=y
+CONFIG_UBSAN=y
+CONFIG_STACKTRACE=y
+CONFIG_NET=y
+CONFIG_NETWORK_FILESYSTEMS=y
+CONFIG_INET=y
+CONFIG_FILE_LOCKING=y
+CONFIG_MULTIUSER=y
+CONFIG_CRYPTO=y
+CONFIG_CRYPTO_CBC=y
+CONFIG_CRYPTO_CTS=y
+CONFIG_CRYPTO_ECB=y
+CONFIG_CRYPTO_HMAC=y
+CONFIG_CRYPTO_CMAC=y
+CONFIG_CRYPTO_MD5=y
+CONFIG_CRYPTO_SHA1=y
+CONFIG_CRYPTO_SHA256=y
+CONFIG_CRYPTO_SHA512=y
+CONFIG_CRYPTO_DES=y
+CONFIG_CRYPTO_AES=y
+CONFIG_CRYPTO_CAMELLIA=y
+CONFIG_NFS_FS=y
+CONFIG_SUNRPC=y
+CONFIG_SUNRPC_GSS=y
+CONFIG_RPCSEC_GSS_KRB5=y
+CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1=y
+CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA=y
+CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2=y
+CONFIG_RPCSEC_GSS_KRB5_KUNIT_TEST=y
diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig
new file mode 100644
index 0000000000..2d8b67dac7
--- /dev/null
+++ b/net/sunrpc/Kconfig
@@ -0,0 +1,117 @@
+# SPDX-License-Identifier: GPL-2.0-only
+config SUNRPC
+ tristate
+ depends on MULTIUSER
+
+config SUNRPC_GSS
+ tristate
+ select OID_REGISTRY
+ depends on MULTIUSER
+
+config SUNRPC_BACKCHANNEL
+ bool
+ depends on SUNRPC
+
+config SUNRPC_SWAP
+ bool
+ depends on SUNRPC
+
+config RPCSEC_GSS_KRB5
+ tristate "Secure RPC: Kerberos V mechanism"
+ depends on SUNRPC && CRYPTO
+ default y
+ select SUNRPC_GSS
+ select CRYPTO_SKCIPHER
+ select CRYPTO_HASH
+ help
+ Choose Y here to enable Secure RPC using the Kerberos version 5
+ GSS-API mechanism (RFC 1964).
+
+ Secure RPC calls with Kerberos require an auxiliary user-space
+ daemon which may be found in the Linux nfs-utils package
+ available from http://linux-nfs.org/. In addition, user-space
+ Kerberos support should be installed.
+
+ If unsure, say Y.
+
+config RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1
+ bool "Enable Kerberos enctypes based on AES and SHA-1"
+ depends on RPCSEC_GSS_KRB5
+ depends on CRYPTO_CBC && CRYPTO_CTS
+ depends on CRYPTO_HMAC && CRYPTO_SHA1
+ depends on CRYPTO_AES
+ default y
+ help
+ Choose Y to enable the use of Kerberos 5 encryption types
+ that utilize Advanced Encryption Standard (AES) ciphers and
+ SHA-1 digests. These include aes128-cts-hmac-sha1-96 and
+ aes256-cts-hmac-sha1-96.
+
+config RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA
+ bool "Enable Kerberos encryption types based on Camellia and CMAC"
+ depends on RPCSEC_GSS_KRB5
+ depends on CRYPTO_CBC && CRYPTO_CTS && CRYPTO_CAMELLIA
+ depends on CRYPTO_CMAC
+ default n
+ help
+ Choose Y to enable the use of Kerberos 5 encryption types
+ that utilize Camellia ciphers (RFC 3713) and CMAC digests
+ (NIST Special Publication 800-38B). These include
+ camellia128-cts-cmac and camellia256-cts-cmac.
+
+config RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2
+ bool "Enable Kerberos enctypes based on AES and SHA-2"
+ depends on RPCSEC_GSS_KRB5
+ depends on CRYPTO_CBC && CRYPTO_CTS
+ depends on CRYPTO_HMAC && CRYPTO_SHA256 && CRYPTO_SHA512
+ depends on CRYPTO_AES
+ default n
+ help
+ Choose Y to enable the use of Kerberos 5 encryption types
+ that utilize Advanced Encryption Standard (AES) ciphers and
+ SHA-2 digests. These include aes128-cts-hmac-sha256-128 and
+ aes256-cts-hmac-sha384-192.
+
+config RPCSEC_GSS_KRB5_KUNIT_TEST
+ tristate "KUnit tests for RPCSEC GSS Kerberos" if !KUNIT_ALL_TESTS
+ depends on RPCSEC_GSS_KRB5 && KUNIT
+ default KUNIT_ALL_TESTS
+ help
+ This builds the KUnit tests for RPCSEC GSS Kerberos 5.
+
+ KUnit tests run during boot and output the results to the debug
+ log in TAP format (https://testanything.org/). Only useful for
+ kernel devs running KUnit test harness and are not for inclusion
+ into a production build.
+
+ For more information on KUnit and unit tests in general, refer
+ to the KUnit documentation in Documentation/dev-tools/kunit/.
+
+config SUNRPC_DEBUG
+ bool "RPC: Enable dprintk debugging"
+ depends on SUNRPC && SYSCTL
+ select DEBUG_FS
+ help
+ This option enables a sysctl-based debugging interface
+ that is be used by the 'rpcdebug' utility to turn on or off
+ logging of different aspects of the kernel RPC activity.
+
+ Disabling this option will make your kernel slightly smaller,
+ but makes troubleshooting NFS issues significantly harder.
+
+ If unsure, say Y.
+
+config SUNRPC_XPRT_RDMA
+ tristate "RPC-over-RDMA transport"
+ depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS
+ default SUNRPC && INFINIBAND
+ select SG_POOL
+ help
+ This option allows the NFS client and server to use RDMA
+ transports (InfiniBand, iWARP, or RoCE).
+
+ To compile this support as a module, choose M. The module
+ will be called rpcrdma.ko.
+
+ If unsure, or you know there is no RDMA capability on your
+ hardware platform, say N.
diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile
new file mode 100644
index 0000000000..f89c10fe7e
--- /dev/null
+++ b/net/sunrpc/Makefile
@@ -0,0 +1,21 @@
+# SPDX-License-Identifier: GPL-2.0
+#
+# Makefile for Linux kernel SUN RPC
+#
+
+
+obj-$(CONFIG_SUNRPC) += sunrpc.o
+obj-$(CONFIG_SUNRPC_GSS) += auth_gss/
+obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/
+
+sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \
+ auth.o auth_null.o auth_tls.o auth_unix.o \
+ svc.o svcsock.o svcauth.o svcauth_unix.o \
+ addr.o rpcb_clnt.o timer.o xdr.o \
+ sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \
+ svc_xprt.o \
+ xprtmultipath.o
+sunrpc-$(CONFIG_SUNRPC_DEBUG) += debugfs.o
+sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o
+sunrpc-$(CONFIG_PROC_FS) += stats.o
+sunrpc-$(CONFIG_SYSCTL) += sysctl.o
diff --git a/net/sunrpc/addr.c b/net/sunrpc/addr.c
new file mode 100644
index 0000000000..d435bffc61
--- /dev/null
+++ b/net/sunrpc/addr.c
@@ -0,0 +1,354 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright 2009, Oracle. All rights reserved.
+ *
+ * Convert socket addresses to presentation addresses and universal
+ * addresses, and vice versa.
+ *
+ * Universal addresses are introduced by RFC 1833 and further refined by
+ * recent RFCs describing NFSv4. The universal address format is part
+ * of the external (network) interface provided by rpcbind version 3
+ * and 4, and by NFSv4. Such an address is a string containing a
+ * presentation format IP address followed by a port number in
+ * "hibyte.lobyte" format.
+ *
+ * IPv6 addresses can also include a scope ID, typically denoted by
+ * a '%' followed by a device name or a non-negative integer. Refer to
+ * RFC 4291, Section 2.2 for details on IPv6 presentation formats.
+ */
+
+#include <net/ipv6.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/msg_prot.h>
+#include <linux/slab.h>
+#include <linux/export.h>
+
+#if IS_ENABLED(CONFIG_IPV6)
+
+static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap,
+ char *buf, const int buflen)
+{
+ const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap;
+ const struct in6_addr *addr = &sin6->sin6_addr;
+
+ /*
+ * RFC 4291, Section 2.2.2
+ *
+ * Shorthanded ANY address
+ */
+ if (ipv6_addr_any(addr))
+ return snprintf(buf, buflen, "::");
+
+ /*
+ * RFC 4291, Section 2.2.2
+ *
+ * Shorthanded loopback address
+ */
+ if (ipv6_addr_loopback(addr))
+ return snprintf(buf, buflen, "::1");
+
+ /*
+ * RFC 4291, Section 2.2.3
+ *
+ * Special presentation address format for mapped v4
+ * addresses.
+ */
+ if (ipv6_addr_v4mapped(addr))
+ return snprintf(buf, buflen, "::ffff:%pI4",
+ &addr->s6_addr32[3]);
+
+ /*
+ * RFC 4291, Section 2.2.1
+ */
+ return snprintf(buf, buflen, "%pI6c", addr);
+}
+
+static size_t rpc_ntop6(const struct sockaddr *sap,
+ char *buf, const size_t buflen)
+{
+ const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap;
+ char scopebuf[IPV6_SCOPE_ID_LEN];
+ size_t len;
+ int rc;
+
+ len = rpc_ntop6_noscopeid(sap, buf, buflen);
+ if (unlikely(len == 0))
+ return len;
+
+ if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL))
+ return len;
+ if (sin6->sin6_scope_id == 0)
+ return len;
+
+ rc = snprintf(scopebuf, sizeof(scopebuf), "%c%u",
+ IPV6_SCOPE_DELIMITER, sin6->sin6_scope_id);
+ if (unlikely((size_t)rc >= sizeof(scopebuf)))
+ return 0;
+
+ len += rc;
+ if (unlikely(len >= buflen))
+ return 0;
+
+ strcat(buf, scopebuf);
+ return len;
+}
+
+#else /* !IS_ENABLED(CONFIG_IPV6) */
+
+static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap,
+ char *buf, const int buflen)
+{
+ return 0;
+}
+
+static size_t rpc_ntop6(const struct sockaddr *sap,
+ char *buf, const size_t buflen)
+{
+ return 0;
+}
+
+#endif /* !IS_ENABLED(CONFIG_IPV6) */
+
+static int rpc_ntop4(const struct sockaddr *sap,
+ char *buf, const size_t buflen)
+{
+ const struct sockaddr_in *sin = (struct sockaddr_in *)sap;
+
+ return snprintf(buf, buflen, "%pI4", &sin->sin_addr);
+}
+
+/**
+ * rpc_ntop - construct a presentation address in @buf
+ * @sap: socket address
+ * @buf: construction area
+ * @buflen: size of @buf, in bytes
+ *
+ * Plants a %NUL-terminated string in @buf and returns the length
+ * of the string, excluding the %NUL. Otherwise zero is returned.
+ */
+size_t rpc_ntop(const struct sockaddr *sap, char *buf, const size_t buflen)
+{
+ switch (sap->sa_family) {
+ case AF_INET:
+ return rpc_ntop4(sap, buf, buflen);
+ case AF_INET6:
+ return rpc_ntop6(sap, buf, buflen);
+ }
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_ntop);
+
+static size_t rpc_pton4(const char *buf, const size_t buflen,
+ struct sockaddr *sap, const size_t salen)
+{
+ struct sockaddr_in *sin = (struct sockaddr_in *)sap;
+ u8 *addr = (u8 *)&sin->sin_addr.s_addr;
+
+ if (buflen > INET_ADDRSTRLEN || salen < sizeof(struct sockaddr_in))
+ return 0;
+
+ memset(sap, 0, sizeof(struct sockaddr_in));
+
+ if (in4_pton(buf, buflen, addr, '\0', NULL) == 0)
+ return 0;
+
+ sin->sin_family = AF_INET;
+ return sizeof(struct sockaddr_in);
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+static int rpc_parse_scope_id(struct net *net, const char *buf,
+ const size_t buflen, const char *delim,
+ struct sockaddr_in6 *sin6)
+{
+ char p[IPV6_SCOPE_ID_LEN + 1];
+ size_t len;
+ u32 scope_id = 0;
+ struct net_device *dev;
+
+ if ((buf + buflen) == delim)
+ return 1;
+
+ if (*delim != IPV6_SCOPE_DELIMITER)
+ return 0;
+
+ if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL))
+ return 0;
+
+ len = (buf + buflen) - delim - 1;
+ if (len > IPV6_SCOPE_ID_LEN)
+ return 0;
+
+ memcpy(p, delim + 1, len);
+ p[len] = 0;
+
+ dev = dev_get_by_name(net, p);
+ if (dev != NULL) {
+ scope_id = dev->ifindex;
+ dev_put(dev);
+ } else {
+ if (kstrtou32(p, 10, &scope_id) != 0)
+ return 0;
+ }
+
+ sin6->sin6_scope_id = scope_id;
+ return 1;
+}
+
+static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen,
+ struct sockaddr *sap, const size_t salen)
+{
+ struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap;
+ u8 *addr = (u8 *)&sin6->sin6_addr.in6_u;
+ const char *delim;
+
+ if (buflen > (INET6_ADDRSTRLEN + IPV6_SCOPE_ID_LEN) ||
+ salen < sizeof(struct sockaddr_in6))
+ return 0;
+
+ memset(sap, 0, sizeof(struct sockaddr_in6));
+
+ if (in6_pton(buf, buflen, addr, IPV6_SCOPE_DELIMITER, &delim) == 0)
+ return 0;
+
+ if (!rpc_parse_scope_id(net, buf, buflen, delim, sin6))
+ return 0;
+
+ sin6->sin6_family = AF_INET6;
+ return sizeof(struct sockaddr_in6);
+}
+#else
+static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen,
+ struct sockaddr *sap, const size_t salen)
+{
+ return 0;
+}
+#endif
+
+/**
+ * rpc_pton - Construct a sockaddr in @sap
+ * @net: applicable network namespace
+ * @buf: C string containing presentation format IP address
+ * @buflen: length of presentation address in bytes
+ * @sap: buffer into which to plant socket address
+ * @salen: size of buffer in bytes
+ *
+ * Returns the size of the socket address if successful; otherwise
+ * zero is returned.
+ *
+ * Plants a socket address in @sap and returns the size of the
+ * socket address, if successful. Returns zero if an error
+ * occurred.
+ */
+size_t rpc_pton(struct net *net, const char *buf, const size_t buflen,
+ struct sockaddr *sap, const size_t salen)
+{
+ unsigned int i;
+
+ for (i = 0; i < buflen; i++)
+ if (buf[i] == ':')
+ return rpc_pton6(net, buf, buflen, sap, salen);
+ return rpc_pton4(buf, buflen, sap, salen);
+}
+EXPORT_SYMBOL_GPL(rpc_pton);
+
+/**
+ * rpc_sockaddr2uaddr - Construct a universal address string from @sap.
+ * @sap: socket address
+ * @gfp_flags: allocation mode
+ *
+ * Returns a %NUL-terminated string in dynamically allocated memory;
+ * otherwise NULL is returned if an error occurred. Caller must
+ * free the returned string.
+ */
+char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags)
+{
+ char portbuf[RPCBIND_MAXUADDRPLEN];
+ char addrbuf[RPCBIND_MAXUADDRLEN];
+ unsigned short port;
+
+ switch (sap->sa_family) {
+ case AF_INET:
+ if (rpc_ntop4(sap, addrbuf, sizeof(addrbuf)) == 0)
+ return NULL;
+ port = ntohs(((struct sockaddr_in *)sap)->sin_port);
+ break;
+ case AF_INET6:
+ if (rpc_ntop6_noscopeid(sap, addrbuf, sizeof(addrbuf)) == 0)
+ return NULL;
+ port = ntohs(((struct sockaddr_in6 *)sap)->sin6_port);
+ break;
+ default:
+ return NULL;
+ }
+
+ if (snprintf(portbuf, sizeof(portbuf),
+ ".%u.%u", port >> 8, port & 0xff) > (int)sizeof(portbuf))
+ return NULL;
+
+ if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) > sizeof(addrbuf))
+ return NULL;
+
+ return kstrdup(addrbuf, gfp_flags);
+}
+
+/**
+ * rpc_uaddr2sockaddr - convert a universal address to a socket address.
+ * @net: applicable network namespace
+ * @uaddr: C string containing universal address to convert
+ * @uaddr_len: length of universal address string
+ * @sap: buffer into which to plant socket address
+ * @salen: size of buffer
+ *
+ * @uaddr does not have to be '\0'-terminated, but kstrtou8() and
+ * rpc_pton() require proper string termination to be successful.
+ *
+ * Returns the size of the socket address if successful; otherwise
+ * zero is returned.
+ */
+size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr,
+ const size_t uaddr_len, struct sockaddr *sap,
+ const size_t salen)
+{
+ char *c, buf[RPCBIND_MAXUADDRLEN + sizeof('\0')];
+ u8 portlo, porthi;
+ unsigned short port;
+
+ if (uaddr_len > RPCBIND_MAXUADDRLEN)
+ return 0;
+
+ memcpy(buf, uaddr, uaddr_len);
+
+ buf[uaddr_len] = '\0';
+ c = strrchr(buf, '.');
+ if (unlikely(c == NULL))
+ return 0;
+ if (unlikely(kstrtou8(c + 1, 10, &portlo) != 0))
+ return 0;
+
+ *c = '\0';
+ c = strrchr(buf, '.');
+ if (unlikely(c == NULL))
+ return 0;
+ if (unlikely(kstrtou8(c + 1, 10, &porthi) != 0))
+ return 0;
+
+ port = (unsigned short)((porthi << 8) | portlo);
+
+ *c = '\0';
+ if (rpc_pton(net, buf, strlen(buf), sap, salen) == 0)
+ return 0;
+
+ switch (sap->sa_family) {
+ case AF_INET:
+ ((struct sockaddr_in *)sap)->sin_port = htons(port);
+ return sizeof(struct sockaddr_in);
+ case AF_INET6:
+ ((struct sockaddr_in6 *)sap)->sin6_port = htons(port);
+ return sizeof(struct sockaddr_in6);
+ }
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_uaddr2sockaddr);
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
new file mode 100644
index 0000000000..ec41b26af7
--- /dev/null
+++ b/net/sunrpc/auth.c
@@ -0,0 +1,893 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/auth.c
+ *
+ * Generic RPC client authentication API.
+ *
+ * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/types.h>
+#include <linux/sched.h>
+#include <linux/cred.h>
+#include <linux/module.h>
+#include <linux/slab.h>
+#include <linux/errno.h>
+#include <linux/hash.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/gss_api.h>
+#include <linux/spinlock.h>
+
+#include <trace/events/sunrpc.h>
+
+#define RPC_CREDCACHE_DEFAULT_HASHBITS (4)
+struct rpc_cred_cache {
+ struct hlist_head *hashtable;
+ unsigned int hashbits;
+ spinlock_t lock;
+};
+
+static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
+
+static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
+ [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
+ [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
+ [RPC_AUTH_TLS] = (const struct rpc_authops __force __rcu *)&authtls_ops,
+};
+
+static LIST_HEAD(cred_unused);
+static unsigned long number_cred_unused;
+
+static struct cred machine_cred = {
+ .usage = ATOMIC_INIT(1),
+};
+
+/*
+ * Return the machine_cred pointer to be used whenever
+ * the a generic machine credential is needed.
+ */
+const struct cred *rpc_machine_cred(void)
+{
+ return &machine_cred;
+}
+EXPORT_SYMBOL_GPL(rpc_machine_cred);
+
+#define MAX_HASHTABLE_BITS (14)
+static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
+{
+ unsigned long num;
+ unsigned int nbits;
+ int ret;
+
+ if (!val)
+ goto out_inval;
+ ret = kstrtoul(val, 0, &num);
+ if (ret)
+ goto out_inval;
+ nbits = fls(num - 1);
+ if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
+ goto out_inval;
+ *(unsigned int *)kp->arg = nbits;
+ return 0;
+out_inval:
+ return -EINVAL;
+}
+
+static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
+{
+ unsigned int nbits;
+
+ nbits = *(unsigned int *)kp->arg;
+ return sprintf(buffer, "%u\n", 1U << nbits);
+}
+
+#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
+
+static const struct kernel_param_ops param_ops_hashtbl_sz = {
+ .set = param_set_hashtbl_sz,
+ .get = param_get_hashtbl_sz,
+};
+
+module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
+MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
+
+static unsigned long auth_max_cred_cachesize = ULONG_MAX;
+module_param(auth_max_cred_cachesize, ulong, 0644);
+MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
+
+static u32
+pseudoflavor_to_flavor(u32 flavor) {
+ if (flavor > RPC_AUTH_MAXFLAVOR)
+ return RPC_AUTH_GSS;
+ return flavor;
+}
+
+int
+rpcauth_register(const struct rpc_authops *ops)
+{
+ const struct rpc_authops *old;
+ rpc_authflavor_t flavor;
+
+ if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
+ return -EINVAL;
+ old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
+ if (old == NULL || old == ops)
+ return 0;
+ return -EPERM;
+}
+EXPORT_SYMBOL_GPL(rpcauth_register);
+
+int
+rpcauth_unregister(const struct rpc_authops *ops)
+{
+ const struct rpc_authops *old;
+ rpc_authflavor_t flavor;
+
+ if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
+ return -EINVAL;
+
+ old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
+ if (old == ops || old == NULL)
+ return 0;
+ return -EPERM;
+}
+EXPORT_SYMBOL_GPL(rpcauth_unregister);
+
+static const struct rpc_authops *
+rpcauth_get_authops(rpc_authflavor_t flavor)
+{
+ const struct rpc_authops *ops;
+
+ if (flavor >= RPC_AUTH_MAXFLAVOR)
+ return NULL;
+
+ rcu_read_lock();
+ ops = rcu_dereference(auth_flavors[flavor]);
+ if (ops == NULL) {
+ rcu_read_unlock();
+ request_module("rpc-auth-%u", flavor);
+ rcu_read_lock();
+ ops = rcu_dereference(auth_flavors[flavor]);
+ if (ops == NULL)
+ goto out;
+ }
+ if (!try_module_get(ops->owner))
+ ops = NULL;
+out:
+ rcu_read_unlock();
+ return ops;
+}
+
+static void
+rpcauth_put_authops(const struct rpc_authops *ops)
+{
+ module_put(ops->owner);
+}
+
+/**
+ * rpcauth_get_pseudoflavor - check if security flavor is supported
+ * @flavor: a security flavor
+ * @info: a GSS mech OID, quality of protection, and service value
+ *
+ * Verifies that an appropriate kernel module is available or already loaded.
+ * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
+ * not supported locally.
+ */
+rpc_authflavor_t
+rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
+{
+ const struct rpc_authops *ops = rpcauth_get_authops(flavor);
+ rpc_authflavor_t pseudoflavor;
+
+ if (!ops)
+ return RPC_AUTH_MAXFLAVOR;
+ pseudoflavor = flavor;
+ if (ops->info2flavor != NULL)
+ pseudoflavor = ops->info2flavor(info);
+
+ rpcauth_put_authops(ops);
+ return pseudoflavor;
+}
+EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
+
+/**
+ * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
+ * @pseudoflavor: GSS pseudoflavor to match
+ * @info: rpcsec_gss_info structure to fill in
+ *
+ * Returns zero and fills in "info" if pseudoflavor matches a
+ * supported mechanism.
+ */
+int
+rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
+{
+ rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
+ const struct rpc_authops *ops;
+ int result;
+
+ ops = rpcauth_get_authops(flavor);
+ if (ops == NULL)
+ return -ENOENT;
+
+ result = -ENOENT;
+ if (ops->flavor2info != NULL)
+ result = ops->flavor2info(pseudoflavor, info);
+
+ rpcauth_put_authops(ops);
+ return result;
+}
+EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
+
+struct rpc_auth *
+rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
+{
+ struct rpc_auth *auth = ERR_PTR(-EINVAL);
+ const struct rpc_authops *ops;
+ u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
+
+ ops = rpcauth_get_authops(flavor);
+ if (ops == NULL)
+ goto out;
+
+ auth = ops->create(args, clnt);
+
+ rpcauth_put_authops(ops);
+ if (IS_ERR(auth))
+ return auth;
+ if (clnt->cl_auth)
+ rpcauth_release(clnt->cl_auth);
+ clnt->cl_auth = auth;
+
+out:
+ return auth;
+}
+EXPORT_SYMBOL_GPL(rpcauth_create);
+
+void
+rpcauth_release(struct rpc_auth *auth)
+{
+ if (!refcount_dec_and_test(&auth->au_count))
+ return;
+ auth->au_ops->destroy(auth);
+}
+
+static DEFINE_SPINLOCK(rpc_credcache_lock);
+
+/*
+ * On success, the caller is responsible for freeing the reference
+ * held by the hashtable
+ */
+static bool
+rpcauth_unhash_cred_locked(struct rpc_cred *cred)
+{
+ if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
+ return false;
+ hlist_del_rcu(&cred->cr_hash);
+ return true;
+}
+
+static bool
+rpcauth_unhash_cred(struct rpc_cred *cred)
+{
+ spinlock_t *cache_lock;
+ bool ret;
+
+ if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
+ return false;
+ cache_lock = &cred->cr_auth->au_credcache->lock;
+ spin_lock(cache_lock);
+ ret = rpcauth_unhash_cred_locked(cred);
+ spin_unlock(cache_lock);
+ return ret;
+}
+
+/*
+ * Initialize RPC credential cache
+ */
+int
+rpcauth_init_credcache(struct rpc_auth *auth)
+{
+ struct rpc_cred_cache *new;
+ unsigned int hashsize;
+
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (!new)
+ goto out_nocache;
+ new->hashbits = auth_hashbits;
+ hashsize = 1U << new->hashbits;
+ new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
+ if (!new->hashtable)
+ goto out_nohashtbl;
+ spin_lock_init(&new->lock);
+ auth->au_credcache = new;
+ return 0;
+out_nohashtbl:
+ kfree(new);
+out_nocache:
+ return -ENOMEM;
+}
+EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
+
+char *
+rpcauth_stringify_acceptor(struct rpc_cred *cred)
+{
+ if (!cred->cr_ops->crstringify_acceptor)
+ return NULL;
+ return cred->cr_ops->crstringify_acceptor(cred);
+}
+EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
+
+/*
+ * Destroy a list of credentials
+ */
+static inline
+void rpcauth_destroy_credlist(struct list_head *head)
+{
+ struct rpc_cred *cred;
+
+ while (!list_empty(head)) {
+ cred = list_entry(head->next, struct rpc_cred, cr_lru);
+ list_del_init(&cred->cr_lru);
+ put_rpccred(cred);
+ }
+}
+
+static void
+rpcauth_lru_add_locked(struct rpc_cred *cred)
+{
+ if (!list_empty(&cred->cr_lru))
+ return;
+ number_cred_unused++;
+ list_add_tail(&cred->cr_lru, &cred_unused);
+}
+
+static void
+rpcauth_lru_add(struct rpc_cred *cred)
+{
+ if (!list_empty(&cred->cr_lru))
+ return;
+ spin_lock(&rpc_credcache_lock);
+ rpcauth_lru_add_locked(cred);
+ spin_unlock(&rpc_credcache_lock);
+}
+
+static void
+rpcauth_lru_remove_locked(struct rpc_cred *cred)
+{
+ if (list_empty(&cred->cr_lru))
+ return;
+ number_cred_unused--;
+ list_del_init(&cred->cr_lru);
+}
+
+static void
+rpcauth_lru_remove(struct rpc_cred *cred)
+{
+ if (list_empty(&cred->cr_lru))
+ return;
+ spin_lock(&rpc_credcache_lock);
+ rpcauth_lru_remove_locked(cred);
+ spin_unlock(&rpc_credcache_lock);
+}
+
+/*
+ * Clear the RPC credential cache, and delete those credentials
+ * that are not referenced.
+ */
+void
+rpcauth_clear_credcache(struct rpc_cred_cache *cache)
+{
+ LIST_HEAD(free);
+ struct hlist_head *head;
+ struct rpc_cred *cred;
+ unsigned int hashsize = 1U << cache->hashbits;
+ int i;
+
+ spin_lock(&rpc_credcache_lock);
+ spin_lock(&cache->lock);
+ for (i = 0; i < hashsize; i++) {
+ head = &cache->hashtable[i];
+ while (!hlist_empty(head)) {
+ cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
+ rpcauth_unhash_cred_locked(cred);
+ /* Note: We now hold a reference to cred */
+ rpcauth_lru_remove_locked(cred);
+ list_add_tail(&cred->cr_lru, &free);
+ }
+ }
+ spin_unlock(&cache->lock);
+ spin_unlock(&rpc_credcache_lock);
+ rpcauth_destroy_credlist(&free);
+}
+
+/*
+ * Destroy the RPC credential cache
+ */
+void
+rpcauth_destroy_credcache(struct rpc_auth *auth)
+{
+ struct rpc_cred_cache *cache = auth->au_credcache;
+
+ if (cache) {
+ auth->au_credcache = NULL;
+ rpcauth_clear_credcache(cache);
+ kfree(cache->hashtable);
+ kfree(cache);
+ }
+}
+EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
+
+
+#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
+
+/*
+ * Remove stale credentials. Avoid sleeping inside the loop.
+ */
+static long
+rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
+{
+ struct rpc_cred *cred, *next;
+ unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
+ long freed = 0;
+
+ list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
+
+ if (nr_to_scan-- == 0)
+ break;
+ if (refcount_read(&cred->cr_count) > 1) {
+ rpcauth_lru_remove_locked(cred);
+ continue;
+ }
+ /*
+ * Enforce a 60 second garbage collection moratorium
+ * Note that the cred_unused list must be time-ordered.
+ */
+ if (time_in_range(cred->cr_expire, expired, jiffies))
+ continue;
+ if (!rpcauth_unhash_cred(cred))
+ continue;
+
+ rpcauth_lru_remove_locked(cred);
+ freed++;
+ list_add_tail(&cred->cr_lru, free);
+ }
+ return freed ? freed : SHRINK_STOP;
+}
+
+static unsigned long
+rpcauth_cache_do_shrink(int nr_to_scan)
+{
+ LIST_HEAD(free);
+ unsigned long freed;
+
+ spin_lock(&rpc_credcache_lock);
+ freed = rpcauth_prune_expired(&free, nr_to_scan);
+ spin_unlock(&rpc_credcache_lock);
+ rpcauth_destroy_credlist(&free);
+
+ return freed;
+}
+
+/*
+ * Run memory cache shrinker.
+ */
+static unsigned long
+rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
+
+{
+ if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
+ return SHRINK_STOP;
+
+ /* nothing left, don't come back */
+ if (list_empty(&cred_unused))
+ return SHRINK_STOP;
+
+ return rpcauth_cache_do_shrink(sc->nr_to_scan);
+}
+
+static unsigned long
+rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
+
+{
+ return number_cred_unused * sysctl_vfs_cache_pressure / 100;
+}
+
+static void
+rpcauth_cache_enforce_limit(void)
+{
+ unsigned long diff;
+ unsigned int nr_to_scan;
+
+ if (number_cred_unused <= auth_max_cred_cachesize)
+ return;
+ diff = number_cred_unused - auth_max_cred_cachesize;
+ nr_to_scan = 100;
+ if (diff < nr_to_scan)
+ nr_to_scan = diff;
+ rpcauth_cache_do_shrink(nr_to_scan);
+}
+
+/*
+ * Look up a process' credentials in the authentication cache
+ */
+struct rpc_cred *
+rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
+ int flags, gfp_t gfp)
+{
+ LIST_HEAD(free);
+ struct rpc_cred_cache *cache = auth->au_credcache;
+ struct rpc_cred *cred = NULL,
+ *entry, *new;
+ unsigned int nr;
+
+ nr = auth->au_ops->hash_cred(acred, cache->hashbits);
+
+ rcu_read_lock();
+ hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
+ if (!entry->cr_ops->crmatch(acred, entry, flags))
+ continue;
+ cred = get_rpccred(entry);
+ if (cred)
+ break;
+ }
+ rcu_read_unlock();
+
+ if (cred != NULL)
+ goto found;
+
+ new = auth->au_ops->crcreate(auth, acred, flags, gfp);
+ if (IS_ERR(new)) {
+ cred = new;
+ goto out;
+ }
+
+ spin_lock(&cache->lock);
+ hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
+ if (!entry->cr_ops->crmatch(acred, entry, flags))
+ continue;
+ cred = get_rpccred(entry);
+ if (cred)
+ break;
+ }
+ if (cred == NULL) {
+ cred = new;
+ set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
+ refcount_inc(&cred->cr_count);
+ hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
+ } else
+ list_add_tail(&new->cr_lru, &free);
+ spin_unlock(&cache->lock);
+ rpcauth_cache_enforce_limit();
+found:
+ if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
+ cred->cr_ops->cr_init != NULL &&
+ !(flags & RPCAUTH_LOOKUP_NEW)) {
+ int res = cred->cr_ops->cr_init(auth, cred);
+ if (res < 0) {
+ put_rpccred(cred);
+ cred = ERR_PTR(res);
+ }
+ }
+ rpcauth_destroy_credlist(&free);
+out:
+ return cred;
+}
+EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
+
+struct rpc_cred *
+rpcauth_lookupcred(struct rpc_auth *auth, int flags)
+{
+ struct auth_cred acred;
+ struct rpc_cred *ret;
+ const struct cred *cred = current_cred();
+
+ memset(&acred, 0, sizeof(acred));
+ acred.cred = cred;
+ ret = auth->au_ops->lookup_cred(auth, &acred, flags);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
+
+void
+rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
+ struct rpc_auth *auth, const struct rpc_credops *ops)
+{
+ INIT_HLIST_NODE(&cred->cr_hash);
+ INIT_LIST_HEAD(&cred->cr_lru);
+ refcount_set(&cred->cr_count, 1);
+ cred->cr_auth = auth;
+ cred->cr_flags = 0;
+ cred->cr_ops = ops;
+ cred->cr_expire = jiffies;
+ cred->cr_cred = get_cred(acred->cred);
+}
+EXPORT_SYMBOL_GPL(rpcauth_init_cred);
+
+static struct rpc_cred *
+rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
+{
+ struct rpc_auth *auth = task->tk_client->cl_auth;
+ struct auth_cred acred = {
+ .cred = get_task_cred(&init_task),
+ };
+ struct rpc_cred *ret;
+
+ if (RPC_IS_ASYNC(task))
+ lookupflags |= RPCAUTH_LOOKUP_ASYNC;
+ ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
+ put_cred(acred.cred);
+ return ret;
+}
+
+static struct rpc_cred *
+rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags)
+{
+ struct rpc_auth *auth = task->tk_client->cl_auth;
+ struct auth_cred acred = {
+ .principal = task->tk_client->cl_principal,
+ .cred = init_task.cred,
+ };
+
+ if (!acred.principal)
+ return NULL;
+ if (RPC_IS_ASYNC(task))
+ lookupflags |= RPCAUTH_LOOKUP_ASYNC;
+ return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
+}
+
+static struct rpc_cred *
+rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
+{
+ struct rpc_auth *auth = task->tk_client->cl_auth;
+
+ return rpcauth_lookupcred(auth, lookupflags);
+}
+
+static int
+rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_cred *new = NULL;
+ int lookupflags = 0;
+ struct rpc_auth *auth = task->tk_client->cl_auth;
+ struct auth_cred acred = {
+ .cred = cred,
+ };
+
+ if (flags & RPC_TASK_ASYNC)
+ lookupflags |= RPCAUTH_LOOKUP_NEW | RPCAUTH_LOOKUP_ASYNC;
+ if (task->tk_op_cred)
+ /* Task must use exactly this rpc_cred */
+ new = get_rpccred(task->tk_op_cred);
+ else if (cred != NULL && cred != &machine_cred)
+ new = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
+ else if (cred == &machine_cred)
+ new = rpcauth_bind_machine_cred(task, lookupflags);
+
+ /* If machine cred couldn't be bound, try a root cred */
+ if (new)
+ ;
+ else if (cred == &machine_cred)
+ new = rpcauth_bind_root_cred(task, lookupflags);
+ else if (flags & RPC_TASK_NULLCREDS)
+ new = authnull_ops.lookup_cred(NULL, NULL, 0);
+ else
+ new = rpcauth_bind_new_cred(task, lookupflags);
+ if (IS_ERR(new))
+ return PTR_ERR(new);
+ put_rpccred(req->rq_cred);
+ req->rq_cred = new;
+ return 0;
+}
+
+void
+put_rpccred(struct rpc_cred *cred)
+{
+ if (cred == NULL)
+ return;
+ rcu_read_lock();
+ if (refcount_dec_and_test(&cred->cr_count))
+ goto destroy;
+ if (refcount_read(&cred->cr_count) != 1 ||
+ !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
+ goto out;
+ if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
+ cred->cr_expire = jiffies;
+ rpcauth_lru_add(cred);
+ /* Race breaker */
+ if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
+ rpcauth_lru_remove(cred);
+ } else if (rpcauth_unhash_cred(cred)) {
+ rpcauth_lru_remove(cred);
+ if (refcount_dec_and_test(&cred->cr_count))
+ goto destroy;
+ }
+out:
+ rcu_read_unlock();
+ return;
+destroy:
+ rcu_read_unlock();
+ cred->cr_ops->crdestroy(cred);
+}
+EXPORT_SYMBOL_GPL(put_rpccred);
+
+/**
+ * rpcauth_marshcred - Append RPC credential to end of @xdr
+ * @task: controlling RPC task
+ * @xdr: xdr_stream containing initial portion of RPC Call header
+ *
+ * On success, an appropriate verifier is added to @xdr, @xdr is
+ * updated to point past the verifier, and zero is returned.
+ * Otherwise, @xdr is in an undefined state and a negative errno
+ * is returned.
+ */
+int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
+
+ return ops->crmarshal(task, xdr);
+}
+
+/**
+ * rpcauth_wrap_req_encode - XDR encode the RPC procedure
+ * @task: controlling RPC task
+ * @xdr: stream where on-the-wire bytes are to be marshalled
+ *
+ * On success, @xdr contains the encoded and wrapped message.
+ * Otherwise, @xdr is in an undefined state.
+ */
+int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode;
+
+ encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode);
+
+/**
+ * rpcauth_wrap_req - XDR encode and wrap the RPC procedure
+ * @task: controlling RPC task
+ * @xdr: stream where on-the-wire bytes are to be marshalled
+ *
+ * On success, @xdr contains the encoded and wrapped message,
+ * and zero is returned. Otherwise, @xdr is in an undefined
+ * state and a negative errno is returned.
+ */
+int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
+
+ return ops->crwrap_req(task, xdr);
+}
+
+/**
+ * rpcauth_checkverf - Validate verifier in RPC Reply header
+ * @task: controlling RPC task
+ * @xdr: xdr_stream containing RPC Reply header
+ *
+ * Return values:
+ * %0: Verifier is valid. @xdr now points past the verifier.
+ * %-EIO: Verifier is corrupted or message ended early.
+ * %-EACCES: Verifier is intact but not valid.
+ * %-EPROTONOSUPPORT: Server does not support the requested auth type.
+ *
+ * When a negative errno is returned, @xdr is left in an undefined
+ * state.
+ */
+int
+rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
+
+ return ops->crvalidate(task, xdr);
+}
+
+/**
+ * rpcauth_unwrap_resp_decode - Invoke XDR decode function
+ * @task: controlling RPC task
+ * @xdr: stream where the Reply message resides
+ *
+ * Returns zero on success; otherwise a negative errno is returned.
+ */
+int
+rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode;
+
+ return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp);
+}
+EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode);
+
+/**
+ * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred
+ * @task: controlling RPC task
+ * @xdr: stream where the Reply message resides
+ *
+ * Returns zero on success; otherwise a negative errno is returned.
+ */
+int
+rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
+
+ return ops->crunwrap_resp(task, xdr);
+}
+
+bool
+rpcauth_xmit_need_reencode(struct rpc_task *task)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+
+ if (!cred || !cred->cr_ops->crneed_reencode)
+ return false;
+ return cred->cr_ops->crneed_reencode(task);
+}
+
+int
+rpcauth_refreshcred(struct rpc_task *task)
+{
+ struct rpc_cred *cred;
+ int err;
+
+ cred = task->tk_rqstp->rq_cred;
+ if (cred == NULL) {
+ err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
+ if (err < 0)
+ goto out;
+ cred = task->tk_rqstp->rq_cred;
+ }
+
+ err = cred->cr_ops->crrefresh(task);
+out:
+ if (err < 0)
+ task->tk_status = err;
+ return err;
+}
+
+void
+rpcauth_invalcred(struct rpc_task *task)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+
+ if (cred)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+}
+
+int
+rpcauth_uptodatecred(struct rpc_task *task)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+
+ return cred == NULL ||
+ test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
+}
+
+static struct shrinker rpc_cred_shrinker = {
+ .count_objects = rpcauth_cache_shrink_count,
+ .scan_objects = rpcauth_cache_shrink_scan,
+ .seeks = DEFAULT_SEEKS,
+};
+
+int __init rpcauth_init_module(void)
+{
+ int err;
+
+ err = rpc_init_authunix();
+ if (err < 0)
+ goto out1;
+ err = register_shrinker(&rpc_cred_shrinker, "sunrpc_cred");
+ if (err < 0)
+ goto out2;
+ return 0;
+out2:
+ rpc_destroy_authunix();
+out1:
+ return err;
+}
+
+void rpcauth_remove_module(void)
+{
+ rpc_destroy_authunix();
+ unregister_shrinker(&rpc_cred_shrinker);
+}
diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile
new file mode 100644
index 0000000000..ad1736d93b
--- /dev/null
+++ b/net/sunrpc/auth_gss/Makefile
@@ -0,0 +1,17 @@
+# SPDX-License-Identifier: GPL-2.0
+#
+# Makefile for Linux kernel rpcsec_gss implementation
+#
+
+obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o
+
+auth_rpcgss-y := auth_gss.o gss_generic_token.o \
+ gss_mech_switch.o svcauth_gss.o \
+ gss_rpc_upcall.o gss_rpc_xdr.o trace.o
+
+obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o
+
+rpcsec_gss_krb5-y := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \
+ gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o
+
+obj-$(CONFIG_RPCSEC_GSS_KRB5_KUNIT_TEST) += gss_krb5_test.o
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
new file mode 100644
index 0000000000..1af71fbb0d
--- /dev/null
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -0,0 +1,2298 @@
+// SPDX-License-Identifier: BSD-3-Clause
+/*
+ * linux/net/sunrpc/auth_gss/auth_gss.c
+ *
+ * RPCSEC_GSS client authentication.
+ *
+ * Copyright (c) 2000 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Dug Song <dugsong@monkey.org>
+ * Andy Adamson <andros@umich.edu>
+ */
+
+#include <linux/module.h>
+#include <linux/init.h>
+#include <linux/types.h>
+#include <linux/slab.h>
+#include <linux/sched.h>
+#include <linux/pagemap.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/auth.h>
+#include <linux/sunrpc/auth_gss.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/sunrpc/svcauth_gss.h>
+#include <linux/sunrpc/gss_err.h>
+#include <linux/workqueue.h>
+#include <linux/sunrpc/rpc_pipe_fs.h>
+#include <linux/sunrpc/gss_api.h>
+#include <linux/uaccess.h>
+#include <linux/hashtable.h>
+
+#include "auth_gss_internal.h"
+#include "../netns.h"
+
+#include <trace/events/rpcgss.h>
+
+static const struct rpc_authops authgss_ops;
+
+static const struct rpc_credops gss_credops;
+static const struct rpc_credops gss_nullops;
+
+#define GSS_RETRY_EXPIRED 5
+static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
+
+#define GSS_KEY_EXPIRE_TIMEO 240
+static unsigned int gss_key_expire_timeo = GSS_KEY_EXPIRE_TIMEO;
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+/*
+ * This compile-time check verifies that we will not exceed the
+ * slack space allotted by the client and server auth_gss code
+ * before they call gss_wrap().
+ */
+#define GSS_KRB5_MAX_SLACK_NEEDED \
+ (GSS_KRB5_TOK_HDR_LEN /* gss token header */ \
+ + GSS_KRB5_MAX_CKSUM_LEN /* gss token checksum */ \
+ + GSS_KRB5_MAX_BLOCKSIZE /* confounder */ \
+ + GSS_KRB5_MAX_BLOCKSIZE /* possible padding */ \
+ + GSS_KRB5_TOK_HDR_LEN /* encrypted hdr in v2 token */ \
+ + GSS_KRB5_MAX_CKSUM_LEN /* encryption hmac */ \
+ + XDR_UNIT * 2 /* RPC verifier */ \
+ + GSS_KRB5_TOK_HDR_LEN \
+ + GSS_KRB5_MAX_CKSUM_LEN)
+
+#define GSS_CRED_SLACK (RPC_MAX_AUTH_SIZE * 2)
+/* length of a krb5 verifier (48), plus data added before arguments when
+ * using integrity (two 4-byte integers): */
+#define GSS_VERF_SLACK 100
+
+static DEFINE_HASHTABLE(gss_auth_hash_table, 4);
+static DEFINE_SPINLOCK(gss_auth_hash_lock);
+
+struct gss_pipe {
+ struct rpc_pipe_dir_object pdo;
+ struct rpc_pipe *pipe;
+ struct rpc_clnt *clnt;
+ const char *name;
+ struct kref kref;
+};
+
+struct gss_auth {
+ struct kref kref;
+ struct hlist_node hash;
+ struct rpc_auth rpc_auth;
+ struct gss_api_mech *mech;
+ enum rpc_gss_svc service;
+ struct rpc_clnt *client;
+ struct net *net;
+ netns_tracker ns_tracker;
+ /*
+ * There are two upcall pipes; dentry[1], named "gssd", is used
+ * for the new text-based upcall; dentry[0] is named after the
+ * mechanism (for example, "krb5") and exists for
+ * backwards-compatibility with older gssd's.
+ */
+ struct gss_pipe *gss_pipe[2];
+ const char *target_name;
+};
+
+/* pipe_version >= 0 if and only if someone has a pipe open. */
+static DEFINE_SPINLOCK(pipe_version_lock);
+static struct rpc_wait_queue pipe_version_rpc_waitqueue;
+static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
+static void gss_put_auth(struct gss_auth *gss_auth);
+
+static void gss_free_ctx(struct gss_cl_ctx *);
+static const struct rpc_pipe_ops gss_upcall_ops_v0;
+static const struct rpc_pipe_ops gss_upcall_ops_v1;
+
+static inline struct gss_cl_ctx *
+gss_get_ctx(struct gss_cl_ctx *ctx)
+{
+ refcount_inc(&ctx->count);
+ return ctx;
+}
+
+static inline void
+gss_put_ctx(struct gss_cl_ctx *ctx)
+{
+ if (refcount_dec_and_test(&ctx->count))
+ gss_free_ctx(ctx);
+}
+
+/* gss_cred_set_ctx:
+ * called by gss_upcall_callback and gss_create_upcall in order
+ * to set the gss context. The actual exchange of an old context
+ * and a new one is protected by the pipe->lock.
+ */
+static void
+gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
+{
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
+
+ if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
+ return;
+ gss_get_ctx(ctx);
+ rcu_assign_pointer(gss_cred->gc_ctx, ctx);
+ set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ smp_mb__before_atomic();
+ clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
+}
+
+static struct gss_cl_ctx *
+gss_cred_get_ctx(struct rpc_cred *cred)
+{
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
+ struct gss_cl_ctx *ctx = NULL;
+
+ rcu_read_lock();
+ ctx = rcu_dereference(gss_cred->gc_ctx);
+ if (ctx)
+ gss_get_ctx(ctx);
+ rcu_read_unlock();
+ return ctx;
+}
+
+static struct gss_cl_ctx *
+gss_alloc_context(void)
+{
+ struct gss_cl_ctx *ctx;
+
+ ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+ if (ctx != NULL) {
+ ctx->gc_proc = RPC_GSS_PROC_DATA;
+ ctx->gc_seq = 1; /* NetApp 6.4R1 doesn't accept seq. no. 0 */
+ spin_lock_init(&ctx->gc_seq_lock);
+ refcount_set(&ctx->count,1);
+ }
+ return ctx;
+}
+
+#define GSSD_MIN_TIMEOUT (60 * 60)
+static const void *
+gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
+{
+ const void *q;
+ unsigned int seclen;
+ unsigned int timeout;
+ unsigned long now = jiffies;
+ u32 window_size;
+ int ret;
+
+ /* First unsigned int gives the remaining lifetime in seconds of the
+ * credential - e.g. the remaining TGT lifetime for Kerberos or
+ * the -t value passed to GSSD.
+ */
+ p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
+ if (IS_ERR(p))
+ goto err;
+ if (timeout == 0)
+ timeout = GSSD_MIN_TIMEOUT;
+ ctx->gc_expiry = now + ((unsigned long)timeout * HZ);
+ /* Sequence number window. Determines the maximum number of
+ * simultaneous requests
+ */
+ p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
+ if (IS_ERR(p))
+ goto err;
+ ctx->gc_win = window_size;
+ /* gssd signals an error by passing ctx->gc_win = 0: */
+ if (ctx->gc_win == 0) {
+ /*
+ * in which case, p points to an error code. Anything other
+ * than -EKEYEXPIRED gets converted to -EACCES.
+ */
+ p = simple_get_bytes(p, end, &ret, sizeof(ret));
+ if (!IS_ERR(p))
+ p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
+ ERR_PTR(-EACCES);
+ goto err;
+ }
+ /* copy the opaque wire context */
+ p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
+ if (IS_ERR(p))
+ goto err;
+ /* import the opaque security context */
+ p = simple_get_bytes(p, end, &seclen, sizeof(seclen));
+ if (IS_ERR(p))
+ goto err;
+ q = (const void *)((const char *)p + seclen);
+ if (unlikely(q > end || q < p)) {
+ p = ERR_PTR(-EFAULT);
+ goto err;
+ }
+ ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_KERNEL);
+ if (ret < 0) {
+ trace_rpcgss_import_ctx(ret);
+ p = ERR_PTR(ret);
+ goto err;
+ }
+
+ /* is there any trailing data? */
+ if (q == end) {
+ p = q;
+ goto done;
+ }
+
+ /* pull in acceptor name (if there is one) */
+ p = simple_get_netobj(q, end, &ctx->gc_acceptor);
+ if (IS_ERR(p))
+ goto err;
+done:
+ trace_rpcgss_context(window_size, ctx->gc_expiry, now, timeout,
+ ctx->gc_acceptor.len, ctx->gc_acceptor.data);
+err:
+ return p;
+}
+
+/* XXX: Need some documentation about why UPCALL_BUF_LEN is so small.
+ * Is user space expecting no more than UPCALL_BUF_LEN bytes?
+ * Note that there are now _two_ NI_MAXHOST sized data items
+ * being passed in this string.
+ */
+#define UPCALL_BUF_LEN 256
+
+struct gss_upcall_msg {
+ refcount_t count;
+ kuid_t uid;
+ const char *service_name;
+ struct rpc_pipe_msg msg;
+ struct list_head list;
+ struct gss_auth *auth;
+ struct rpc_pipe *pipe;
+ struct rpc_wait_queue rpc_waitqueue;
+ wait_queue_head_t waitqueue;
+ struct gss_cl_ctx *ctx;
+ char databuf[UPCALL_BUF_LEN];
+};
+
+static int get_pipe_version(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ int ret;
+
+ spin_lock(&pipe_version_lock);
+ if (sn->pipe_version >= 0) {
+ atomic_inc(&sn->pipe_users);
+ ret = sn->pipe_version;
+ } else
+ ret = -EAGAIN;
+ spin_unlock(&pipe_version_lock);
+ return ret;
+}
+
+static void put_pipe_version(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ if (atomic_dec_and_lock(&sn->pipe_users, &pipe_version_lock)) {
+ sn->pipe_version = -1;
+ spin_unlock(&pipe_version_lock);
+ }
+}
+
+static void
+gss_release_msg(struct gss_upcall_msg *gss_msg)
+{
+ struct net *net = gss_msg->auth->net;
+ if (!refcount_dec_and_test(&gss_msg->count))
+ return;
+ put_pipe_version(net);
+ BUG_ON(!list_empty(&gss_msg->list));
+ if (gss_msg->ctx != NULL)
+ gss_put_ctx(gss_msg->ctx);
+ rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
+ gss_put_auth(gss_msg->auth);
+ kfree_const(gss_msg->service_name);
+ kfree(gss_msg);
+}
+
+static struct gss_upcall_msg *
+__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth)
+{
+ struct gss_upcall_msg *pos;
+ list_for_each_entry(pos, &pipe->in_downcall, list) {
+ if (!uid_eq(pos->uid, uid))
+ continue;
+ if (pos->auth->service != auth->service)
+ continue;
+ refcount_inc(&pos->count);
+ return pos;
+ }
+ return NULL;
+}
+
+/* Try to add an upcall to the pipefs queue.
+ * If an upcall owned by our uid already exists, then we return a reference
+ * to that upcall instead of adding the new upcall.
+ */
+static inline struct gss_upcall_msg *
+gss_add_msg(struct gss_upcall_msg *gss_msg)
+{
+ struct rpc_pipe *pipe = gss_msg->pipe;
+ struct gss_upcall_msg *old;
+
+ spin_lock(&pipe->lock);
+ old = __gss_find_upcall(pipe, gss_msg->uid, gss_msg->auth);
+ if (old == NULL) {
+ refcount_inc(&gss_msg->count);
+ list_add(&gss_msg->list, &pipe->in_downcall);
+ } else
+ gss_msg = old;
+ spin_unlock(&pipe->lock);
+ return gss_msg;
+}
+
+static void
+__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
+{
+ list_del_init(&gss_msg->list);
+ rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
+ wake_up_all(&gss_msg->waitqueue);
+ refcount_dec(&gss_msg->count);
+}
+
+static void
+gss_unhash_msg(struct gss_upcall_msg *gss_msg)
+{
+ struct rpc_pipe *pipe = gss_msg->pipe;
+
+ if (list_empty(&gss_msg->list))
+ return;
+ spin_lock(&pipe->lock);
+ if (!list_empty(&gss_msg->list))
+ __gss_unhash_msg(gss_msg);
+ spin_unlock(&pipe->lock);
+}
+
+static void
+gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
+{
+ switch (gss_msg->msg.errno) {
+ case 0:
+ if (gss_msg->ctx == NULL)
+ break;
+ clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
+ gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
+ break;
+ case -EKEYEXPIRED:
+ set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
+ }
+ gss_cred->gc_upcall_timestamp = jiffies;
+ gss_cred->gc_upcall = NULL;
+ rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
+}
+
+static void
+gss_upcall_callback(struct rpc_task *task)
+{
+ struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
+ struct gss_cred, gc_base);
+ struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
+ struct rpc_pipe *pipe = gss_msg->pipe;
+
+ spin_lock(&pipe->lock);
+ gss_handle_downcall_result(gss_cred, gss_msg);
+ spin_unlock(&pipe->lock);
+ task->tk_status = gss_msg->msg.errno;
+ gss_release_msg(gss_msg);
+}
+
+static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg,
+ const struct cred *cred)
+{
+ struct user_namespace *userns = cred->user_ns;
+
+ uid_t uid = from_kuid_munged(userns, gss_msg->uid);
+ memcpy(gss_msg->databuf, &uid, sizeof(uid));
+ gss_msg->msg.data = gss_msg->databuf;
+ gss_msg->msg.len = sizeof(uid);
+
+ BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf));
+}
+
+static ssize_t
+gss_v0_upcall(struct file *file, struct rpc_pipe_msg *msg,
+ char __user *buf, size_t buflen)
+{
+ struct gss_upcall_msg *gss_msg = container_of(msg,
+ struct gss_upcall_msg,
+ msg);
+ if (msg->copied == 0)
+ gss_encode_v0_msg(gss_msg, file->f_cred);
+ return rpc_pipe_generic_upcall(file, msg, buf, buflen);
+}
+
+static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
+ const char *service_name,
+ const char *target_name,
+ const struct cred *cred)
+{
+ struct user_namespace *userns = cred->user_ns;
+ struct gss_api_mech *mech = gss_msg->auth->mech;
+ char *p = gss_msg->databuf;
+ size_t buflen = sizeof(gss_msg->databuf);
+ int len;
+
+ len = scnprintf(p, buflen, "mech=%s uid=%d", mech->gm_name,
+ from_kuid_munged(userns, gss_msg->uid));
+ buflen -= len;
+ p += len;
+ gss_msg->msg.len = len;
+
+ /*
+ * target= is a full service principal that names the remote
+ * identity that we are authenticating to.
+ */
+ if (target_name) {
+ len = scnprintf(p, buflen, " target=%s", target_name);
+ buflen -= len;
+ p += len;
+ gss_msg->msg.len += len;
+ }
+
+ /*
+ * gssd uses service= and srchost= to select a matching key from
+ * the system's keytab to use as the source principal.
+ *
+ * service= is the service name part of the source principal,
+ * or "*" (meaning choose any).
+ *
+ * srchost= is the hostname part of the source principal. When
+ * not provided, gssd uses the local hostname.
+ */
+ if (service_name) {
+ char *c = strchr(service_name, '@');
+
+ if (!c)
+ len = scnprintf(p, buflen, " service=%s",
+ service_name);
+ else
+ len = scnprintf(p, buflen,
+ " service=%.*s srchost=%s",
+ (int)(c - service_name),
+ service_name, c + 1);
+ buflen -= len;
+ p += len;
+ gss_msg->msg.len += len;
+ }
+
+ if (mech->gm_upcall_enctypes) {
+ len = scnprintf(p, buflen, " enctypes=%s",
+ mech->gm_upcall_enctypes);
+ buflen -= len;
+ p += len;
+ gss_msg->msg.len += len;
+ }
+ trace_rpcgss_upcall_msg(gss_msg->databuf);
+ len = scnprintf(p, buflen, "\n");
+ if (len == 0)
+ goto out_overflow;
+ gss_msg->msg.len += len;
+ gss_msg->msg.data = gss_msg->databuf;
+ return 0;
+out_overflow:
+ WARN_ON_ONCE(1);
+ return -ENOMEM;
+}
+
+static ssize_t
+gss_v1_upcall(struct file *file, struct rpc_pipe_msg *msg,
+ char __user *buf, size_t buflen)
+{
+ struct gss_upcall_msg *gss_msg = container_of(msg,
+ struct gss_upcall_msg,
+ msg);
+ int err;
+ if (msg->copied == 0) {
+ err = gss_encode_v1_msg(gss_msg,
+ gss_msg->service_name,
+ gss_msg->auth->target_name,
+ file->f_cred);
+ if (err)
+ return err;
+ }
+ return rpc_pipe_generic_upcall(file, msg, buf, buflen);
+}
+
+static struct gss_upcall_msg *
+gss_alloc_msg(struct gss_auth *gss_auth,
+ kuid_t uid, const char *service_name)
+{
+ struct gss_upcall_msg *gss_msg;
+ int vers;
+ int err = -ENOMEM;
+
+ gss_msg = kzalloc(sizeof(*gss_msg), GFP_KERNEL);
+ if (gss_msg == NULL)
+ goto err;
+ vers = get_pipe_version(gss_auth->net);
+ err = vers;
+ if (err < 0)
+ goto err_free_msg;
+ gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe;
+ INIT_LIST_HEAD(&gss_msg->list);
+ rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
+ init_waitqueue_head(&gss_msg->waitqueue);
+ refcount_set(&gss_msg->count, 1);
+ gss_msg->uid = uid;
+ gss_msg->auth = gss_auth;
+ kref_get(&gss_auth->kref);
+ if (service_name) {
+ gss_msg->service_name = kstrdup_const(service_name, GFP_KERNEL);
+ if (!gss_msg->service_name) {
+ err = -ENOMEM;
+ goto err_put_pipe_version;
+ }
+ }
+ return gss_msg;
+err_put_pipe_version:
+ put_pipe_version(gss_auth->net);
+err_free_msg:
+ kfree(gss_msg);
+err:
+ return ERR_PTR(err);
+}
+
+static struct gss_upcall_msg *
+gss_setup_upcall(struct gss_auth *gss_auth, struct rpc_cred *cred)
+{
+ struct gss_cred *gss_cred = container_of(cred,
+ struct gss_cred, gc_base);
+ struct gss_upcall_msg *gss_new, *gss_msg;
+ kuid_t uid = cred->cr_cred->fsuid;
+
+ gss_new = gss_alloc_msg(gss_auth, uid, gss_cred->gc_principal);
+ if (IS_ERR(gss_new))
+ return gss_new;
+ gss_msg = gss_add_msg(gss_new);
+ if (gss_msg == gss_new) {
+ int res;
+ refcount_inc(&gss_msg->count);
+ res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg);
+ if (res) {
+ gss_unhash_msg(gss_new);
+ refcount_dec(&gss_msg->count);
+ gss_release_msg(gss_new);
+ gss_msg = ERR_PTR(res);
+ }
+ } else
+ gss_release_msg(gss_new);
+ return gss_msg;
+}
+
+static void warn_gssd(void)
+{
+ dprintk("AUTH_GSS upcall failed. Please check user daemon is running.\n");
+}
+
+static inline int
+gss_refresh_upcall(struct rpc_task *task)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+ struct gss_auth *gss_auth = container_of(cred->cr_auth,
+ struct gss_auth, rpc_auth);
+ struct gss_cred *gss_cred = container_of(cred,
+ struct gss_cred, gc_base);
+ struct gss_upcall_msg *gss_msg;
+ struct rpc_pipe *pipe;
+ int err = 0;
+
+ gss_msg = gss_setup_upcall(gss_auth, cred);
+ if (PTR_ERR(gss_msg) == -EAGAIN) {
+ /* XXX: warning on the first, under the assumption we
+ * shouldn't normally hit this case on a refresh. */
+ warn_gssd();
+ rpc_sleep_on_timeout(&pipe_version_rpc_waitqueue,
+ task, NULL, jiffies + (15 * HZ));
+ err = -EAGAIN;
+ goto out;
+ }
+ if (IS_ERR(gss_msg)) {
+ err = PTR_ERR(gss_msg);
+ goto out;
+ }
+ pipe = gss_msg->pipe;
+ spin_lock(&pipe->lock);
+ if (gss_cred->gc_upcall != NULL)
+ rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
+ else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
+ gss_cred->gc_upcall = gss_msg;
+ /* gss_upcall_callback will release the reference to gss_upcall_msg */
+ refcount_inc(&gss_msg->count);
+ rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
+ } else {
+ gss_handle_downcall_result(gss_cred, gss_msg);
+ err = gss_msg->msg.errno;
+ }
+ spin_unlock(&pipe->lock);
+ gss_release_msg(gss_msg);
+out:
+ trace_rpcgss_upcall_result(from_kuid(&init_user_ns,
+ cred->cr_cred->fsuid), err);
+ return err;
+}
+
+static inline int
+gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
+{
+ struct net *net = gss_auth->net;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_pipe *pipe;
+ struct rpc_cred *cred = &gss_cred->gc_base;
+ struct gss_upcall_msg *gss_msg;
+ DEFINE_WAIT(wait);
+ int err;
+
+retry:
+ err = 0;
+ /* if gssd is down, just skip upcalling altogether */
+ if (!gssd_running(net)) {
+ warn_gssd();
+ err = -EACCES;
+ goto out;
+ }
+ gss_msg = gss_setup_upcall(gss_auth, cred);
+ if (PTR_ERR(gss_msg) == -EAGAIN) {
+ err = wait_event_interruptible_timeout(pipe_version_waitqueue,
+ sn->pipe_version >= 0, 15 * HZ);
+ if (sn->pipe_version < 0) {
+ warn_gssd();
+ err = -EACCES;
+ }
+ if (err < 0)
+ goto out;
+ goto retry;
+ }
+ if (IS_ERR(gss_msg)) {
+ err = PTR_ERR(gss_msg);
+ goto out;
+ }
+ pipe = gss_msg->pipe;
+ for (;;) {
+ prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE);
+ spin_lock(&pipe->lock);
+ if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
+ break;
+ }
+ spin_unlock(&pipe->lock);
+ if (fatal_signal_pending(current)) {
+ err = -ERESTARTSYS;
+ goto out_intr;
+ }
+ schedule();
+ }
+ if (gss_msg->ctx) {
+ trace_rpcgss_ctx_init(gss_cred);
+ gss_cred_set_ctx(cred, gss_msg->ctx);
+ } else {
+ err = gss_msg->msg.errno;
+ }
+ spin_unlock(&pipe->lock);
+out_intr:
+ finish_wait(&gss_msg->waitqueue, &wait);
+ gss_release_msg(gss_msg);
+out:
+ trace_rpcgss_upcall_result(from_kuid(&init_user_ns,
+ cred->cr_cred->fsuid), err);
+ return err;
+}
+
+static struct gss_upcall_msg *
+gss_find_downcall(struct rpc_pipe *pipe, kuid_t uid)
+{
+ struct gss_upcall_msg *pos;
+ list_for_each_entry(pos, &pipe->in_downcall, list) {
+ if (!uid_eq(pos->uid, uid))
+ continue;
+ if (!rpc_msg_is_inflight(&pos->msg))
+ continue;
+ refcount_inc(&pos->count);
+ return pos;
+ }
+ return NULL;
+}
+
+#define MSG_BUF_MAXSIZE 1024
+
+static ssize_t
+gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
+{
+ const void *p, *end;
+ void *buf;
+ struct gss_upcall_msg *gss_msg;
+ struct rpc_pipe *pipe = RPC_I(file_inode(filp))->pipe;
+ struct gss_cl_ctx *ctx;
+ uid_t id;
+ kuid_t uid;
+ ssize_t err = -EFBIG;
+
+ if (mlen > MSG_BUF_MAXSIZE)
+ goto out;
+ err = -ENOMEM;
+ buf = kmalloc(mlen, GFP_KERNEL);
+ if (!buf)
+ goto out;
+
+ err = -EFAULT;
+ if (copy_from_user(buf, src, mlen))
+ goto err;
+
+ end = (const void *)((char *)buf + mlen);
+ p = simple_get_bytes(buf, end, &id, sizeof(id));
+ if (IS_ERR(p)) {
+ err = PTR_ERR(p);
+ goto err;
+ }
+
+ uid = make_kuid(current_user_ns(), id);
+ if (!uid_valid(uid)) {
+ err = -EINVAL;
+ goto err;
+ }
+
+ err = -ENOMEM;
+ ctx = gss_alloc_context();
+ if (ctx == NULL)
+ goto err;
+
+ err = -ENOENT;
+ /* Find a matching upcall */
+ spin_lock(&pipe->lock);
+ gss_msg = gss_find_downcall(pipe, uid);
+ if (gss_msg == NULL) {
+ spin_unlock(&pipe->lock);
+ goto err_put_ctx;
+ }
+ list_del_init(&gss_msg->list);
+ spin_unlock(&pipe->lock);
+
+ p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
+ if (IS_ERR(p)) {
+ err = PTR_ERR(p);
+ switch (err) {
+ case -EACCES:
+ case -EKEYEXPIRED:
+ gss_msg->msg.errno = err;
+ err = mlen;
+ break;
+ case -EFAULT:
+ case -ENOMEM:
+ case -EINVAL:
+ case -ENOSYS:
+ gss_msg->msg.errno = -EAGAIN;
+ break;
+ default:
+ printk(KERN_CRIT "%s: bad return from "
+ "gss_fill_context: %zd\n", __func__, err);
+ gss_msg->msg.errno = -EIO;
+ }
+ goto err_release_msg;
+ }
+ gss_msg->ctx = gss_get_ctx(ctx);
+ err = mlen;
+
+err_release_msg:
+ spin_lock(&pipe->lock);
+ __gss_unhash_msg(gss_msg);
+ spin_unlock(&pipe->lock);
+ gss_release_msg(gss_msg);
+err_put_ctx:
+ gss_put_ctx(ctx);
+err:
+ kfree(buf);
+out:
+ return err;
+}
+
+static int gss_pipe_open(struct inode *inode, int new_version)
+{
+ struct net *net = inode->i_sb->s_fs_info;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ int ret = 0;
+
+ spin_lock(&pipe_version_lock);
+ if (sn->pipe_version < 0) {
+ /* First open of any gss pipe determines the version: */
+ sn->pipe_version = new_version;
+ rpc_wake_up(&pipe_version_rpc_waitqueue);
+ wake_up(&pipe_version_waitqueue);
+ } else if (sn->pipe_version != new_version) {
+ /* Trying to open a pipe of a different version */
+ ret = -EBUSY;
+ goto out;
+ }
+ atomic_inc(&sn->pipe_users);
+out:
+ spin_unlock(&pipe_version_lock);
+ return ret;
+
+}
+
+static int gss_pipe_open_v0(struct inode *inode)
+{
+ return gss_pipe_open(inode, 0);
+}
+
+static int gss_pipe_open_v1(struct inode *inode)
+{
+ return gss_pipe_open(inode, 1);
+}
+
+static void
+gss_pipe_release(struct inode *inode)
+{
+ struct net *net = inode->i_sb->s_fs_info;
+ struct rpc_pipe *pipe = RPC_I(inode)->pipe;
+ struct gss_upcall_msg *gss_msg;
+
+restart:
+ spin_lock(&pipe->lock);
+ list_for_each_entry(gss_msg, &pipe->in_downcall, list) {
+
+ if (!list_empty(&gss_msg->msg.list))
+ continue;
+ gss_msg->msg.errno = -EPIPE;
+ refcount_inc(&gss_msg->count);
+ __gss_unhash_msg(gss_msg);
+ spin_unlock(&pipe->lock);
+ gss_release_msg(gss_msg);
+ goto restart;
+ }
+ spin_unlock(&pipe->lock);
+
+ put_pipe_version(net);
+}
+
+static void
+gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
+{
+ struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
+
+ if (msg->errno < 0) {
+ refcount_inc(&gss_msg->count);
+ gss_unhash_msg(gss_msg);
+ if (msg->errno == -ETIMEDOUT)
+ warn_gssd();
+ gss_release_msg(gss_msg);
+ }
+ gss_release_msg(gss_msg);
+}
+
+static void gss_pipe_dentry_destroy(struct dentry *dir,
+ struct rpc_pipe_dir_object *pdo)
+{
+ struct gss_pipe *gss_pipe = pdo->pdo_data;
+ struct rpc_pipe *pipe = gss_pipe->pipe;
+
+ if (pipe->dentry != NULL) {
+ rpc_unlink(pipe->dentry);
+ pipe->dentry = NULL;
+ }
+}
+
+static int gss_pipe_dentry_create(struct dentry *dir,
+ struct rpc_pipe_dir_object *pdo)
+{
+ struct gss_pipe *p = pdo->pdo_data;
+ struct dentry *dentry;
+
+ dentry = rpc_mkpipe_dentry(dir, p->name, p->clnt, p->pipe);
+ if (IS_ERR(dentry))
+ return PTR_ERR(dentry);
+ p->pipe->dentry = dentry;
+ return 0;
+}
+
+static const struct rpc_pipe_dir_object_ops gss_pipe_dir_object_ops = {
+ .create = gss_pipe_dentry_create,
+ .destroy = gss_pipe_dentry_destroy,
+};
+
+static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt,
+ const char *name,
+ const struct rpc_pipe_ops *upcall_ops)
+{
+ struct gss_pipe *p;
+ int err = -ENOMEM;
+
+ p = kmalloc(sizeof(*p), GFP_KERNEL);
+ if (p == NULL)
+ goto err;
+ p->pipe = rpc_mkpipe_data(upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
+ if (IS_ERR(p->pipe)) {
+ err = PTR_ERR(p->pipe);
+ goto err_free_gss_pipe;
+ }
+ p->name = name;
+ p->clnt = clnt;
+ kref_init(&p->kref);
+ rpc_init_pipe_dir_object(&p->pdo,
+ &gss_pipe_dir_object_ops,
+ p);
+ return p;
+err_free_gss_pipe:
+ kfree(p);
+err:
+ return ERR_PTR(err);
+}
+
+struct gss_alloc_pdo {
+ struct rpc_clnt *clnt;
+ const char *name;
+ const struct rpc_pipe_ops *upcall_ops;
+};
+
+static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data)
+{
+ struct gss_pipe *gss_pipe;
+ struct gss_alloc_pdo *args = data;
+
+ if (pdo->pdo_ops != &gss_pipe_dir_object_ops)
+ return 0;
+ gss_pipe = container_of(pdo, struct gss_pipe, pdo);
+ if (strcmp(gss_pipe->name, args->name) != 0)
+ return 0;
+ if (!kref_get_unless_zero(&gss_pipe->kref))
+ return 0;
+ return 1;
+}
+
+static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data)
+{
+ struct gss_pipe *gss_pipe;
+ struct gss_alloc_pdo *args = data;
+
+ gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops);
+ if (!IS_ERR(gss_pipe))
+ return &gss_pipe->pdo;
+ return NULL;
+}
+
+static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt,
+ const char *name,
+ const struct rpc_pipe_ops *upcall_ops)
+{
+ struct net *net = rpc_net_ns(clnt);
+ struct rpc_pipe_dir_object *pdo;
+ struct gss_alloc_pdo args = {
+ .clnt = clnt,
+ .name = name,
+ .upcall_ops = upcall_ops,
+ };
+
+ pdo = rpc_find_or_alloc_pipe_dir_object(net,
+ &clnt->cl_pipedir_objects,
+ gss_pipe_match_pdo,
+ gss_pipe_alloc_pdo,
+ &args);
+ if (pdo != NULL)
+ return container_of(pdo, struct gss_pipe, pdo);
+ return ERR_PTR(-ENOMEM);
+}
+
+static void __gss_pipe_free(struct gss_pipe *p)
+{
+ struct rpc_clnt *clnt = p->clnt;
+ struct net *net = rpc_net_ns(clnt);
+
+ rpc_remove_pipe_dir_object(net,
+ &clnt->cl_pipedir_objects,
+ &p->pdo);
+ rpc_destroy_pipe_data(p->pipe);
+ kfree(p);
+}
+
+static void __gss_pipe_release(struct kref *kref)
+{
+ struct gss_pipe *p = container_of(kref, struct gss_pipe, kref);
+
+ __gss_pipe_free(p);
+}
+
+static void gss_pipe_free(struct gss_pipe *p)
+{
+ if (p != NULL)
+ kref_put(&p->kref, __gss_pipe_release);
+}
+
+/*
+ * NOTE: we have the opportunity to use different
+ * parameters based on the input flavor (which must be a pseudoflavor)
+ */
+static struct gss_auth *
+gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
+{
+ rpc_authflavor_t flavor = args->pseudoflavor;
+ struct gss_auth *gss_auth;
+ struct gss_pipe *gss_pipe;
+ struct rpc_auth * auth;
+ int err = -ENOMEM; /* XXX? */
+
+ if (!try_module_get(THIS_MODULE))
+ return ERR_PTR(err);
+ if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
+ goto out_dec;
+ INIT_HLIST_NODE(&gss_auth->hash);
+ gss_auth->target_name = NULL;
+ if (args->target_name) {
+ gss_auth->target_name = kstrdup(args->target_name, GFP_KERNEL);
+ if (gss_auth->target_name == NULL)
+ goto err_free;
+ }
+ gss_auth->client = clnt;
+ gss_auth->net = get_net_track(rpc_net_ns(clnt), &gss_auth->ns_tracker,
+ GFP_KERNEL);
+ err = -EINVAL;
+ gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
+ if (!gss_auth->mech)
+ goto err_put_net;
+ gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
+ if (gss_auth->service == 0)
+ goto err_put_mech;
+ if (!gssd_running(gss_auth->net))
+ goto err_put_mech;
+ auth = &gss_auth->rpc_auth;
+ auth->au_cslack = GSS_CRED_SLACK >> 2;
+ BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE);
+ auth->au_rslack = GSS_KRB5_MAX_SLACK_NEEDED >> 2;
+ auth->au_verfsize = GSS_VERF_SLACK >> 2;
+ auth->au_ralign = GSS_VERF_SLACK >> 2;
+ __set_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags);
+ auth->au_ops = &authgss_ops;
+ auth->au_flavor = flavor;
+ if (gss_pseudoflavor_to_datatouch(gss_auth->mech, flavor))
+ __set_bit(RPCAUTH_AUTH_DATATOUCH, &auth->au_flags);
+ refcount_set(&auth->au_count, 1);
+ kref_init(&gss_auth->kref);
+
+ err = rpcauth_init_credcache(auth);
+ if (err)
+ goto err_put_mech;
+ /*
+ * Note: if we created the old pipe first, then someone who
+ * examined the directory at the right moment might conclude
+ * that we supported only the old pipe. So we instead create
+ * the new pipe first.
+ */
+ gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1);
+ if (IS_ERR(gss_pipe)) {
+ err = PTR_ERR(gss_pipe);
+ goto err_destroy_credcache;
+ }
+ gss_auth->gss_pipe[1] = gss_pipe;
+
+ gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name,
+ &gss_upcall_ops_v0);
+ if (IS_ERR(gss_pipe)) {
+ err = PTR_ERR(gss_pipe);
+ goto err_destroy_pipe_1;
+ }
+ gss_auth->gss_pipe[0] = gss_pipe;
+
+ return gss_auth;
+err_destroy_pipe_1:
+ gss_pipe_free(gss_auth->gss_pipe[1]);
+err_destroy_credcache:
+ rpcauth_destroy_credcache(auth);
+err_put_mech:
+ gss_mech_put(gss_auth->mech);
+err_put_net:
+ put_net_track(gss_auth->net, &gss_auth->ns_tracker);
+err_free:
+ kfree(gss_auth->target_name);
+ kfree(gss_auth);
+out_dec:
+ module_put(THIS_MODULE);
+ trace_rpcgss_createauth(flavor, err);
+ return ERR_PTR(err);
+}
+
+static void
+gss_free(struct gss_auth *gss_auth)
+{
+ gss_pipe_free(gss_auth->gss_pipe[0]);
+ gss_pipe_free(gss_auth->gss_pipe[1]);
+ gss_mech_put(gss_auth->mech);
+ put_net_track(gss_auth->net, &gss_auth->ns_tracker);
+ kfree(gss_auth->target_name);
+
+ kfree(gss_auth);
+ module_put(THIS_MODULE);
+}
+
+static void
+gss_free_callback(struct kref *kref)
+{
+ struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
+
+ gss_free(gss_auth);
+}
+
+static void
+gss_put_auth(struct gss_auth *gss_auth)
+{
+ kref_put(&gss_auth->kref, gss_free_callback);
+}
+
+static void
+gss_destroy(struct rpc_auth *auth)
+{
+ struct gss_auth *gss_auth = container_of(auth,
+ struct gss_auth, rpc_auth);
+
+ if (hash_hashed(&gss_auth->hash)) {
+ spin_lock(&gss_auth_hash_lock);
+ hash_del(&gss_auth->hash);
+ spin_unlock(&gss_auth_hash_lock);
+ }
+
+ gss_pipe_free(gss_auth->gss_pipe[0]);
+ gss_auth->gss_pipe[0] = NULL;
+ gss_pipe_free(gss_auth->gss_pipe[1]);
+ gss_auth->gss_pipe[1] = NULL;
+ rpcauth_destroy_credcache(auth);
+
+ gss_put_auth(gss_auth);
+}
+
+/*
+ * Auths may be shared between rpc clients that were cloned from a
+ * common client with the same xprt, if they also share the flavor and
+ * target_name.
+ *
+ * The auth is looked up from the oldest parent sharing the same
+ * cl_xprt, and the auth itself references only that common parent
+ * (which is guaranteed to last as long as any of its descendants).
+ */
+static struct gss_auth *
+gss_auth_find_or_add_hashed(const struct rpc_auth_create_args *args,
+ struct rpc_clnt *clnt,
+ struct gss_auth *new)
+{
+ struct gss_auth *gss_auth;
+ unsigned long hashval = (unsigned long)clnt;
+
+ spin_lock(&gss_auth_hash_lock);
+ hash_for_each_possible(gss_auth_hash_table,
+ gss_auth,
+ hash,
+ hashval) {
+ if (gss_auth->client != clnt)
+ continue;
+ if (gss_auth->rpc_auth.au_flavor != args->pseudoflavor)
+ continue;
+ if (gss_auth->target_name != args->target_name) {
+ if (gss_auth->target_name == NULL)
+ continue;
+ if (args->target_name == NULL)
+ continue;
+ if (strcmp(gss_auth->target_name, args->target_name))
+ continue;
+ }
+ if (!refcount_inc_not_zero(&gss_auth->rpc_auth.au_count))
+ continue;
+ goto out;
+ }
+ if (new)
+ hash_add(gss_auth_hash_table, &new->hash, hashval);
+ gss_auth = new;
+out:
+ spin_unlock(&gss_auth_hash_lock);
+ return gss_auth;
+}
+
+static struct gss_auth *
+gss_create_hashed(const struct rpc_auth_create_args *args,
+ struct rpc_clnt *clnt)
+{
+ struct gss_auth *gss_auth;
+ struct gss_auth *new;
+
+ gss_auth = gss_auth_find_or_add_hashed(args, clnt, NULL);
+ if (gss_auth != NULL)
+ goto out;
+ new = gss_create_new(args, clnt);
+ if (IS_ERR(new))
+ return new;
+ gss_auth = gss_auth_find_or_add_hashed(args, clnt, new);
+ if (gss_auth != new)
+ gss_destroy(&new->rpc_auth);
+out:
+ return gss_auth;
+}
+
+static struct rpc_auth *
+gss_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
+{
+ struct gss_auth *gss_auth;
+ struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch);
+
+ while (clnt != clnt->cl_parent) {
+ struct rpc_clnt *parent = clnt->cl_parent;
+ /* Find the original parent for this transport */
+ if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps)
+ break;
+ clnt = parent;
+ }
+
+ gss_auth = gss_create_hashed(args, clnt);
+ if (IS_ERR(gss_auth))
+ return ERR_CAST(gss_auth);
+ return &gss_auth->rpc_auth;
+}
+
+static struct gss_cred *
+gss_dup_cred(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
+{
+ struct gss_cred *new;
+
+ /* Make a copy of the cred so that we can reference count it */
+ new = kzalloc(sizeof(*gss_cred), GFP_KERNEL);
+ if (new) {
+ struct auth_cred acred = {
+ .cred = gss_cred->gc_base.cr_cred,
+ };
+ struct gss_cl_ctx *ctx =
+ rcu_dereference_protected(gss_cred->gc_ctx, 1);
+
+ rpcauth_init_cred(&new->gc_base, &acred,
+ &gss_auth->rpc_auth,
+ &gss_nullops);
+ new->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE;
+ new->gc_service = gss_cred->gc_service;
+ new->gc_principal = gss_cred->gc_principal;
+ kref_get(&gss_auth->kref);
+ rcu_assign_pointer(new->gc_ctx, ctx);
+ gss_get_ctx(ctx);
+ }
+ return new;
+}
+
+/*
+ * gss_send_destroy_context will cause the RPCSEC_GSS to send a NULL RPC call
+ * to the server with the GSS control procedure field set to
+ * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
+ * all RPCSEC_GSS state associated with that context.
+ */
+static void
+gss_send_destroy_context(struct rpc_cred *cred)
+{
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
+ struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
+ struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1);
+ struct gss_cred *new;
+ struct rpc_task *task;
+
+ new = gss_dup_cred(gss_auth, gss_cred);
+ if (new) {
+ ctx->gc_proc = RPC_GSS_PROC_DESTROY;
+
+ trace_rpcgss_ctx_destroy(gss_cred);
+ task = rpc_call_null(gss_auth->client, &new->gc_base,
+ RPC_TASK_ASYNC);
+ if (!IS_ERR(task))
+ rpc_put_task(task);
+
+ put_rpccred(&new->gc_base);
+ }
+}
+
+/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
+ * to create a new cred or context, so they check that things have been
+ * allocated before freeing them. */
+static void
+gss_do_free_ctx(struct gss_cl_ctx *ctx)
+{
+ gss_delete_sec_context(&ctx->gc_gss_ctx);
+ kfree(ctx->gc_wire_ctx.data);
+ kfree(ctx->gc_acceptor.data);
+ kfree(ctx);
+}
+
+static void
+gss_free_ctx_callback(struct rcu_head *head)
+{
+ struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
+ gss_do_free_ctx(ctx);
+}
+
+static void
+gss_free_ctx(struct gss_cl_ctx *ctx)
+{
+ call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
+}
+
+static void
+gss_free_cred(struct gss_cred *gss_cred)
+{
+ kfree(gss_cred);
+}
+
+static void
+gss_free_cred_callback(struct rcu_head *head)
+{
+ struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
+ gss_free_cred(gss_cred);
+}
+
+static void
+gss_destroy_nullcred(struct rpc_cred *cred)
+{
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
+ struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
+ struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1);
+
+ RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
+ put_cred(cred->cr_cred);
+ call_rcu(&cred->cr_rcu, gss_free_cred_callback);
+ if (ctx)
+ gss_put_ctx(ctx);
+ gss_put_auth(gss_auth);
+}
+
+static void
+gss_destroy_cred(struct rpc_cred *cred)
+{
+ if (test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0)
+ gss_send_destroy_context(cred);
+ gss_destroy_nullcred(cred);
+}
+
+static int
+gss_hash_cred(struct auth_cred *acred, unsigned int hashbits)
+{
+ return hash_64(from_kuid(&init_user_ns, acred->cred->fsuid), hashbits);
+}
+
+/*
+ * Lookup RPCSEC_GSS cred for the current process
+ */
+static struct rpc_cred *gss_lookup_cred(struct rpc_auth *auth,
+ struct auth_cred *acred, int flags)
+{
+ return rpcauth_lookup_credcache(auth, acred, flags,
+ rpc_task_gfp_mask());
+}
+
+static struct rpc_cred *
+gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp)
+{
+ struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
+ struct gss_cred *cred = NULL;
+ int err = -ENOMEM;
+
+ if (!(cred = kzalloc(sizeof(*cred), gfp)))
+ goto out_err;
+
+ rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
+ /*
+ * Note: in order to force a call to call_refresh(), we deliberately
+ * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
+ */
+ cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
+ cred->gc_service = gss_auth->service;
+ cred->gc_principal = acred->principal;
+ kref_get(&gss_auth->kref);
+ return &cred->gc_base;
+
+out_err:
+ return ERR_PTR(err);
+}
+
+static int
+gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
+{
+ struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
+ struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
+ int err;
+
+ do {
+ err = gss_create_upcall(gss_auth, gss_cred);
+ } while (err == -EAGAIN);
+ return err;
+}
+
+static char *
+gss_stringify_acceptor(struct rpc_cred *cred)
+{
+ char *string = NULL;
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
+ struct gss_cl_ctx *ctx;
+ unsigned int len;
+ struct xdr_netobj *acceptor;
+
+ rcu_read_lock();
+ ctx = rcu_dereference(gss_cred->gc_ctx);
+ if (!ctx)
+ goto out;
+
+ len = ctx->gc_acceptor.len;
+ rcu_read_unlock();
+
+ /* no point if there's no string */
+ if (!len)
+ return NULL;
+realloc:
+ string = kmalloc(len + 1, GFP_KERNEL);
+ if (!string)
+ return NULL;
+
+ rcu_read_lock();
+ ctx = rcu_dereference(gss_cred->gc_ctx);
+
+ /* did the ctx disappear or was it replaced by one with no acceptor? */
+ if (!ctx || !ctx->gc_acceptor.len) {
+ kfree(string);
+ string = NULL;
+ goto out;
+ }
+
+ acceptor = &ctx->gc_acceptor;
+
+ /*
+ * Did we find a new acceptor that's longer than the original? Allocate
+ * a longer buffer and try again.
+ */
+ if (len < acceptor->len) {
+ len = acceptor->len;
+ rcu_read_unlock();
+ kfree(string);
+ goto realloc;
+ }
+
+ memcpy(string, acceptor->data, acceptor->len);
+ string[acceptor->len] = '\0';
+out:
+ rcu_read_unlock();
+ return string;
+}
+
+/*
+ * Returns -EACCES if GSS context is NULL or will expire within the
+ * timeout (miliseconds)
+ */
+static int
+gss_key_timeout(struct rpc_cred *rc)
+{
+ struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
+ struct gss_cl_ctx *ctx;
+ unsigned long timeout = jiffies + (gss_key_expire_timeo * HZ);
+ int ret = 0;
+
+ rcu_read_lock();
+ ctx = rcu_dereference(gss_cred->gc_ctx);
+ if (!ctx || time_after(timeout, ctx->gc_expiry))
+ ret = -EACCES;
+ rcu_read_unlock();
+
+ return ret;
+}
+
+static int
+gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
+{
+ struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
+ struct gss_cl_ctx *ctx;
+ int ret;
+
+ if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
+ goto out;
+ /* Don't match with creds that have expired. */
+ rcu_read_lock();
+ ctx = rcu_dereference(gss_cred->gc_ctx);
+ if (!ctx || time_after(jiffies, ctx->gc_expiry)) {
+ rcu_read_unlock();
+ return 0;
+ }
+ rcu_read_unlock();
+ if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
+ return 0;
+out:
+ if (acred->principal != NULL) {
+ if (gss_cred->gc_principal == NULL)
+ return 0;
+ ret = strcmp(acred->principal, gss_cred->gc_principal) == 0;
+ } else {
+ if (gss_cred->gc_principal != NULL)
+ return 0;
+ ret = uid_eq(rc->cr_cred->fsuid, acred->cred->fsuid);
+ }
+ return ret;
+}
+
+/*
+ * Marshal credentials.
+ *
+ * The expensive part is computing the verifier. We can't cache a
+ * pre-computed version of the verifier because the seqno, which
+ * is different every time, is included in the MIC.
+ */
+static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_cred *cred = req->rq_cred;
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
+ gc_base);
+ struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
+ __be32 *p, *cred_len;
+ u32 maj_stat = 0;
+ struct xdr_netobj mic;
+ struct kvec iov;
+ struct xdr_buf verf_buf;
+ int status;
+
+ /* Credential */
+
+ p = xdr_reserve_space(xdr, 7 * sizeof(*p) +
+ ctx->gc_wire_ctx.len);
+ if (!p)
+ goto marshal_failed;
+ *p++ = rpc_auth_gss;
+ cred_len = p++;
+
+ spin_lock(&ctx->gc_seq_lock);
+ req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ;
+ spin_unlock(&ctx->gc_seq_lock);
+ if (req->rq_seqno == MAXSEQ)
+ goto expired;
+ trace_rpcgss_seqno(task);
+
+ *p++ = cpu_to_be32(RPC_GSS_VERSION);
+ *p++ = cpu_to_be32(ctx->gc_proc);
+ *p++ = cpu_to_be32(req->rq_seqno);
+ *p++ = cpu_to_be32(gss_cred->gc_service);
+ p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
+ *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2);
+
+ /* Verifier */
+
+ /* We compute the checksum for the verifier over the xdr-encoded bytes
+ * starting with the xid and ending at the end of the credential: */
+ iov.iov_base = req->rq_snd_buf.head[0].iov_base;
+ iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
+ xdr_buf_from_iov(&iov, &verf_buf);
+
+ p = xdr_reserve_space(xdr, sizeof(*p));
+ if (!p)
+ goto marshal_failed;
+ *p++ = rpc_auth_gss;
+ mic.data = (u8 *)(p + 1);
+ maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ goto expired;
+ else if (maj_stat != 0)
+ goto bad_mic;
+ if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0)
+ goto marshal_failed;
+ status = 0;
+out:
+ gss_put_ctx(ctx);
+ return status;
+expired:
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ status = -EKEYEXPIRED;
+ goto out;
+marshal_failed:
+ status = -EMSGSIZE;
+ goto out;
+bad_mic:
+ trace_rpcgss_get_mic(task, maj_stat);
+ status = -EIO;
+ goto out;
+}
+
+static int gss_renew_cred(struct rpc_task *task)
+{
+ struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
+ struct gss_cred *gss_cred = container_of(oldcred,
+ struct gss_cred,
+ gc_base);
+ struct rpc_auth *auth = oldcred->cr_auth;
+ struct auth_cred acred = {
+ .cred = oldcred->cr_cred,
+ .principal = gss_cred->gc_principal,
+ };
+ struct rpc_cred *new;
+
+ new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
+ if (IS_ERR(new))
+ return PTR_ERR(new);
+
+ task->tk_rqstp->rq_cred = new;
+ put_rpccred(oldcred);
+ return 0;
+}
+
+static int gss_cred_is_negative_entry(struct rpc_cred *cred)
+{
+ if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
+ unsigned long now = jiffies;
+ unsigned long begin, expire;
+ struct gss_cred *gss_cred;
+
+ gss_cred = container_of(cred, struct gss_cred, gc_base);
+ begin = gss_cred->gc_upcall_timestamp;
+ expire = begin + gss_expired_cred_retry_delay * HZ;
+
+ if (time_in_range_open(now, begin, expire))
+ return 1;
+ }
+ return 0;
+}
+
+/*
+* Refresh credentials. XXX - finish
+*/
+static int
+gss_refresh(struct rpc_task *task)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+ int ret = 0;
+
+ if (gss_cred_is_negative_entry(cred))
+ return -EKEYEXPIRED;
+
+ if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
+ !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
+ ret = gss_renew_cred(task);
+ if (ret < 0)
+ goto out;
+ cred = task->tk_rqstp->rq_cred;
+ }
+
+ if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
+ ret = gss_refresh_upcall(task);
+out:
+ return ret;
+}
+
+/* Dummy refresh routine: used only when destroying the context */
+static int
+gss_refresh_null(struct rpc_task *task)
+{
+ return 0;
+}
+
+static int
+gss_validate(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+ struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
+ __be32 *p, *seq = NULL;
+ struct kvec iov;
+ struct xdr_buf verf_buf;
+ struct xdr_netobj mic;
+ u32 len, maj_stat;
+ int status;
+
+ p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+ if (!p)
+ goto validate_failed;
+ if (*p++ != rpc_auth_gss)
+ goto validate_failed;
+ len = be32_to_cpup(p);
+ if (len > RPC_MAX_AUTH_SIZE)
+ goto validate_failed;
+ p = xdr_inline_decode(xdr, len);
+ if (!p)
+ goto validate_failed;
+
+ seq = kmalloc(4, GFP_KERNEL);
+ if (!seq)
+ goto validate_failed;
+ *seq = cpu_to_be32(task->tk_rqstp->rq_seqno);
+ iov.iov_base = seq;
+ iov.iov_len = 4;
+ xdr_buf_from_iov(&iov, &verf_buf);
+ mic.data = (u8 *)p;
+ mic.len = len;
+ maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ if (maj_stat)
+ goto bad_mic;
+
+ /* We leave it to unwrap to calculate au_rslack. For now we just
+ * calculate the length of the verifier: */
+ if (test_bit(RPCAUTH_AUTH_UPDATE_SLACK, &cred->cr_auth->au_flags))
+ cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
+ status = 0;
+out:
+ gss_put_ctx(ctx);
+ kfree(seq);
+ return status;
+
+validate_failed:
+ status = -EIO;
+ goto out;
+bad_mic:
+ trace_rpcgss_verify_mic(task, maj_stat);
+ status = -EACCES;
+ goto out;
+}
+
+static noinline_for_stack int
+gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
+ struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_rqst *rqstp = task->tk_rqstp;
+ struct xdr_buf integ_buf, *snd_buf = &rqstp->rq_snd_buf;
+ struct xdr_netobj mic;
+ __be32 *p, *integ_len;
+ u32 offset, maj_stat;
+
+ p = xdr_reserve_space(xdr, 2 * sizeof(*p));
+ if (!p)
+ goto wrap_failed;
+ integ_len = p++;
+ *p = cpu_to_be32(rqstp->rq_seqno);
+
+ if (rpcauth_wrap_req_encode(task, xdr))
+ goto wrap_failed;
+
+ offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
+ if (xdr_buf_subsegment(snd_buf, &integ_buf,
+ offset, snd_buf->len - offset))
+ goto wrap_failed;
+ *integ_len = cpu_to_be32(integ_buf.len);
+
+ p = xdr_reserve_space(xdr, 0);
+ if (!p)
+ goto wrap_failed;
+ mic.data = (u8 *)(p + 1);
+ maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ else if (maj_stat)
+ goto bad_mic;
+ /* Check that the trailing MIC fit in the buffer, after the fact */
+ if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0)
+ goto wrap_failed;
+ return 0;
+wrap_failed:
+ return -EMSGSIZE;
+bad_mic:
+ trace_rpcgss_get_mic(task, maj_stat);
+ return -EIO;
+}
+
+static void
+priv_release_snd_buf(struct rpc_rqst *rqstp)
+{
+ int i;
+
+ for (i=0; i < rqstp->rq_enc_pages_num; i++)
+ __free_page(rqstp->rq_enc_pages[i]);
+ kfree(rqstp->rq_enc_pages);
+ rqstp->rq_release_snd_buf = NULL;
+}
+
+static int
+alloc_enc_pages(struct rpc_rqst *rqstp)
+{
+ struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
+ int first, last, i;
+
+ if (rqstp->rq_release_snd_buf)
+ rqstp->rq_release_snd_buf(rqstp);
+
+ if (snd_buf->page_len == 0) {
+ rqstp->rq_enc_pages_num = 0;
+ return 0;
+ }
+
+ first = snd_buf->page_base >> PAGE_SHIFT;
+ last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_SHIFT;
+ rqstp->rq_enc_pages_num = last - first + 1 + 1;
+ rqstp->rq_enc_pages
+ = kmalloc_array(rqstp->rq_enc_pages_num,
+ sizeof(struct page *),
+ GFP_KERNEL);
+ if (!rqstp->rq_enc_pages)
+ goto out;
+ for (i=0; i < rqstp->rq_enc_pages_num; i++) {
+ rqstp->rq_enc_pages[i] = alloc_page(GFP_KERNEL);
+ if (rqstp->rq_enc_pages[i] == NULL)
+ goto out_free;
+ }
+ rqstp->rq_release_snd_buf = priv_release_snd_buf;
+ return 0;
+out_free:
+ rqstp->rq_enc_pages_num = i;
+ priv_release_snd_buf(rqstp);
+out:
+ return -EAGAIN;
+}
+
+static noinline_for_stack int
+gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
+ struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_rqst *rqstp = task->tk_rqstp;
+ struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
+ u32 pad, offset, maj_stat;
+ int status;
+ __be32 *p, *opaque_len;
+ struct page **inpages;
+ int first;
+ struct kvec *iov;
+
+ status = -EIO;
+ p = xdr_reserve_space(xdr, 2 * sizeof(*p));
+ if (!p)
+ goto wrap_failed;
+ opaque_len = p++;
+ *p = cpu_to_be32(rqstp->rq_seqno);
+
+ if (rpcauth_wrap_req_encode(task, xdr))
+ goto wrap_failed;
+
+ status = alloc_enc_pages(rqstp);
+ if (unlikely(status))
+ goto wrap_failed;
+ first = snd_buf->page_base >> PAGE_SHIFT;
+ inpages = snd_buf->pages + first;
+ snd_buf->pages = rqstp->rq_enc_pages;
+ snd_buf->page_base -= first << PAGE_SHIFT;
+ /*
+ * Move the tail into its own page, in case gss_wrap needs
+ * more space in the head when wrapping.
+ *
+ * Still... Why can't gss_wrap just slide the tail down?
+ */
+ if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
+ char *tmp;
+
+ tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
+ memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
+ snd_buf->tail[0].iov_base = tmp;
+ }
+ offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
+ maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
+ /* slack space should prevent this ever happening: */
+ if (unlikely(snd_buf->len > snd_buf->buflen))
+ goto wrap_failed;
+ /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
+ * done anyway, so it's safe to put the request on the wire: */
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ else if (maj_stat)
+ goto bad_wrap;
+
+ *opaque_len = cpu_to_be32(snd_buf->len - offset);
+ /* guess whether the pad goes into the head or the tail: */
+ if (snd_buf->page_len || snd_buf->tail[0].iov_len)
+ iov = snd_buf->tail;
+ else
+ iov = snd_buf->head;
+ p = iov->iov_base + iov->iov_len;
+ pad = xdr_pad_size(snd_buf->len - offset);
+ memset(p, 0, pad);
+ iov->iov_len += pad;
+ snd_buf->len += pad;
+
+ return 0;
+wrap_failed:
+ return status;
+bad_wrap:
+ trace_rpcgss_wrap(task, maj_stat);
+ return -EIO;
+}
+
+static int gss_wrap_req(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
+ gc_base);
+ struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
+ int status;
+
+ status = -EIO;
+ if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
+ /* The spec seems a little ambiguous here, but I think that not
+ * wrapping context destruction requests makes the most sense.
+ */
+ status = rpcauth_wrap_req_encode(task, xdr);
+ goto out;
+ }
+ switch (gss_cred->gc_service) {
+ case RPC_GSS_SVC_NONE:
+ status = rpcauth_wrap_req_encode(task, xdr);
+ break;
+ case RPC_GSS_SVC_INTEGRITY:
+ status = gss_wrap_req_integ(cred, ctx, task, xdr);
+ break;
+ case RPC_GSS_SVC_PRIVACY:
+ status = gss_wrap_req_priv(cred, ctx, task, xdr);
+ break;
+ default:
+ status = -EIO;
+ }
+out:
+ gss_put_ctx(ctx);
+ return status;
+}
+
+/**
+ * gss_update_rslack - Possibly update RPC receive buffer size estimates
+ * @task: rpc_task for incoming RPC Reply being unwrapped
+ * @cred: controlling rpc_cred for @task
+ * @before: XDR words needed before each RPC Reply message
+ * @after: XDR words needed following each RPC Reply message
+ *
+ */
+static void gss_update_rslack(struct rpc_task *task, struct rpc_cred *cred,
+ unsigned int before, unsigned int after)
+{
+ struct rpc_auth *auth = cred->cr_auth;
+
+ if (test_and_clear_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags)) {
+ auth->au_ralign = auth->au_verfsize + before;
+ auth->au_rslack = auth->au_verfsize + after;
+ trace_rpcgss_update_slack(task, auth);
+ }
+}
+
+static int
+gss_unwrap_resp_auth(struct rpc_task *task, struct rpc_cred *cred)
+{
+ gss_update_rslack(task, cred, 0, 0);
+ return 0;
+}
+
+/*
+ * RFC 2203, Section 5.3.2.2
+ *
+ * struct rpc_gss_integ_data {
+ * opaque databody_integ<>;
+ * opaque checksum<>;
+ * };
+ *
+ * struct rpc_gss_data_t {
+ * unsigned int seq_num;
+ * proc_req_arg_t arg;
+ * };
+ */
+static noinline_for_stack int
+gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred,
+ struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp,
+ struct xdr_stream *xdr)
+{
+ struct xdr_buf gss_data, *rcv_buf = &rqstp->rq_rcv_buf;
+ u32 len, offset, seqno, maj_stat;
+ struct xdr_netobj mic;
+ int ret;
+
+ ret = -EIO;
+ mic.data = NULL;
+
+ /* opaque databody_integ<>; */
+ if (xdr_stream_decode_u32(xdr, &len))
+ goto unwrap_failed;
+ if (len & 3)
+ goto unwrap_failed;
+ offset = rcv_buf->len - xdr_stream_remaining(xdr);
+ if (xdr_stream_decode_u32(xdr, &seqno))
+ goto unwrap_failed;
+ if (seqno != rqstp->rq_seqno)
+ goto bad_seqno;
+ if (xdr_buf_subsegment(rcv_buf, &gss_data, offset, len))
+ goto unwrap_failed;
+
+ /*
+ * The xdr_stream now points to the beginning of the
+ * upper layer payload, to be passed below to
+ * rpcauth_unwrap_resp_decode(). The checksum, which
+ * follows the upper layer payload in @rcv_buf, is
+ * located and parsed without updating the xdr_stream.
+ */
+
+ /* opaque checksum<>; */
+ offset += len;
+ if (xdr_decode_word(rcv_buf, offset, &len))
+ goto unwrap_failed;
+ offset += sizeof(__be32);
+ if (offset + len > rcv_buf->len)
+ goto unwrap_failed;
+ mic.len = len;
+ mic.data = kmalloc(len, GFP_KERNEL);
+ if (ZERO_OR_NULL_PTR(mic.data))
+ goto unwrap_failed;
+ if (read_bytes_from_xdr_buf(rcv_buf, offset, mic.data, mic.len))
+ goto unwrap_failed;
+
+ maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &gss_data, &mic);
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_mic;
+
+ gss_update_rslack(task, cred, 2, 2 + 1 + XDR_QUADLEN(mic.len));
+ ret = 0;
+
+out:
+ kfree(mic.data);
+ return ret;
+
+unwrap_failed:
+ trace_rpcgss_unwrap_failed(task);
+ goto out;
+bad_seqno:
+ trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, seqno);
+ goto out;
+bad_mic:
+ trace_rpcgss_verify_mic(task, maj_stat);
+ goto out;
+}
+
+static noinline_for_stack int
+gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred,
+ struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp,
+ struct xdr_stream *xdr)
+{
+ struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf;
+ struct kvec *head = rqstp->rq_rcv_buf.head;
+ u32 offset, opaque_len, maj_stat;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+ if (unlikely(!p))
+ goto unwrap_failed;
+ opaque_len = be32_to_cpup(p++);
+ offset = (u8 *)(p) - (u8 *)head->iov_base;
+ if (offset + opaque_len > rcv_buf->len)
+ goto unwrap_failed;
+
+ maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset,
+ offset + opaque_len, rcv_buf);
+ if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_unwrap;
+ /* gss_unwrap decrypted the sequence number */
+ if (be32_to_cpup(p++) != rqstp->rq_seqno)
+ goto bad_seqno;
+
+ /* gss_unwrap redacts the opaque blob from the head iovec.
+ * rcv_buf has changed, thus the stream needs to be reset.
+ */
+ xdr_init_decode(xdr, rcv_buf, p, rqstp);
+
+ gss_update_rslack(task, cred, 2 + ctx->gc_gss_ctx->align,
+ 2 + ctx->gc_gss_ctx->slack);
+
+ return 0;
+unwrap_failed:
+ trace_rpcgss_unwrap_failed(task);
+ return -EIO;
+bad_seqno:
+ trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(--p));
+ return -EIO;
+bad_unwrap:
+ trace_rpcgss_unwrap(task, maj_stat);
+ return -EIO;
+}
+
+static bool
+gss_seq_is_newer(u32 new, u32 old)
+{
+ return (s32)(new - old) > 0;
+}
+
+static bool
+gss_xmit_need_reencode(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_cred *cred = req->rq_cred;
+ struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
+ u32 win, seq_xmit = 0;
+ bool ret = true;
+
+ if (!ctx)
+ goto out;
+
+ if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq)))
+ goto out_ctx;
+
+ seq_xmit = READ_ONCE(ctx->gc_seq_xmit);
+ while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) {
+ u32 tmp = seq_xmit;
+
+ seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno);
+ if (seq_xmit == tmp) {
+ ret = false;
+ goto out_ctx;
+ }
+ }
+
+ win = ctx->gc_win;
+ if (win > 0)
+ ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win);
+
+out_ctx:
+ gss_put_ctx(ctx);
+out:
+ trace_rpcgss_need_reencode(task, seq_xmit, ret);
+ return ret;
+}
+
+static int
+gss_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_rqst *rqstp = task->tk_rqstp;
+ struct rpc_cred *cred = rqstp->rq_cred;
+ struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
+ gc_base);
+ struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
+ int status = -EIO;
+
+ if (ctx->gc_proc != RPC_GSS_PROC_DATA)
+ goto out_decode;
+ switch (gss_cred->gc_service) {
+ case RPC_GSS_SVC_NONE:
+ status = gss_unwrap_resp_auth(task, cred);
+ break;
+ case RPC_GSS_SVC_INTEGRITY:
+ status = gss_unwrap_resp_integ(task, cred, ctx, rqstp, xdr);
+ break;
+ case RPC_GSS_SVC_PRIVACY:
+ status = gss_unwrap_resp_priv(task, cred, ctx, rqstp, xdr);
+ break;
+ }
+ if (status)
+ goto out;
+
+out_decode:
+ status = rpcauth_unwrap_resp_decode(task, xdr);
+out:
+ gss_put_ctx(ctx);
+ return status;
+}
+
+static const struct rpc_authops authgss_ops = {
+ .owner = THIS_MODULE,
+ .au_flavor = RPC_AUTH_GSS,
+ .au_name = "RPCSEC_GSS",
+ .create = gss_create,
+ .destroy = gss_destroy,
+ .hash_cred = gss_hash_cred,
+ .lookup_cred = gss_lookup_cred,
+ .crcreate = gss_create_cred,
+ .info2flavor = gss_mech_info2flavor,
+ .flavor2info = gss_mech_flavor2info,
+};
+
+static const struct rpc_credops gss_credops = {
+ .cr_name = "AUTH_GSS",
+ .crdestroy = gss_destroy_cred,
+ .cr_init = gss_cred_init,
+ .crmatch = gss_match,
+ .crmarshal = gss_marshal,
+ .crrefresh = gss_refresh,
+ .crvalidate = gss_validate,
+ .crwrap_req = gss_wrap_req,
+ .crunwrap_resp = gss_unwrap_resp,
+ .crkey_timeout = gss_key_timeout,
+ .crstringify_acceptor = gss_stringify_acceptor,
+ .crneed_reencode = gss_xmit_need_reencode,
+};
+
+static const struct rpc_credops gss_nullops = {
+ .cr_name = "AUTH_GSS",
+ .crdestroy = gss_destroy_nullcred,
+ .crmatch = gss_match,
+ .crmarshal = gss_marshal,
+ .crrefresh = gss_refresh_null,
+ .crvalidate = gss_validate,
+ .crwrap_req = gss_wrap_req,
+ .crunwrap_resp = gss_unwrap_resp,
+ .crstringify_acceptor = gss_stringify_acceptor,
+};
+
+static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
+ .upcall = gss_v0_upcall,
+ .downcall = gss_pipe_downcall,
+ .destroy_msg = gss_pipe_destroy_msg,
+ .open_pipe = gss_pipe_open_v0,
+ .release_pipe = gss_pipe_release,
+};
+
+static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
+ .upcall = gss_v1_upcall,
+ .downcall = gss_pipe_downcall,
+ .destroy_msg = gss_pipe_destroy_msg,
+ .open_pipe = gss_pipe_open_v1,
+ .release_pipe = gss_pipe_release,
+};
+
+static __net_init int rpcsec_gss_init_net(struct net *net)
+{
+ return gss_svc_init_net(net);
+}
+
+static __net_exit void rpcsec_gss_exit_net(struct net *net)
+{
+ gss_svc_shutdown_net(net);
+}
+
+static struct pernet_operations rpcsec_gss_net_ops = {
+ .init = rpcsec_gss_init_net,
+ .exit = rpcsec_gss_exit_net,
+};
+
+/*
+ * Initialize RPCSEC_GSS module
+ */
+static int __init init_rpcsec_gss(void)
+{
+ int err = 0;
+
+ err = rpcauth_register(&authgss_ops);
+ if (err)
+ goto out;
+ err = gss_svc_init();
+ if (err)
+ goto out_unregister;
+ err = register_pernet_subsys(&rpcsec_gss_net_ops);
+ if (err)
+ goto out_svc_exit;
+ rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
+ return 0;
+out_svc_exit:
+ gss_svc_shutdown();
+out_unregister:
+ rpcauth_unregister(&authgss_ops);
+out:
+ return err;
+}
+
+static void __exit exit_rpcsec_gss(void)
+{
+ unregister_pernet_subsys(&rpcsec_gss_net_ops);
+ gss_svc_shutdown();
+ rpcauth_unregister(&authgss_ops);
+ rcu_barrier(); /* Wait for completion of call_rcu()'s */
+}
+
+MODULE_ALIAS("rpc-auth-6");
+MODULE_LICENSE("GPL");
+module_param_named(expired_cred_retry_delay,
+ gss_expired_cred_retry_delay,
+ uint, 0644);
+MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
+ "the RPC engine retries an expired credential");
+
+module_param_named(key_expire_timeo,
+ gss_key_expire_timeo,
+ uint, 0644);
+MODULE_PARM_DESC(key_expire_timeo, "Time (in seconds) at the end of a "
+ "credential keys lifetime where the NFS layer cleans up "
+ "prior to key expiration");
+
+module_init(init_rpcsec_gss)
+module_exit(exit_rpcsec_gss)
diff --git a/net/sunrpc/auth_gss/auth_gss_internal.h b/net/sunrpc/auth_gss/auth_gss_internal.h
new file mode 100644
index 0000000000..c53b329092
--- /dev/null
+++ b/net/sunrpc/auth_gss/auth_gss_internal.h
@@ -0,0 +1,45 @@
+// SPDX-License-Identifier: BSD-3-Clause
+/*
+ * linux/net/sunrpc/auth_gss/auth_gss_internal.h
+ *
+ * Internal definitions for RPCSEC_GSS client authentication
+ *
+ * Copyright (c) 2000 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ */
+#include <linux/err.h>
+#include <linux/string.h>
+#include <linux/sunrpc/xdr.h>
+
+static inline const void *
+simple_get_bytes(const void *p, const void *end, void *res, size_t len)
+{
+ const void *q = (const void *)((const char *)p + len);
+ if (unlikely(q > end || q < p))
+ return ERR_PTR(-EFAULT);
+ memcpy(res, p, len);
+ return q;
+}
+
+static inline const void *
+simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
+{
+ const void *q;
+ unsigned int len;
+
+ p = simple_get_bytes(p, end, &len, sizeof(len));
+ if (IS_ERR(p))
+ return p;
+ q = (const void *)((const char *)p + len);
+ if (unlikely(q > end || q < p))
+ return ERR_PTR(-EFAULT);
+ if (len) {
+ dest->data = kmemdup(p, len, GFP_KERNEL);
+ if (unlikely(dest->data == NULL))
+ return ERR_PTR(-ENOMEM);
+ } else
+ dest->data = NULL;
+ dest->len = len;
+ return q;
+}
diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c
new file mode 100644
index 0000000000..4a4082bb22
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_generic_token.c
@@ -0,0 +1,231 @@
+/*
+ * linux/net/sunrpc/gss_generic_token.c
+ *
+ * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c
+ *
+ * Copyright (c) 2000 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Andy Adamson <andros@umich.edu>
+ */
+
+/*
+ * Copyright 1993 by OpenVision Technologies, Inc.
+ *
+ * Permission to use, copy, modify, distribute, and sell this software
+ * and its documentation for any purpose is hereby granted without fee,
+ * provided that the above copyright notice appears in all copies and
+ * that both that copyright notice and this permission notice appear in
+ * supporting documentation, and that the name of OpenVision not be used
+ * in advertising or publicity pertaining to distribution of the software
+ * without specific, written prior permission. OpenVision makes no
+ * representations about the suitability of this software for any
+ * purpose. It is provided "as is" without express or implied warranty.
+ *
+ * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
+ * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO
+ * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR
+ * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF
+ * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
+ * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+ * PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/string.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/gss_asn1.h>
+
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+
+/* TWRITE_STR from gssapiP_generic.h */
+#define TWRITE_STR(ptr, str, len) \
+ memcpy((ptr), (char *) (str), (len)); \
+ (ptr) += (len);
+
+/* XXXX this code currently makes the assumption that a mech oid will
+ never be longer than 127 bytes. This assumption is not inherent in
+ the interfaces, so the code can be fixed if the OSI namespace
+ balloons unexpectedly. */
+
+/* Each token looks like this:
+
+0x60 tag for APPLICATION 0, SEQUENCE
+ (constructed, definite-length)
+ <length> possible multiple bytes, need to parse/generate
+ 0x06 tag for OBJECT IDENTIFIER
+ <moid_length> compile-time constant string (assume 1 byte)
+ <moid_bytes> compile-time constant string
+ <inner_bytes> the ANY containing the application token
+ bytes 0,1 are the token type
+ bytes 2,n are the token data
+
+For the purposes of this abstraction, the token "header" consists of
+the sequence tag and length octets, the mech OID DER encoding, and the
+first two inner bytes, which indicate the token type. The token
+"body" consists of everything else.
+
+*/
+
+static int
+der_length_size( int length)
+{
+ if (length < (1<<7))
+ return 1;
+ else if (length < (1<<8))
+ return 2;
+#if (SIZEOF_INT == 2)
+ else
+ return 3;
+#else
+ else if (length < (1<<16))
+ return 3;
+ else if (length < (1<<24))
+ return 4;
+ else
+ return 5;
+#endif
+}
+
+static void
+der_write_length(unsigned char **buf, int length)
+{
+ if (length < (1<<7)) {
+ *(*buf)++ = (unsigned char) length;
+ } else {
+ *(*buf)++ = (unsigned char) (der_length_size(length)+127);
+#if (SIZEOF_INT > 2)
+ if (length >= (1<<24))
+ *(*buf)++ = (unsigned char) (length>>24);
+ if (length >= (1<<16))
+ *(*buf)++ = (unsigned char) ((length>>16)&0xff);
+#endif
+ if (length >= (1<<8))
+ *(*buf)++ = (unsigned char) ((length>>8)&0xff);
+ *(*buf)++ = (unsigned char) (length&0xff);
+ }
+}
+
+/* returns decoded length, or < 0 on failure. Advances buf and
+ decrements bufsize */
+
+static int
+der_read_length(unsigned char **buf, int *bufsize)
+{
+ unsigned char sf;
+ int ret;
+
+ if (*bufsize < 1)
+ return -1;
+ sf = *(*buf)++;
+ (*bufsize)--;
+ if (sf & 0x80) {
+ if ((sf &= 0x7f) > ((*bufsize)-1))
+ return -1;
+ if (sf > SIZEOF_INT)
+ return -1;
+ ret = 0;
+ for (; sf; sf--) {
+ ret = (ret<<8) + (*(*buf)++);
+ (*bufsize)--;
+ }
+ } else {
+ ret = sf;
+ }
+
+ return ret;
+}
+
+/* returns the length of a token, given the mech oid and the body size */
+
+int
+g_token_size(struct xdr_netobj *mech, unsigned int body_size)
+{
+ /* set body_size to sequence contents size */
+ body_size += 2 + (int) mech->len; /* NEED overflow check */
+ return 1 + der_length_size(body_size) + body_size;
+}
+
+EXPORT_SYMBOL_GPL(g_token_size);
+
+/* fills in a buffer with the token header. The buffer is assumed to
+ be the right size. buf is advanced past the token header */
+
+void
+g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf)
+{
+ *(*buf)++ = 0x60;
+ der_write_length(buf, 2 + mech->len + body_size);
+ *(*buf)++ = 0x06;
+ *(*buf)++ = (unsigned char) mech->len;
+ TWRITE_STR(*buf, mech->data, ((int) mech->len));
+}
+
+EXPORT_SYMBOL_GPL(g_make_token_header);
+
+/*
+ * Given a buffer containing a token, reads and verifies the token,
+ * leaving buf advanced past the token header, and setting body_size
+ * to the number of remaining bytes. Returns 0 on success,
+ * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the
+ * mechanism in the token does not match the mech argument. buf and
+ * *body_size are left unmodified on error.
+ */
+u32
+g_verify_token_header(struct xdr_netobj *mech, int *body_size,
+ unsigned char **buf_in, int toksize)
+{
+ unsigned char *buf = *buf_in;
+ int seqsize;
+ struct xdr_netobj toid;
+ int ret = 0;
+
+ if ((toksize-=1) < 0)
+ return G_BAD_TOK_HEADER;
+ if (*buf++ != 0x60)
+ return G_BAD_TOK_HEADER;
+
+ if ((seqsize = der_read_length(&buf, &toksize)) < 0)
+ return G_BAD_TOK_HEADER;
+
+ if (seqsize != toksize)
+ return G_BAD_TOK_HEADER;
+
+ if ((toksize-=1) < 0)
+ return G_BAD_TOK_HEADER;
+ if (*buf++ != 0x06)
+ return G_BAD_TOK_HEADER;
+
+ if ((toksize-=1) < 0)
+ return G_BAD_TOK_HEADER;
+ toid.len = *buf++;
+
+ if ((toksize-=toid.len) < 0)
+ return G_BAD_TOK_HEADER;
+ toid.data = buf;
+ buf+=toid.len;
+
+ if (! g_OID_equal(&toid, mech))
+ ret = G_WRONG_MECH;
+
+ /* G_WRONG_MECH is not returned immediately because it's more important
+ to return G_BAD_TOK_HEADER if the token header is in fact bad */
+
+ if ((toksize-=2) < 0)
+ return G_BAD_TOK_HEADER;
+
+ if (ret)
+ return ret;
+
+ *buf_in = buf;
+ *body_size = toksize;
+
+ return ret;
+}
+
+EXPORT_SYMBOL_GPL(g_verify_token_header);
diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c
new file mode 100644
index 0000000000..9734e1d9f9
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c
@@ -0,0 +1,1154 @@
+/*
+ * linux/net/sunrpc/gss_krb5_crypto.c
+ *
+ * Copyright (c) 2000-2008 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Andy Adamson <andros@umich.edu>
+ * Bruce Fields <bfields@umich.edu>
+ */
+
+/*
+ * Copyright (C) 1998 by the FundsXpress, INC.
+ *
+ * All rights reserved.
+ *
+ * Export of this software from the United States of America may require
+ * a specific license from the United States Government. It is the
+ * responsibility of any person or organization contemplating export to
+ * obtain such a license before exporting.
+ *
+ * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
+ * distribute this software and its documentation for any purpose and
+ * without fee is hereby granted, provided that the above copyright
+ * notice appear in all copies and that both that copyright notice and
+ * this permission notice appear in supporting documentation, and that
+ * the name of FundsXpress. not be used in advertising or publicity pertaining
+ * to distribution of the software without specific, written prior
+ * permission. FundsXpress makes no representations about the suitability of
+ * this software for any purpose. It is provided "as is" without express
+ * or implied warranty.
+ *
+ * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
+ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+#include <crypto/algapi.h>
+#include <crypto/hash.h>
+#include <crypto/skcipher.h>
+#include <linux/err.h>
+#include <linux/types.h>
+#include <linux/mm.h>
+#include <linux/scatterlist.h>
+#include <linux/highmem.h>
+#include <linux/pagemap.h>
+#include <linux/random.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/sunrpc/xdr.h>
+#include <kunit/visibility.h>
+
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+/**
+ * krb5_make_confounder - Generate a confounder string
+ * @p: memory location into which to write the string
+ * @conflen: string length to write, in octets
+ *
+ * RFCs 1964 and 3961 mention only "a random confounder" without going
+ * into detail about its function or cryptographic requirements. The
+ * assumed purpose is to prevent repeated encryption of a plaintext with
+ * the same key from generating the same ciphertext. It is also used to
+ * pad minimum plaintext length to at least a single cipher block.
+ *
+ * However, in situations like the GSS Kerberos 5 mechanism, where the
+ * encryption IV is always all zeroes, the confounder also effectively
+ * functions like an IV. Thus, not only must it be unique from message
+ * to message, but it must also be difficult to predict. Otherwise an
+ * attacker can correlate the confounder to previous or future values,
+ * making the encryption easier to break.
+ *
+ * Given that the primary consumer of this encryption mechanism is a
+ * network storage protocol, a type of traffic that often carries
+ * predictable payloads (eg, all zeroes when reading unallocated blocks
+ * from a file), our confounder generation has to be cryptographically
+ * strong.
+ */
+void krb5_make_confounder(u8 *p, int conflen)
+{
+ get_random_bytes(p, conflen);
+}
+
+/**
+ * krb5_encrypt - simple encryption of an RPCSEC GSS payload
+ * @tfm: initialized cipher transform
+ * @iv: pointer to an IV
+ * @in: plaintext to encrypt
+ * @out: OUT: ciphertext
+ * @length: length of input and output buffers, in bytes
+ *
+ * @iv may be NULL to force the use of an all-zero IV.
+ * The buffer containing the IV must be as large as the
+ * cipher's ivsize.
+ *
+ * Return values:
+ * %0: @in successfully encrypted into @out
+ * negative errno: @in not encrypted
+ */
+u32
+krb5_encrypt(
+ struct crypto_sync_skcipher *tfm,
+ void * iv,
+ void * in,
+ void * out,
+ int length)
+{
+ u32 ret = -EINVAL;
+ struct scatterlist sg[1];
+ u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0};
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
+
+ if (length % crypto_sync_skcipher_blocksize(tfm) != 0)
+ goto out;
+
+ if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) {
+ dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n",
+ crypto_sync_skcipher_ivsize(tfm));
+ goto out;
+ }
+
+ if (iv)
+ memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm));
+
+ memcpy(out, in, length);
+ sg_init_one(sg, out, length);
+
+ skcipher_request_set_sync_tfm(req, tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+ skcipher_request_set_crypt(req, sg, sg, length, local_iv);
+
+ ret = crypto_skcipher_encrypt(req);
+ skcipher_request_zero(req);
+out:
+ dprintk("RPC: krb5_encrypt returns %d\n", ret);
+ return ret;
+}
+
+/**
+ * krb5_decrypt - simple decryption of an RPCSEC GSS payload
+ * @tfm: initialized cipher transform
+ * @iv: pointer to an IV
+ * @in: ciphertext to decrypt
+ * @out: OUT: plaintext
+ * @length: length of input and output buffers, in bytes
+ *
+ * @iv may be NULL to force the use of an all-zero IV.
+ * The buffer containing the IV must be as large as the
+ * cipher's ivsize.
+ *
+ * Return values:
+ * %0: @in successfully decrypted into @out
+ * negative errno: @in not decrypted
+ */
+u32
+krb5_decrypt(
+ struct crypto_sync_skcipher *tfm,
+ void * iv,
+ void * in,
+ void * out,
+ int length)
+{
+ u32 ret = -EINVAL;
+ struct scatterlist sg[1];
+ u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0};
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
+
+ if (length % crypto_sync_skcipher_blocksize(tfm) != 0)
+ goto out;
+
+ if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) {
+ dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n",
+ crypto_sync_skcipher_ivsize(tfm));
+ goto out;
+ }
+ if (iv)
+ memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm));
+
+ memcpy(out, in, length);
+ sg_init_one(sg, out, length);
+
+ skcipher_request_set_sync_tfm(req, tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+ skcipher_request_set_crypt(req, sg, sg, length, local_iv);
+
+ ret = crypto_skcipher_decrypt(req);
+ skcipher_request_zero(req);
+out:
+ dprintk("RPC: gss_k5decrypt returns %d\n",ret);
+ return ret;
+}
+
+static int
+checksummer(struct scatterlist *sg, void *data)
+{
+ struct ahash_request *req = data;
+
+ ahash_request_set_crypt(req, sg, NULL, sg->length);
+
+ return crypto_ahash_update(req);
+}
+
+/*
+ * checksum the plaintext data and hdrlen bytes of the token header
+ * The checksum is performed over the first 8 bytes of the
+ * gss token header and then over the data body
+ */
+u32
+make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen,
+ struct xdr_buf *body, int body_offset, u8 *cksumkey,
+ unsigned int usage, struct xdr_netobj *cksumout)
+{
+ struct crypto_ahash *tfm;
+ struct ahash_request *req;
+ struct scatterlist sg[1];
+ int err = -1;
+ u8 *checksumdata;
+ unsigned int checksumlen;
+
+ if (cksumout->len < kctx->gk5e->cksumlength) {
+ dprintk("%s: checksum buffer length, %u, too small for %s\n",
+ __func__, cksumout->len, kctx->gk5e->name);
+ return GSS_S_FAILURE;
+ }
+
+ checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_KERNEL);
+ if (checksumdata == NULL)
+ return GSS_S_FAILURE;
+
+ tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC);
+ if (IS_ERR(tfm))
+ goto out_free_cksum;
+
+ req = ahash_request_alloc(tfm, GFP_KERNEL);
+ if (!req)
+ goto out_free_ahash;
+
+ ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL);
+
+ checksumlen = crypto_ahash_digestsize(tfm);
+
+ if (cksumkey != NULL) {
+ err = crypto_ahash_setkey(tfm, cksumkey,
+ kctx->gk5e->keylength);
+ if (err)
+ goto out;
+ }
+
+ err = crypto_ahash_init(req);
+ if (err)
+ goto out;
+ sg_init_one(sg, header, hdrlen);
+ ahash_request_set_crypt(req, sg, NULL, hdrlen);
+ err = crypto_ahash_update(req);
+ if (err)
+ goto out;
+ err = xdr_process_buf(body, body_offset, body->len - body_offset,
+ checksummer, req);
+ if (err)
+ goto out;
+ ahash_request_set_crypt(req, NULL, checksumdata, 0);
+ err = crypto_ahash_final(req);
+ if (err)
+ goto out;
+
+ switch (kctx->gk5e->ctype) {
+ case CKSUMTYPE_RSA_MD5:
+ err = krb5_encrypt(kctx->seq, NULL, checksumdata,
+ checksumdata, checksumlen);
+ if (err)
+ goto out;
+ memcpy(cksumout->data,
+ checksumdata + checksumlen - kctx->gk5e->cksumlength,
+ kctx->gk5e->cksumlength);
+ break;
+ case CKSUMTYPE_HMAC_SHA1_DES3:
+ memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength);
+ break;
+ default:
+ BUG();
+ break;
+ }
+ cksumout->len = kctx->gk5e->cksumlength;
+out:
+ ahash_request_free(req);
+out_free_ahash:
+ crypto_free_ahash(tfm);
+out_free_cksum:
+ kfree(checksumdata);
+ return err ? GSS_S_FAILURE : 0;
+}
+
+/**
+ * gss_krb5_checksum - Compute the MAC for a GSS Wrap or MIC token
+ * @tfm: an initialized hash transform
+ * @header: pointer to a buffer containing the token header, or NULL
+ * @hdrlen: number of octets in @header
+ * @body: xdr_buf containing an RPC message (body.len is the message length)
+ * @body_offset: byte offset into @body to start checksumming
+ * @cksumout: OUT: a buffer to be filled in with the computed HMAC
+ *
+ * Usually expressed as H = HMAC(K, message)[1..h] .
+ *
+ * Caller provides the truncation length of the output token (h) in
+ * cksumout.len.
+ *
+ * Return values:
+ * %GSS_S_COMPLETE: Digest computed, @cksumout filled in
+ * %GSS_S_FAILURE: Call failed
+ */
+u32
+gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen,
+ const struct xdr_buf *body, int body_offset,
+ struct xdr_netobj *cksumout)
+{
+ struct ahash_request *req;
+ int err = -ENOMEM;
+ u8 *checksumdata;
+
+ checksumdata = kmalloc(crypto_ahash_digestsize(tfm), GFP_KERNEL);
+ if (!checksumdata)
+ return GSS_S_FAILURE;
+
+ req = ahash_request_alloc(tfm, GFP_KERNEL);
+ if (!req)
+ goto out_free_cksum;
+ ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL);
+ err = crypto_ahash_init(req);
+ if (err)
+ goto out_free_ahash;
+
+ /*
+ * Per RFC 4121 Section 4.2.4, the checksum is performed over the
+ * data body first, then over the octets in "header".
+ */
+ err = xdr_process_buf(body, body_offset, body->len - body_offset,
+ checksummer, req);
+ if (err)
+ goto out_free_ahash;
+ if (header) {
+ struct scatterlist sg[1];
+
+ sg_init_one(sg, header, hdrlen);
+ ahash_request_set_crypt(req, sg, NULL, hdrlen);
+ err = crypto_ahash_update(req);
+ if (err)
+ goto out_free_ahash;
+ }
+
+ ahash_request_set_crypt(req, NULL, checksumdata, 0);
+ err = crypto_ahash_final(req);
+ if (err)
+ goto out_free_ahash;
+
+ memcpy(cksumout->data, checksumdata,
+ min_t(int, cksumout->len, crypto_ahash_digestsize(tfm)));
+
+out_free_ahash:
+ ahash_request_free(req);
+out_free_cksum:
+ kfree_sensitive(checksumdata);
+ return err ? GSS_S_FAILURE : GSS_S_COMPLETE;
+}
+EXPORT_SYMBOL_IF_KUNIT(gss_krb5_checksum);
+
+struct encryptor_desc {
+ u8 iv[GSS_KRB5_MAX_BLOCKSIZE];
+ struct skcipher_request *req;
+ int pos;
+ struct xdr_buf *outbuf;
+ struct page **pages;
+ struct scatterlist infrags[4];
+ struct scatterlist outfrags[4];
+ int fragno;
+ int fraglen;
+};
+
+static int
+encryptor(struct scatterlist *sg, void *data)
+{
+ struct encryptor_desc *desc = data;
+ struct xdr_buf *outbuf = desc->outbuf;
+ struct crypto_sync_skcipher *tfm =
+ crypto_sync_skcipher_reqtfm(desc->req);
+ struct page *in_page;
+ int thislen = desc->fraglen + sg->length;
+ int fraglen, ret;
+ int page_pos;
+
+ /* Worst case is 4 fragments: head, end of page 1, start
+ * of page 2, tail. Anything more is a bug. */
+ BUG_ON(desc->fragno > 3);
+
+ page_pos = desc->pos - outbuf->head[0].iov_len;
+ if (page_pos >= 0 && page_pos < outbuf->page_len) {
+ /* pages are not in place: */
+ int i = (page_pos + outbuf->page_base) >> PAGE_SHIFT;
+ in_page = desc->pages[i];
+ } else {
+ in_page = sg_page(sg);
+ }
+ sg_set_page(&desc->infrags[desc->fragno], in_page, sg->length,
+ sg->offset);
+ sg_set_page(&desc->outfrags[desc->fragno], sg_page(sg), sg->length,
+ sg->offset);
+ desc->fragno++;
+ desc->fraglen += sg->length;
+ desc->pos += sg->length;
+
+ fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1);
+ thislen -= fraglen;
+
+ if (thislen == 0)
+ return 0;
+
+ sg_mark_end(&desc->infrags[desc->fragno - 1]);
+ sg_mark_end(&desc->outfrags[desc->fragno - 1]);
+
+ skcipher_request_set_crypt(desc->req, desc->infrags, desc->outfrags,
+ thislen, desc->iv);
+
+ ret = crypto_skcipher_encrypt(desc->req);
+ if (ret)
+ return ret;
+
+ sg_init_table(desc->infrags, 4);
+ sg_init_table(desc->outfrags, 4);
+
+ if (fraglen) {
+ sg_set_page(&desc->outfrags[0], sg_page(sg), fraglen,
+ sg->offset + sg->length - fraglen);
+ desc->infrags[0] = desc->outfrags[0];
+ sg_assign_page(&desc->infrags[0], in_page);
+ desc->fragno = 1;
+ desc->fraglen = fraglen;
+ } else {
+ desc->fragno = 0;
+ desc->fraglen = 0;
+ }
+ return 0;
+}
+
+int
+gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf,
+ int offset, struct page **pages)
+{
+ int ret;
+ struct encryptor_desc desc;
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
+
+ BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0);
+
+ skcipher_request_set_sync_tfm(req, tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+
+ memset(desc.iv, 0, sizeof(desc.iv));
+ desc.req = req;
+ desc.pos = offset;
+ desc.outbuf = buf;
+ desc.pages = pages;
+ desc.fragno = 0;
+ desc.fraglen = 0;
+
+ sg_init_table(desc.infrags, 4);
+ sg_init_table(desc.outfrags, 4);
+
+ ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc);
+ skcipher_request_zero(req);
+ return ret;
+}
+
+struct decryptor_desc {
+ u8 iv[GSS_KRB5_MAX_BLOCKSIZE];
+ struct skcipher_request *req;
+ struct scatterlist frags[4];
+ int fragno;
+ int fraglen;
+};
+
+static int
+decryptor(struct scatterlist *sg, void *data)
+{
+ struct decryptor_desc *desc = data;
+ int thislen = desc->fraglen + sg->length;
+ struct crypto_sync_skcipher *tfm =
+ crypto_sync_skcipher_reqtfm(desc->req);
+ int fraglen, ret;
+
+ /* Worst case is 4 fragments: head, end of page 1, start
+ * of page 2, tail. Anything more is a bug. */
+ BUG_ON(desc->fragno > 3);
+ sg_set_page(&desc->frags[desc->fragno], sg_page(sg), sg->length,
+ sg->offset);
+ desc->fragno++;
+ desc->fraglen += sg->length;
+
+ fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1);
+ thislen -= fraglen;
+
+ if (thislen == 0)
+ return 0;
+
+ sg_mark_end(&desc->frags[desc->fragno - 1]);
+
+ skcipher_request_set_crypt(desc->req, desc->frags, desc->frags,
+ thislen, desc->iv);
+
+ ret = crypto_skcipher_decrypt(desc->req);
+ if (ret)
+ return ret;
+
+ sg_init_table(desc->frags, 4);
+
+ if (fraglen) {
+ sg_set_page(&desc->frags[0], sg_page(sg), fraglen,
+ sg->offset + sg->length - fraglen);
+ desc->fragno = 1;
+ desc->fraglen = fraglen;
+ } else {
+ desc->fragno = 0;
+ desc->fraglen = 0;
+ }
+ return 0;
+}
+
+int
+gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf,
+ int offset)
+{
+ int ret;
+ struct decryptor_desc desc;
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
+
+ /* XXXJBF: */
+ BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0);
+
+ skcipher_request_set_sync_tfm(req, tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+
+ memset(desc.iv, 0, sizeof(desc.iv));
+ desc.req = req;
+ desc.fragno = 0;
+ desc.fraglen = 0;
+
+ sg_init_table(desc.frags, 4);
+
+ ret = xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc);
+ skcipher_request_zero(req);
+ return ret;
+}
+
+/*
+ * This function makes the assumption that it was ultimately called
+ * from gss_wrap().
+ *
+ * The client auth_gss code moves any existing tail data into a
+ * separate page before calling gss_wrap.
+ * The server svcauth_gss code ensures that both the head and the
+ * tail have slack space of RPC_MAX_AUTH_SIZE before calling gss_wrap.
+ *
+ * Even with that guarantee, this function may be called more than
+ * once in the processing of gss_wrap(). The best we can do is
+ * verify at compile-time (see GSS_KRB5_SLACK_CHECK) that the
+ * largest expected shift will fit within RPC_MAX_AUTH_SIZE.
+ * At run-time we can verify that a single invocation of this
+ * function doesn't attempt to use more the RPC_MAX_AUTH_SIZE.
+ */
+
+int
+xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen)
+{
+ u8 *p;
+
+ if (shiftlen == 0)
+ return 0;
+
+ BUG_ON(shiftlen > RPC_MAX_AUTH_SIZE);
+
+ p = buf->head[0].iov_base + base;
+
+ memmove(p + shiftlen, p, buf->head[0].iov_len - base);
+
+ buf->head[0].iov_len += shiftlen;
+ buf->len += shiftlen;
+
+ return 0;
+}
+
+static u32
+gss_krb5_cts_crypt(struct crypto_sync_skcipher *cipher, struct xdr_buf *buf,
+ u32 offset, u8 *iv, struct page **pages, int encrypt)
+{
+ u32 ret;
+ struct scatterlist sg[1];
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, cipher);
+ u8 *data;
+ struct page **save_pages;
+ u32 len = buf->len - offset;
+
+ if (len > GSS_KRB5_MAX_BLOCKSIZE * 2) {
+ WARN_ON(0);
+ return -ENOMEM;
+ }
+ data = kmalloc(GSS_KRB5_MAX_BLOCKSIZE * 2, GFP_KERNEL);
+ if (!data)
+ return -ENOMEM;
+
+ /*
+ * For encryption, we want to read from the cleartext
+ * page cache pages, and write the encrypted data to
+ * the supplied xdr_buf pages.
+ */
+ save_pages = buf->pages;
+ if (encrypt)
+ buf->pages = pages;
+
+ ret = read_bytes_from_xdr_buf(buf, offset, data, len);
+ buf->pages = save_pages;
+ if (ret)
+ goto out;
+
+ sg_init_one(sg, data, len);
+
+ skcipher_request_set_sync_tfm(req, cipher);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+ skcipher_request_set_crypt(req, sg, sg, len, iv);
+
+ if (encrypt)
+ ret = crypto_skcipher_encrypt(req);
+ else
+ ret = crypto_skcipher_decrypt(req);
+
+ skcipher_request_zero(req);
+
+ if (ret)
+ goto out;
+
+ ret = write_bytes_to_xdr_buf(buf, offset, data, len);
+
+#if IS_ENABLED(CONFIG_KUNIT)
+ /*
+ * CBC-CTS does not define an output IV but RFC 3962 defines it as the
+ * penultimate block of ciphertext, so copy that into the IV buffer
+ * before returning.
+ */
+ if (encrypt)
+ memcpy(iv, data, crypto_sync_skcipher_ivsize(cipher));
+#endif
+
+out:
+ kfree(data);
+ return ret;
+}
+
+/**
+ * krb5_cbc_cts_encrypt - encrypt in CBC mode with CTS
+ * @cts_tfm: CBC cipher with CTS
+ * @cbc_tfm: base CBC cipher
+ * @offset: starting byte offset for plaintext
+ * @buf: OUT: output buffer
+ * @pages: plaintext
+ * @iv: output CBC initialization vector, or NULL
+ * @ivsize: size of @iv, in octets
+ *
+ * To provide confidentiality, encrypt using cipher block chaining
+ * with ciphertext stealing. Message integrity is handled separately.
+ *
+ * Return values:
+ * %0: encryption successful
+ * negative errno: encryption could not be completed
+ */
+VISIBLE_IF_KUNIT
+int krb5_cbc_cts_encrypt(struct crypto_sync_skcipher *cts_tfm,
+ struct crypto_sync_skcipher *cbc_tfm,
+ u32 offset, struct xdr_buf *buf, struct page **pages,
+ u8 *iv, unsigned int ivsize)
+{
+ u32 blocksize, nbytes, nblocks, cbcbytes;
+ struct encryptor_desc desc;
+ int err;
+
+ blocksize = crypto_sync_skcipher_blocksize(cts_tfm);
+ nbytes = buf->len - offset;
+ nblocks = (nbytes + blocksize - 1) / blocksize;
+ cbcbytes = 0;
+ if (nblocks > 2)
+ cbcbytes = (nblocks - 2) * blocksize;
+
+ memset(desc.iv, 0, sizeof(desc.iv));
+
+ /* Handle block-sized chunks of plaintext with CBC. */
+ if (cbcbytes) {
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, cbc_tfm);
+
+ desc.pos = offset;
+ desc.fragno = 0;
+ desc.fraglen = 0;
+ desc.pages = pages;
+ desc.outbuf = buf;
+ desc.req = req;
+
+ skcipher_request_set_sync_tfm(req, cbc_tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+
+ sg_init_table(desc.infrags, 4);
+ sg_init_table(desc.outfrags, 4);
+
+ err = xdr_process_buf(buf, offset, cbcbytes, encryptor, &desc);
+ skcipher_request_zero(req);
+ if (err)
+ return err;
+ }
+
+ /* Remaining plaintext is handled with CBC-CTS. */
+ err = gss_krb5_cts_crypt(cts_tfm, buf, offset + cbcbytes,
+ desc.iv, pages, 1);
+ if (err)
+ return err;
+
+ if (unlikely(iv))
+ memcpy(iv, desc.iv, ivsize);
+ return 0;
+}
+EXPORT_SYMBOL_IF_KUNIT(krb5_cbc_cts_encrypt);
+
+/**
+ * krb5_cbc_cts_decrypt - decrypt in CBC mode with CTS
+ * @cts_tfm: CBC cipher with CTS
+ * @cbc_tfm: base CBC cipher
+ * @offset: starting byte offset for plaintext
+ * @buf: OUT: output buffer
+ *
+ * Return values:
+ * %0: decryption successful
+ * negative errno: decryption could not be completed
+ */
+VISIBLE_IF_KUNIT
+int krb5_cbc_cts_decrypt(struct crypto_sync_skcipher *cts_tfm,
+ struct crypto_sync_skcipher *cbc_tfm,
+ u32 offset, struct xdr_buf *buf)
+{
+ u32 blocksize, nblocks, cbcbytes;
+ struct decryptor_desc desc;
+ int err;
+
+ blocksize = crypto_sync_skcipher_blocksize(cts_tfm);
+ nblocks = (buf->len + blocksize - 1) / blocksize;
+ cbcbytes = 0;
+ if (nblocks > 2)
+ cbcbytes = (nblocks - 2) * blocksize;
+
+ memset(desc.iv, 0, sizeof(desc.iv));
+
+ /* Handle block-sized chunks of plaintext with CBC. */
+ if (cbcbytes) {
+ SYNC_SKCIPHER_REQUEST_ON_STACK(req, cbc_tfm);
+
+ desc.fragno = 0;
+ desc.fraglen = 0;
+ desc.req = req;
+
+ skcipher_request_set_sync_tfm(req, cbc_tfm);
+ skcipher_request_set_callback(req, 0, NULL, NULL);
+
+ sg_init_table(desc.frags, 4);
+
+ err = xdr_process_buf(buf, 0, cbcbytes, decryptor, &desc);
+ skcipher_request_zero(req);
+ if (err)
+ return err;
+ }
+
+ /* Remaining plaintext is handled with CBC-CTS. */
+ return gss_krb5_cts_crypt(cts_tfm, buf, cbcbytes, desc.iv, NULL, 0);
+}
+EXPORT_SYMBOL_IF_KUNIT(krb5_cbc_cts_decrypt);
+
+u32
+gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset,
+ struct xdr_buf *buf, struct page **pages)
+{
+ u32 err;
+ struct xdr_netobj hmac;
+ u8 *ecptr;
+ struct crypto_sync_skcipher *cipher, *aux_cipher;
+ struct crypto_ahash *ahash;
+ struct page **save_pages;
+ unsigned int conflen;
+
+ if (kctx->initiate) {
+ cipher = kctx->initiator_enc;
+ aux_cipher = kctx->initiator_enc_aux;
+ ahash = kctx->initiator_integ;
+ } else {
+ cipher = kctx->acceptor_enc;
+ aux_cipher = kctx->acceptor_enc_aux;
+ ahash = kctx->acceptor_integ;
+ }
+ conflen = crypto_sync_skcipher_blocksize(cipher);
+
+ /* hide the gss token header and insert the confounder */
+ offset += GSS_KRB5_TOK_HDR_LEN;
+ if (xdr_extend_head(buf, offset, conflen))
+ return GSS_S_FAILURE;
+ krb5_make_confounder(buf->head[0].iov_base + offset, conflen);
+ offset -= GSS_KRB5_TOK_HDR_LEN;
+
+ if (buf->tail[0].iov_base != NULL) {
+ ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len;
+ } else {
+ buf->tail[0].iov_base = buf->head[0].iov_base
+ + buf->head[0].iov_len;
+ buf->tail[0].iov_len = 0;
+ ecptr = buf->tail[0].iov_base;
+ }
+
+ /* copy plaintext gss token header after filler (if any) */
+ memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN);
+ buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN;
+ buf->len += GSS_KRB5_TOK_HDR_LEN;
+
+ hmac.len = kctx->gk5e->cksumlength;
+ hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len;
+
+ /*
+ * When we are called, pages points to the real page cache
+ * data -- which we can't go and encrypt! buf->pages points
+ * to scratch pages which we are going to send off to the
+ * client/server. Swap in the plaintext pages to calculate
+ * the hmac.
+ */
+ save_pages = buf->pages;
+ buf->pages = pages;
+
+ err = gss_krb5_checksum(ahash, NULL, 0, buf,
+ offset + GSS_KRB5_TOK_HDR_LEN, &hmac);
+ buf->pages = save_pages;
+ if (err)
+ return GSS_S_FAILURE;
+
+ err = krb5_cbc_cts_encrypt(cipher, aux_cipher,
+ offset + GSS_KRB5_TOK_HDR_LEN,
+ buf, pages, NULL, 0);
+ if (err)
+ return GSS_S_FAILURE;
+
+ /* Now update buf to account for HMAC */
+ buf->tail[0].iov_len += kctx->gk5e->cksumlength;
+ buf->len += kctx->gk5e->cksumlength;
+
+ return GSS_S_COMPLETE;
+}
+
+u32
+gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len,
+ struct xdr_buf *buf, u32 *headskip, u32 *tailskip)
+{
+ struct crypto_sync_skcipher *cipher, *aux_cipher;
+ struct crypto_ahash *ahash;
+ struct xdr_netobj our_hmac_obj;
+ u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN];
+ u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN];
+ struct xdr_buf subbuf;
+ u32 ret = 0;
+
+ if (kctx->initiate) {
+ cipher = kctx->acceptor_enc;
+ aux_cipher = kctx->acceptor_enc_aux;
+ ahash = kctx->acceptor_integ;
+ } else {
+ cipher = kctx->initiator_enc;
+ aux_cipher = kctx->initiator_enc_aux;
+ ahash = kctx->initiator_integ;
+ }
+
+ /* create a segment skipping the header and leaving out the checksum */
+ xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN,
+ (len - offset - GSS_KRB5_TOK_HDR_LEN -
+ kctx->gk5e->cksumlength));
+
+ ret = krb5_cbc_cts_decrypt(cipher, aux_cipher, 0, &subbuf);
+ if (ret)
+ goto out_err;
+
+ our_hmac_obj.len = kctx->gk5e->cksumlength;
+ our_hmac_obj.data = our_hmac;
+ ret = gss_krb5_checksum(ahash, NULL, 0, &subbuf, 0, &our_hmac_obj);
+ if (ret)
+ goto out_err;
+
+ /* Get the packet's hmac value */
+ ret = read_bytes_from_xdr_buf(buf, len - kctx->gk5e->cksumlength,
+ pkt_hmac, kctx->gk5e->cksumlength);
+ if (ret)
+ goto out_err;
+
+ if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) {
+ ret = GSS_S_BAD_SIG;
+ goto out_err;
+ }
+ *headskip = crypto_sync_skcipher_blocksize(cipher);
+ *tailskip = kctx->gk5e->cksumlength;
+out_err:
+ if (ret && ret != GSS_S_BAD_SIG)
+ ret = GSS_S_FAILURE;
+ return ret;
+}
+
+/**
+ * krb5_etm_checksum - Compute a MAC for a GSS Wrap token
+ * @cipher: an initialized cipher transform
+ * @tfm: an initialized hash transform
+ * @body: xdr_buf containing an RPC message (body.len is the message length)
+ * @body_offset: byte offset into @body to start checksumming
+ * @cksumout: OUT: a buffer to be filled in with the computed HMAC
+ *
+ * Usually expressed as H = HMAC(K, IV | ciphertext)[1..h] .
+ *
+ * Caller provides the truncation length of the output token (h) in
+ * cksumout.len.
+ *
+ * Return values:
+ * %GSS_S_COMPLETE: Digest computed, @cksumout filled in
+ * %GSS_S_FAILURE: Call failed
+ */
+VISIBLE_IF_KUNIT
+u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher,
+ struct crypto_ahash *tfm, const struct xdr_buf *body,
+ int body_offset, struct xdr_netobj *cksumout)
+{
+ unsigned int ivsize = crypto_sync_skcipher_ivsize(cipher);
+ struct ahash_request *req;
+ struct scatterlist sg[1];
+ u8 *iv, *checksumdata;
+ int err = -ENOMEM;
+
+ checksumdata = kmalloc(crypto_ahash_digestsize(tfm), GFP_KERNEL);
+ if (!checksumdata)
+ return GSS_S_FAILURE;
+ /* For RPCSEC, the "initial cipher state" is always all zeroes. */
+ iv = kzalloc(ivsize, GFP_KERNEL);
+ if (!iv)
+ goto out_free_mem;
+
+ req = ahash_request_alloc(tfm, GFP_KERNEL);
+ if (!req)
+ goto out_free_mem;
+ ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL);
+ err = crypto_ahash_init(req);
+ if (err)
+ goto out_free_ahash;
+
+ sg_init_one(sg, iv, ivsize);
+ ahash_request_set_crypt(req, sg, NULL, ivsize);
+ err = crypto_ahash_update(req);
+ if (err)
+ goto out_free_ahash;
+ err = xdr_process_buf(body, body_offset, body->len - body_offset,
+ checksummer, req);
+ if (err)
+ goto out_free_ahash;
+
+ ahash_request_set_crypt(req, NULL, checksumdata, 0);
+ err = crypto_ahash_final(req);
+ if (err)
+ goto out_free_ahash;
+ memcpy(cksumout->data, checksumdata, cksumout->len);
+
+out_free_ahash:
+ ahash_request_free(req);
+out_free_mem:
+ kfree(iv);
+ kfree_sensitive(checksumdata);
+ return err ? GSS_S_FAILURE : GSS_S_COMPLETE;
+}
+EXPORT_SYMBOL_IF_KUNIT(krb5_etm_checksum);
+
+/**
+ * krb5_etm_encrypt - Encrypt using the RFC 8009 rules
+ * @kctx: Kerberos context
+ * @offset: starting offset of the payload, in bytes
+ * @buf: OUT: send buffer to contain the encrypted payload
+ * @pages: plaintext payload
+ *
+ * The main difference with aes_encrypt is that "The HMAC is
+ * calculated over the cipher state concatenated with the AES
+ * output, instead of being calculated over the confounder and
+ * plaintext. This allows the message receiver to verify the
+ * integrity of the message before decrypting the message."
+ *
+ * RFC 8009 Section 5:
+ *
+ * encryption function: as follows, where E() is AES encryption in
+ * CBC-CS3 mode, and h is the size of truncated HMAC (128 bits or
+ * 192 bits as described above).
+ *
+ * N = random value of length 128 bits (the AES block size)
+ * IV = cipher state
+ * C = E(Ke, N | plaintext, IV)
+ * H = HMAC(Ki, IV | C)
+ * ciphertext = C | H[1..h]
+ *
+ * This encryption formula provides AEAD EtM with key separation.
+ *
+ * Return values:
+ * %GSS_S_COMPLETE: Encryption successful
+ * %GSS_S_FAILURE: Encryption failed
+ */
+u32
+krb5_etm_encrypt(struct krb5_ctx *kctx, u32 offset,
+ struct xdr_buf *buf, struct page **pages)
+{
+ struct crypto_sync_skcipher *cipher, *aux_cipher;
+ struct crypto_ahash *ahash;
+ struct xdr_netobj hmac;
+ unsigned int conflen;
+ u8 *ecptr;
+ u32 err;
+
+ if (kctx->initiate) {
+ cipher = kctx->initiator_enc;
+ aux_cipher = kctx->initiator_enc_aux;
+ ahash = kctx->initiator_integ;
+ } else {
+ cipher = kctx->acceptor_enc;
+ aux_cipher = kctx->acceptor_enc_aux;
+ ahash = kctx->acceptor_integ;
+ }
+ conflen = crypto_sync_skcipher_blocksize(cipher);
+
+ offset += GSS_KRB5_TOK_HDR_LEN;
+ if (xdr_extend_head(buf, offset, conflen))
+ return GSS_S_FAILURE;
+ krb5_make_confounder(buf->head[0].iov_base + offset, conflen);
+ offset -= GSS_KRB5_TOK_HDR_LEN;
+
+ if (buf->tail[0].iov_base) {
+ ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len;
+ } else {
+ buf->tail[0].iov_base = buf->head[0].iov_base
+ + buf->head[0].iov_len;
+ buf->tail[0].iov_len = 0;
+ ecptr = buf->tail[0].iov_base;
+ }
+
+ memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN);
+ buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN;
+ buf->len += GSS_KRB5_TOK_HDR_LEN;
+
+ err = krb5_cbc_cts_encrypt(cipher, aux_cipher,
+ offset + GSS_KRB5_TOK_HDR_LEN,
+ buf, pages, NULL, 0);
+ if (err)
+ return GSS_S_FAILURE;
+
+ hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len;
+ hmac.len = kctx->gk5e->cksumlength;
+ err = krb5_etm_checksum(cipher, ahash,
+ buf, offset + GSS_KRB5_TOK_HDR_LEN, &hmac);
+ if (err)
+ goto out_err;
+ buf->tail[0].iov_len += kctx->gk5e->cksumlength;
+ buf->len += kctx->gk5e->cksumlength;
+
+ return GSS_S_COMPLETE;
+
+out_err:
+ return GSS_S_FAILURE;
+}
+
+/**
+ * krb5_etm_decrypt - Decrypt using the RFC 8009 rules
+ * @kctx: Kerberos context
+ * @offset: starting offset of the ciphertext, in bytes
+ * @len:
+ * @buf:
+ * @headskip: OUT: the enctype's confounder length, in octets
+ * @tailskip: OUT: the enctype's HMAC length, in octets
+ *
+ * RFC 8009 Section 5:
+ *
+ * decryption function: as follows, where D() is AES decryption in
+ * CBC-CS3 mode, and h is the size of truncated HMAC.
+ *
+ * (C, H) = ciphertext
+ * (Note: H is the last h bits of the ciphertext.)
+ * IV = cipher state
+ * if H != HMAC(Ki, IV | C)[1..h]
+ * stop, report error
+ * (N, P) = D(Ke, C, IV)
+ *
+ * Return values:
+ * %GSS_S_COMPLETE: Decryption successful
+ * %GSS_S_BAD_SIG: computed HMAC != received HMAC
+ * %GSS_S_FAILURE: Decryption failed
+ */
+u32
+krb5_etm_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len,
+ struct xdr_buf *buf, u32 *headskip, u32 *tailskip)
+{
+ struct crypto_sync_skcipher *cipher, *aux_cipher;
+ u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN];
+ u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN];
+ struct xdr_netobj our_hmac_obj;
+ struct crypto_ahash *ahash;
+ struct xdr_buf subbuf;
+ u32 ret = 0;
+
+ if (kctx->initiate) {
+ cipher = kctx->acceptor_enc;
+ aux_cipher = kctx->acceptor_enc_aux;
+ ahash = kctx->acceptor_integ;
+ } else {
+ cipher = kctx->initiator_enc;
+ aux_cipher = kctx->initiator_enc_aux;
+ ahash = kctx->initiator_integ;
+ }
+
+ /* Extract the ciphertext into @subbuf. */
+ xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN,
+ (len - offset - GSS_KRB5_TOK_HDR_LEN -
+ kctx->gk5e->cksumlength));
+
+ our_hmac_obj.data = our_hmac;
+ our_hmac_obj.len = kctx->gk5e->cksumlength;
+ ret = krb5_etm_checksum(cipher, ahash, &subbuf, 0, &our_hmac_obj);
+ if (ret)
+ goto out_err;
+ ret = read_bytes_from_xdr_buf(buf, len - kctx->gk5e->cksumlength,
+ pkt_hmac, kctx->gk5e->cksumlength);
+ if (ret)
+ goto out_err;
+ if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) {
+ ret = GSS_S_BAD_SIG;
+ goto out_err;
+ }
+
+ ret = krb5_cbc_cts_decrypt(cipher, aux_cipher, 0, &subbuf);
+ if (ret) {
+ ret = GSS_S_FAILURE;
+ goto out_err;
+ }
+
+ *headskip = crypto_sync_skcipher_blocksize(cipher);
+ *tailskip = kctx->gk5e->cksumlength;
+ return GSS_S_COMPLETE;
+
+out_err:
+ if (ret != GSS_S_BAD_SIG)
+ ret = GSS_S_FAILURE;
+ return ret;
+}
diff --git a/net/sunrpc/auth_gss/gss_krb5_internal.h b/net/sunrpc/auth_gss/gss_krb5_internal.h
new file mode 100644
index 0000000000..3afd4065bf
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_internal.h
@@ -0,0 +1,209 @@
+/* SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause */
+/*
+ * SunRPC GSS Kerberos 5 mechanism internal definitions
+ *
+ * Copyright (c) 2022 Oracle and/or its affiliates.
+ */
+
+#ifndef _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H
+#define _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H
+
+/*
+ * The RFCs often specify payload lengths in bits. This helper
+ * converts a specified bit-length to the number of octets/bytes.
+ */
+#define BITS2OCTETS(x) ((x) / 8)
+
+struct krb5_ctx;
+
+struct gss_krb5_enctype {
+ const u32 etype; /* encryption (key) type */
+ const u32 ctype; /* checksum type */
+ const char *name; /* "friendly" name */
+ const char *encrypt_name; /* crypto encrypt name */
+ const char *aux_cipher; /* aux encrypt cipher name */
+ const char *cksum_name; /* crypto checksum name */
+ const u16 signalg; /* signing algorithm */
+ const u16 sealalg; /* sealing algorithm */
+ const u32 cksumlength; /* checksum length */
+ const u32 keyed_cksum; /* is it a keyed cksum? */
+ const u32 keybytes; /* raw key len, in bytes */
+ const u32 keylength; /* protocol key length, in octets */
+ const u32 Kc_length; /* checksum subkey length, in octets */
+ const u32 Ke_length; /* encryption subkey length, in octets */
+ const u32 Ki_length; /* integrity subkey length, in octets */
+
+ int (*derive_key)(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *in,
+ struct xdr_netobj *out,
+ const struct xdr_netobj *label,
+ gfp_t gfp_mask);
+ u32 (*encrypt)(struct krb5_ctx *kctx, u32 offset,
+ struct xdr_buf *buf, struct page **pages);
+ u32 (*decrypt)(struct krb5_ctx *kctx, u32 offset, u32 len,
+ struct xdr_buf *buf, u32 *headskip, u32 *tailskip);
+ u32 (*get_mic)(struct krb5_ctx *kctx, struct xdr_buf *text,
+ struct xdr_netobj *token);
+ u32 (*verify_mic)(struct krb5_ctx *kctx, struct xdr_buf *message_buffer,
+ struct xdr_netobj *read_token);
+ u32 (*wrap)(struct krb5_ctx *kctx, int offset,
+ struct xdr_buf *buf, struct page **pages);
+ u32 (*unwrap)(struct krb5_ctx *kctx, int offset, int len,
+ struct xdr_buf *buf, unsigned int *slack,
+ unsigned int *align);
+};
+
+/* krb5_ctx flags definitions */
+#define KRB5_CTX_FLAG_INITIATOR 0x00000001
+#define KRB5_CTX_FLAG_ACCEPTOR_SUBKEY 0x00000004
+
+struct krb5_ctx {
+ int initiate; /* 1 = initiating, 0 = accepting */
+ u32 enctype;
+ u32 flags;
+ const struct gss_krb5_enctype *gk5e; /* enctype-specific info */
+ struct crypto_sync_skcipher *enc;
+ struct crypto_sync_skcipher *seq;
+ struct crypto_sync_skcipher *acceptor_enc;
+ struct crypto_sync_skcipher *initiator_enc;
+ struct crypto_sync_skcipher *acceptor_enc_aux;
+ struct crypto_sync_skcipher *initiator_enc_aux;
+ struct crypto_ahash *acceptor_sign;
+ struct crypto_ahash *initiator_sign;
+ struct crypto_ahash *initiator_integ;
+ struct crypto_ahash *acceptor_integ;
+ u8 Ksess[GSS_KRB5_MAX_KEYLEN]; /* session key */
+ u8 cksum[GSS_KRB5_MAX_KEYLEN];
+ atomic_t seq_send;
+ atomic64_t seq_send64;
+ time64_t endtime;
+ struct xdr_netobj mech_used;
+};
+
+/*
+ * GSS Kerberos 5 mechanism Per-Message calls.
+ */
+
+u32 gss_krb5_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text,
+ struct xdr_netobj *token);
+
+u32 gss_krb5_verify_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *message_buffer,
+ struct xdr_netobj *read_token);
+
+u32 gss_krb5_wrap_v2(struct krb5_ctx *kctx, int offset,
+ struct xdr_buf *buf, struct page **pages);
+
+u32 gss_krb5_unwrap_v2(struct krb5_ctx *kctx, int offset, int len,
+ struct xdr_buf *buf, unsigned int *slack,
+ unsigned int *align);
+
+/*
+ * Implementation internal functions
+ */
+
+/* Key Derivation Functions */
+
+int krb5_derive_key_v2(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *label,
+ gfp_t gfp_mask);
+
+int krb5_kdf_hmac_sha2(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *in_constant,
+ gfp_t gfp_mask);
+
+int krb5_kdf_feedback_cmac(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *in_constant,
+ gfp_t gfp_mask);
+
+/**
+ * krb5_derive_key - Derive a subkey from a protocol key
+ * @kctx: Kerberos 5 context
+ * @inkey: base protocol key
+ * @outkey: OUT: derived key
+ * @usage: key usage value
+ * @seed: key usage seed (one octet)
+ * @gfp_mask: memory allocation control flags
+ *
+ * Caller sets @outkey->len to the desired length of the derived key.
+ *
+ * On success, returns 0 and fills in @outkey. A negative errno value
+ * is returned on failure.
+ */
+static inline int krb5_derive_key(struct krb5_ctx *kctx,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ u32 usage, u8 seed, gfp_t gfp_mask)
+{
+ const struct gss_krb5_enctype *gk5e = kctx->gk5e;
+ u8 label_data[GSS_KRB5_K5CLENGTH];
+ struct xdr_netobj label = {
+ .len = sizeof(label_data),
+ .data = label_data,
+ };
+ __be32 *p = (__be32 *)label_data;
+
+ *p = cpu_to_be32(usage);
+ label_data[4] = seed;
+ return gk5e->derive_key(gk5e, inkey, outkey, &label, gfp_mask);
+}
+
+void krb5_make_confounder(u8 *p, int conflen);
+
+u32 make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen,
+ struct xdr_buf *body, int body_offset, u8 *cksumkey,
+ unsigned int usage, struct xdr_netobj *cksumout);
+
+u32 gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen,
+ const struct xdr_buf *body, int body_offset,
+ struct xdr_netobj *cksumout);
+
+u32 krb5_encrypt(struct crypto_sync_skcipher *key, void *iv, void *in,
+ void *out, int length);
+
+u32 krb5_decrypt(struct crypto_sync_skcipher *key, void *iv, void *in,
+ void *out, int length);
+
+int xdr_extend_head(struct xdr_buf *buf, unsigned int base,
+ unsigned int shiftlen);
+
+int gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm,
+ struct xdr_buf *outbuf, int offset,
+ struct page **pages);
+
+int gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm,
+ struct xdr_buf *inbuf, int offset);
+
+u32 gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset,
+ struct xdr_buf *buf, struct page **pages);
+
+u32 gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len,
+ struct xdr_buf *buf, u32 *plainoffset, u32 *plainlen);
+
+u32 krb5_etm_encrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf,
+ struct page **pages);
+
+u32 krb5_etm_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len,
+ struct xdr_buf *buf, u32 *headskip, u32 *tailskip);
+
+#if IS_ENABLED(CONFIG_KUNIT)
+void krb5_nfold(u32 inbits, const u8 *in, u32 outbits, u8 *out);
+const struct gss_krb5_enctype *gss_krb5_lookup_enctype(u32 etype);
+int krb5_cbc_cts_encrypt(struct crypto_sync_skcipher *cts_tfm,
+ struct crypto_sync_skcipher *cbc_tfm, u32 offset,
+ struct xdr_buf *buf, struct page **pages,
+ u8 *iv, unsigned int ivsize);
+int krb5_cbc_cts_decrypt(struct crypto_sync_skcipher *cts_tfm,
+ struct crypto_sync_skcipher *cbc_tfm,
+ u32 offset, struct xdr_buf *buf);
+u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher,
+ struct crypto_ahash *tfm, const struct xdr_buf *body,
+ int body_offset, struct xdr_netobj *cksumout);
+#endif
+
+#endif /* _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H */
diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c
new file mode 100644
index 0000000000..06d8ee0db0
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_keys.c
@@ -0,0 +1,546 @@
+/*
+ * COPYRIGHT (c) 2008
+ * The Regents of the University of Michigan
+ * ALL RIGHTS RESERVED
+ *
+ * Permission is granted to use, copy, create derivative works
+ * and redistribute this software and such derivative works
+ * for any purpose, so long as the name of The University of
+ * Michigan is not used in any advertising or publicity
+ * pertaining to the use of distribution of this software
+ * without specific, written prior authorization. If the
+ * above copyright notice or any other identification of the
+ * University of Michigan is included in any copy of any
+ * portion of this software, then the disclaimer below must
+ * also be included.
+ *
+ * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION
+ * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY
+ * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF
+ * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE
+ * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE
+ * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR
+ * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING
+ * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN
+ * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF
+ * SUCH DAMAGES.
+ */
+
+/*
+ * Copyright (C) 1998 by the FundsXpress, INC.
+ *
+ * All rights reserved.
+ *
+ * Export of this software from the United States of America may require
+ * a specific license from the United States Government. It is the
+ * responsibility of any person or organization contemplating export to
+ * obtain such a license before exporting.
+ *
+ * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
+ * distribute this software and its documentation for any purpose and
+ * without fee is hereby granted, provided that the above copyright
+ * notice appear in all copies and that both that copyright notice and
+ * this permission notice appear in supporting documentation, and that
+ * the name of FundsXpress. not be used in advertising or publicity pertaining
+ * to distribution of the software without specific, written prior
+ * permission. FundsXpress makes no representations about the suitability of
+ * this software for any purpose. It is provided "as is" without express
+ * or implied warranty.
+ *
+ * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
+ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+#include <crypto/skcipher.h>
+#include <linux/err.h>
+#include <linux/types.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/lcm.h>
+#include <crypto/hash.h>
+#include <kunit/visibility.h>
+
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+/**
+ * krb5_nfold - n-fold function
+ * @inbits: number of bits in @in
+ * @in: buffer containing input to fold
+ * @outbits: number of bits in the output buffer
+ * @out: buffer to hold the result
+ *
+ * This is the n-fold function as described in rfc3961, sec 5.1
+ * Taken from MIT Kerberos and modified.
+ */
+VISIBLE_IF_KUNIT
+void krb5_nfold(u32 inbits, const u8 *in, u32 outbits, u8 *out)
+{
+ unsigned long ulcm;
+ int byte, i, msbit;
+
+ /* the code below is more readable if I make these bytes
+ instead of bits */
+
+ inbits >>= 3;
+ outbits >>= 3;
+
+ /* first compute lcm(n,k) */
+ ulcm = lcm(inbits, outbits);
+
+ /* now do the real work */
+
+ memset(out, 0, outbits);
+ byte = 0;
+
+ /* this will end up cycling through k lcm(k,n)/k times, which
+ is correct */
+ for (i = ulcm-1; i >= 0; i--) {
+ /* compute the msbit in k which gets added into this byte */
+ msbit = (
+ /* first, start with the msbit in the first,
+ * unrotated byte */
+ ((inbits << 3) - 1)
+ /* then, for each byte, shift to the right
+ * for each repetition */
+ + (((inbits << 3) + 13) * (i/inbits))
+ /* last, pick out the correct byte within
+ * that shifted repetition */
+ + ((inbits - (i % inbits)) << 3)
+ ) % (inbits << 3);
+
+ /* pull out the byte value itself */
+ byte += (((in[((inbits - 1) - (msbit >> 3)) % inbits] << 8)|
+ (in[((inbits) - (msbit >> 3)) % inbits]))
+ >> ((msbit & 7) + 1)) & 0xff;
+
+ /* do the addition */
+ byte += out[i % outbits];
+ out[i % outbits] = byte & 0xff;
+
+ /* keep around the carry bit, if any */
+ byte >>= 8;
+
+ }
+
+ /* if there's a carry bit left over, add it back in */
+ if (byte) {
+ for (i = outbits - 1; i >= 0; i--) {
+ /* do the addition */
+ byte += out[i];
+ out[i] = byte & 0xff;
+
+ /* keep around the carry bit, if any */
+ byte >>= 8;
+ }
+ }
+}
+EXPORT_SYMBOL_IF_KUNIT(krb5_nfold);
+
+/*
+ * This is the DK (derive_key) function as described in rfc3961, sec 5.1
+ * Taken from MIT Kerberos and modified.
+ */
+static int krb5_DK(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey, u8 *rawkey,
+ const struct xdr_netobj *in_constant, gfp_t gfp_mask)
+{
+ size_t blocksize, keybytes, keylength, n;
+ unsigned char *inblockdata, *outblockdata;
+ struct xdr_netobj inblock, outblock;
+ struct crypto_sync_skcipher *cipher;
+ int ret = -EINVAL;
+
+ keybytes = gk5e->keybytes;
+ keylength = gk5e->keylength;
+
+ if (inkey->len != keylength)
+ goto err_return;
+
+ cipher = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0);
+ if (IS_ERR(cipher))
+ goto err_return;
+ blocksize = crypto_sync_skcipher_blocksize(cipher);
+ if (crypto_sync_skcipher_setkey(cipher, inkey->data, inkey->len))
+ goto err_return;
+
+ ret = -ENOMEM;
+ inblockdata = kmalloc(blocksize, gfp_mask);
+ if (inblockdata == NULL)
+ goto err_free_cipher;
+
+ outblockdata = kmalloc(blocksize, gfp_mask);
+ if (outblockdata == NULL)
+ goto err_free_in;
+
+ inblock.data = (char *) inblockdata;
+ inblock.len = blocksize;
+
+ outblock.data = (char *) outblockdata;
+ outblock.len = blocksize;
+
+ /* initialize the input block */
+
+ if (in_constant->len == inblock.len) {
+ memcpy(inblock.data, in_constant->data, inblock.len);
+ } else {
+ krb5_nfold(in_constant->len * 8, in_constant->data,
+ inblock.len * 8, inblock.data);
+ }
+
+ /* loop encrypting the blocks until enough key bytes are generated */
+
+ n = 0;
+ while (n < keybytes) {
+ krb5_encrypt(cipher, NULL, inblock.data, outblock.data,
+ inblock.len);
+
+ if ((keybytes - n) <= outblock.len) {
+ memcpy(rawkey + n, outblock.data, (keybytes - n));
+ break;
+ }
+
+ memcpy(rawkey + n, outblock.data, outblock.len);
+ memcpy(inblock.data, outblock.data, outblock.len);
+ n += outblock.len;
+ }
+
+ ret = 0;
+
+ kfree_sensitive(outblockdata);
+err_free_in:
+ kfree_sensitive(inblockdata);
+err_free_cipher:
+ crypto_free_sync_skcipher(cipher);
+err_return:
+ return ret;
+}
+
+/*
+ * This is the identity function, with some sanity checking.
+ */
+static int krb5_random_to_key_v2(const struct gss_krb5_enctype *gk5e,
+ struct xdr_netobj *randombits,
+ struct xdr_netobj *key)
+{
+ int ret = -EINVAL;
+
+ if (key->len != 16 && key->len != 32) {
+ dprintk("%s: key->len is %d\n", __func__, key->len);
+ goto err_out;
+ }
+ if (randombits->len != 16 && randombits->len != 32) {
+ dprintk("%s: randombits->len is %d\n",
+ __func__, randombits->len);
+ goto err_out;
+ }
+ if (randombits->len != key->len) {
+ dprintk("%s: randombits->len is %d, key->len is %d\n",
+ __func__, randombits->len, key->len);
+ goto err_out;
+ }
+ memcpy(key->data, randombits->data, key->len);
+ ret = 0;
+err_out:
+ return ret;
+}
+
+/**
+ * krb5_derive_key_v2 - Derive a subkey for an RFC 3962 enctype
+ * @gk5e: Kerberos 5 enctype profile
+ * @inkey: base protocol key
+ * @outkey: OUT: derived key
+ * @label: subkey usage label
+ * @gfp_mask: memory allocation control flags
+ *
+ * Caller sets @outkey->len to the desired length of the derived key.
+ *
+ * On success, returns 0 and fills in @outkey. A negative errno value
+ * is returned on failure.
+ */
+int krb5_derive_key_v2(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *label,
+ gfp_t gfp_mask)
+{
+ struct xdr_netobj inblock;
+ int ret;
+
+ inblock.len = gk5e->keybytes;
+ inblock.data = kmalloc(inblock.len, gfp_mask);
+ if (!inblock.data)
+ return -ENOMEM;
+
+ ret = krb5_DK(gk5e, inkey, inblock.data, label, gfp_mask);
+ if (!ret)
+ ret = krb5_random_to_key_v2(gk5e, &inblock, outkey);
+
+ kfree_sensitive(inblock.data);
+ return ret;
+}
+
+/*
+ * K(i) = CMAC(key, K(i-1) | i | constant | 0x00 | k)
+ *
+ * i: A block counter is used with a length of 4 bytes, represented
+ * in big-endian order.
+ *
+ * constant: The label input to the KDF is the usage constant supplied
+ * to the key derivation function
+ *
+ * k: The length of the output key in bits, represented as a 4-byte
+ * string in big-endian order.
+ *
+ * Caller fills in K(i-1) in @step, and receives the result K(i)
+ * in the same buffer.
+ */
+static int
+krb5_cmac_Ki(struct crypto_shash *tfm, const struct xdr_netobj *constant,
+ u32 outlen, u32 count, struct xdr_netobj *step)
+{
+ __be32 k = cpu_to_be32(outlen * 8);
+ SHASH_DESC_ON_STACK(desc, tfm);
+ __be32 i = cpu_to_be32(count);
+ u8 zero = 0;
+ int ret;
+
+ desc->tfm = tfm;
+ ret = crypto_shash_init(desc);
+ if (ret)
+ goto out_err;
+
+ ret = crypto_shash_update(desc, step->data, step->len);
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, (u8 *)&i, sizeof(i));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, constant->data, constant->len);
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, &zero, sizeof(zero));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, (u8 *)&k, sizeof(k));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_final(desc, step->data);
+ if (ret)
+ goto out_err;
+
+out_err:
+ shash_desc_zero(desc);
+ return ret;
+}
+
+/**
+ * krb5_kdf_feedback_cmac - Derive a subkey for a Camellia/CMAC-based enctype
+ * @gk5e: Kerberos 5 enctype parameters
+ * @inkey: base protocol key
+ * @outkey: OUT: derived key
+ * @constant: subkey usage label
+ * @gfp_mask: memory allocation control flags
+ *
+ * RFC 6803 Section 3:
+ *
+ * "We use a key derivation function from the family specified in
+ * [SP800-108], Section 5.2, 'KDF in Feedback Mode'."
+ *
+ * n = ceiling(k / 128)
+ * K(0) = zeros
+ * K(i) = CMAC(key, K(i-1) | i | constant | 0x00 | k)
+ * DR(key, constant) = k-truncate(K(1) | K(2) | ... | K(n))
+ * KDF-FEEDBACK-CMAC(key, constant) = random-to-key(DR(key, constant))
+ *
+ * Caller sets @outkey->len to the desired length of the derived key (k).
+ *
+ * On success, returns 0 and fills in @outkey. A negative errno value
+ * is returned on failure.
+ */
+int
+krb5_kdf_feedback_cmac(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *constant,
+ gfp_t gfp_mask)
+{
+ struct xdr_netobj step = { .data = NULL };
+ struct xdr_netobj DR = { .data = NULL };
+ unsigned int blocksize, offset;
+ struct crypto_shash *tfm;
+ int n, count, ret;
+
+ /*
+ * This implementation assumes the CMAC used for an enctype's
+ * key derivation is the same as the CMAC used for its
+ * checksumming. This happens to be true for enctypes that
+ * are currently supported by this implementation.
+ */
+ tfm = crypto_alloc_shash(gk5e->cksum_name, 0, 0);
+ if (IS_ERR(tfm)) {
+ ret = PTR_ERR(tfm);
+ goto out;
+ }
+ ret = crypto_shash_setkey(tfm, inkey->data, inkey->len);
+ if (ret)
+ goto out_free_tfm;
+
+ blocksize = crypto_shash_digestsize(tfm);
+ n = (outkey->len + blocksize - 1) / blocksize;
+
+ /* K(0) is all zeroes */
+ ret = -ENOMEM;
+ step.len = blocksize;
+ step.data = kzalloc(step.len, gfp_mask);
+ if (!step.data)
+ goto out_free_tfm;
+
+ DR.len = blocksize * n;
+ DR.data = kmalloc(DR.len, gfp_mask);
+ if (!DR.data)
+ goto out_free_tfm;
+
+ /* XXX: Does not handle partial-block key sizes */
+ for (offset = 0, count = 1; count <= n; count++) {
+ ret = krb5_cmac_Ki(tfm, constant, outkey->len, count, &step);
+ if (ret)
+ goto out_free_tfm;
+
+ memcpy(DR.data + offset, step.data, blocksize);
+ offset += blocksize;
+ }
+
+ /* k-truncate and random-to-key */
+ memcpy(outkey->data, DR.data, outkey->len);
+ ret = 0;
+
+out_free_tfm:
+ crypto_free_shash(tfm);
+out:
+ kfree_sensitive(step.data);
+ kfree_sensitive(DR.data);
+ return ret;
+}
+
+/*
+ * K1 = HMAC-SHA(key, 0x00000001 | label | 0x00 | k)
+ *
+ * key: The source of entropy from which subsequent keys are derived.
+ *
+ * label: An octet string describing the intended usage of the
+ * derived key.
+ *
+ * k: Length in bits of the key to be outputted, expressed in
+ * big-endian binary representation in 4 bytes.
+ */
+static int
+krb5_hmac_K1(struct crypto_shash *tfm, const struct xdr_netobj *label,
+ u32 outlen, struct xdr_netobj *K1)
+{
+ __be32 k = cpu_to_be32(outlen * 8);
+ SHASH_DESC_ON_STACK(desc, tfm);
+ __be32 one = cpu_to_be32(1);
+ u8 zero = 0;
+ int ret;
+
+ desc->tfm = tfm;
+ ret = crypto_shash_init(desc);
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, (u8 *)&one, sizeof(one));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, label->data, label->len);
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, &zero, sizeof(zero));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_update(desc, (u8 *)&k, sizeof(k));
+ if (ret)
+ goto out_err;
+ ret = crypto_shash_final(desc, K1->data);
+ if (ret)
+ goto out_err;
+
+out_err:
+ shash_desc_zero(desc);
+ return ret;
+}
+
+/**
+ * krb5_kdf_hmac_sha2 - Derive a subkey for an AES/SHA2-based enctype
+ * @gk5e: Kerberos 5 enctype policy parameters
+ * @inkey: base protocol key
+ * @outkey: OUT: derived key
+ * @label: subkey usage label
+ * @gfp_mask: memory allocation control flags
+ *
+ * RFC 8009 Section 3:
+ *
+ * "We use a key derivation function from Section 5.1 of [SP800-108],
+ * which uses the HMAC algorithm as the PRF."
+ *
+ * function KDF-HMAC-SHA2(key, label, [context,] k):
+ * k-truncate(K1)
+ *
+ * Caller sets @outkey->len to the desired length of the derived key.
+ *
+ * On success, returns 0 and fills in @outkey. A negative errno value
+ * is returned on failure.
+ */
+int
+krb5_kdf_hmac_sha2(const struct gss_krb5_enctype *gk5e,
+ const struct xdr_netobj *inkey,
+ struct xdr_netobj *outkey,
+ const struct xdr_netobj *label,
+ gfp_t gfp_mask)
+{
+ struct crypto_shash *tfm;
+ struct xdr_netobj K1 = {
+ .data = NULL,
+ };
+ int ret;
+
+ /*
+ * This implementation assumes the HMAC used for an enctype's
+ * key derivation is the same as the HMAC used for its
+ * checksumming. This happens to be true for enctypes that
+ * are currently supported by this implementation.
+ */
+ tfm = crypto_alloc_shash(gk5e->cksum_name, 0, 0);
+ if (IS_ERR(tfm)) {
+ ret = PTR_ERR(tfm);
+ goto out;
+ }
+ ret = crypto_shash_setkey(tfm, inkey->data, inkey->len);
+ if (ret)
+ goto out_free_tfm;
+
+ K1.len = crypto_shash_digestsize(tfm);
+ K1.data = kmalloc(K1.len, gfp_mask);
+ if (!K1.data) {
+ ret = -ENOMEM;
+ goto out_free_tfm;
+ }
+
+ ret = krb5_hmac_K1(tfm, label, outkey->len, &K1);
+ if (ret)
+ goto out_free_tfm;
+
+ /* k-truncate and random-to-key */
+ memcpy(outkey->data, K1.data, outkey->len);
+
+out_free_tfm:
+ kfree_sensitive(K1.data);
+ crypto_free_shash(tfm);
+out:
+ return ret;
+}
diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c
new file mode 100644
index 0000000000..e31cfdf7ea
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_mech.c
@@ -0,0 +1,655 @@
+// SPDX-License-Identifier: BSD-3-Clause
+/*
+ * linux/net/sunrpc/gss_krb5_mech.c
+ *
+ * Copyright (c) 2001-2008 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Andy Adamson <andros@umich.edu>
+ * J. Bruce Fields <bfields@umich.edu>
+ */
+
+#include <crypto/hash.h>
+#include <crypto/skcipher.h>
+#include <linux/err.h>
+#include <linux/module.h>
+#include <linux/init.h>
+#include <linux/types.h>
+#include <linux/slab.h>
+#include <linux/sunrpc/auth.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/sunrpc/xdr.h>
+#include <kunit/visibility.h>
+
+#include "auth_gss_internal.h"
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+static struct gss_api_mech gss_kerberos_mech;
+
+static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = {
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1)
+ /*
+ * AES-128 with SHA-1 (RFC 3962)
+ */
+ {
+ .etype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .ctype = CKSUMTYPE_HMAC_SHA1_96_AES128,
+ .name = "aes128-cts",
+ .encrypt_name = "cts(cbc(aes))",
+ .aux_cipher = "cbc(aes)",
+ .cksum_name = "hmac(sha1)",
+ .derive_key = krb5_derive_key_v2,
+ .encrypt = gss_krb5_aes_encrypt,
+ .decrypt = gss_krb5_aes_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+
+ .signalg = -1,
+ .sealalg = -1,
+ .keybytes = 16,
+ .keylength = BITS2OCTETS(128),
+ .Kc_length = BITS2OCTETS(128),
+ .Ke_length = BITS2OCTETS(128),
+ .Ki_length = BITS2OCTETS(128),
+ .cksumlength = BITS2OCTETS(96),
+ .keyed_cksum = 1,
+ },
+ /*
+ * AES-256 with SHA-1 (RFC 3962)
+ */
+ {
+ .etype = ENCTYPE_AES256_CTS_HMAC_SHA1_96,
+ .ctype = CKSUMTYPE_HMAC_SHA1_96_AES256,
+ .name = "aes256-cts",
+ .encrypt_name = "cts(cbc(aes))",
+ .aux_cipher = "cbc(aes)",
+ .cksum_name = "hmac(sha1)",
+ .derive_key = krb5_derive_key_v2,
+ .encrypt = gss_krb5_aes_encrypt,
+ .decrypt = gss_krb5_aes_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+
+ .signalg = -1,
+ .sealalg = -1,
+ .keybytes = 32,
+ .keylength = BITS2OCTETS(256),
+ .Kc_length = BITS2OCTETS(256),
+ .Ke_length = BITS2OCTETS(256),
+ .Ki_length = BITS2OCTETS(256),
+ .cksumlength = BITS2OCTETS(96),
+ .keyed_cksum = 1,
+ },
+#endif
+
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA)
+ /*
+ * Camellia-128 with CMAC (RFC 6803)
+ */
+ {
+ .etype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .ctype = CKSUMTYPE_CMAC_CAMELLIA128,
+ .name = "camellia128-cts-cmac",
+ .encrypt_name = "cts(cbc(camellia))",
+ .aux_cipher = "cbc(camellia)",
+ .cksum_name = "cmac(camellia)",
+ .cksumlength = BITS2OCTETS(128),
+ .keyed_cksum = 1,
+ .keylength = BITS2OCTETS(128),
+ .Kc_length = BITS2OCTETS(128),
+ .Ke_length = BITS2OCTETS(128),
+ .Ki_length = BITS2OCTETS(128),
+
+ .derive_key = krb5_kdf_feedback_cmac,
+ .encrypt = gss_krb5_aes_encrypt,
+ .decrypt = gss_krb5_aes_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+ },
+ /*
+ * Camellia-256 with CMAC (RFC 6803)
+ */
+ {
+ .etype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .ctype = CKSUMTYPE_CMAC_CAMELLIA256,
+ .name = "camellia256-cts-cmac",
+ .encrypt_name = "cts(cbc(camellia))",
+ .aux_cipher = "cbc(camellia)",
+ .cksum_name = "cmac(camellia)",
+ .cksumlength = BITS2OCTETS(128),
+ .keyed_cksum = 1,
+ .keylength = BITS2OCTETS(256),
+ .Kc_length = BITS2OCTETS(256),
+ .Ke_length = BITS2OCTETS(256),
+ .Ki_length = BITS2OCTETS(256),
+
+ .derive_key = krb5_kdf_feedback_cmac,
+ .encrypt = gss_krb5_aes_encrypt,
+ .decrypt = gss_krb5_aes_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+ },
+#endif
+
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2)
+ /*
+ * AES-128 with SHA-256 (RFC 8009)
+ */
+ {
+ .etype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .ctype = CKSUMTYPE_HMAC_SHA256_128_AES128,
+ .name = "aes128-cts-hmac-sha256-128",
+ .encrypt_name = "cts(cbc(aes))",
+ .aux_cipher = "cbc(aes)",
+ .cksum_name = "hmac(sha256)",
+ .cksumlength = BITS2OCTETS(128),
+ .keyed_cksum = 1,
+ .keylength = BITS2OCTETS(128),
+ .Kc_length = BITS2OCTETS(128),
+ .Ke_length = BITS2OCTETS(128),
+ .Ki_length = BITS2OCTETS(128),
+
+ .derive_key = krb5_kdf_hmac_sha2,
+ .encrypt = krb5_etm_encrypt,
+ .decrypt = krb5_etm_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+ },
+ /*
+ * AES-256 with SHA-384 (RFC 8009)
+ */
+ {
+ .etype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .ctype = CKSUMTYPE_HMAC_SHA384_192_AES256,
+ .name = "aes256-cts-hmac-sha384-192",
+ .encrypt_name = "cts(cbc(aes))",
+ .aux_cipher = "cbc(aes)",
+ .cksum_name = "hmac(sha384)",
+ .cksumlength = BITS2OCTETS(192),
+ .keyed_cksum = 1,
+ .keylength = BITS2OCTETS(256),
+ .Kc_length = BITS2OCTETS(192),
+ .Ke_length = BITS2OCTETS(256),
+ .Ki_length = BITS2OCTETS(192),
+
+ .derive_key = krb5_kdf_hmac_sha2,
+ .encrypt = krb5_etm_encrypt,
+ .decrypt = krb5_etm_decrypt,
+
+ .get_mic = gss_krb5_get_mic_v2,
+ .verify_mic = gss_krb5_verify_mic_v2,
+ .wrap = gss_krb5_wrap_v2,
+ .unwrap = gss_krb5_unwrap_v2,
+ },
+#endif
+};
+
+/*
+ * The list of advertised enctypes is specified in order of most
+ * preferred to least.
+ */
+static char gss_krb5_enctype_priority_list[64];
+
+static void gss_krb5_prepare_enctype_priority_list(void)
+{
+ static const u32 gss_krb5_enctypes[] = {
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2)
+ ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+#endif
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA)
+ ENCTYPE_CAMELLIA256_CTS_CMAC,
+ ENCTYPE_CAMELLIA128_CTS_CMAC,
+#endif
+#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1)
+ ENCTYPE_AES256_CTS_HMAC_SHA1_96,
+ ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+#endif
+ };
+ size_t total, i;
+ char buf[16];
+ char *sep;
+ int n;
+
+ sep = "";
+ gss_krb5_enctype_priority_list[0] = '\0';
+ for (total = 0, i = 0; i < ARRAY_SIZE(gss_krb5_enctypes); i++) {
+ n = sprintf(buf, "%s%u", sep, gss_krb5_enctypes[i]);
+ if (n < 0)
+ break;
+ if (total + n >= sizeof(gss_krb5_enctype_priority_list))
+ break;
+ strcat(gss_krb5_enctype_priority_list, buf);
+ sep = ",";
+ total += n;
+ }
+}
+
+/**
+ * gss_krb5_lookup_enctype - Retrieve profile information for a given enctype
+ * @etype: ENCTYPE value
+ *
+ * Returns a pointer to a gss_krb5_enctype structure, or NULL if no
+ * matching etype is found.
+ */
+VISIBLE_IF_KUNIT
+const struct gss_krb5_enctype *gss_krb5_lookup_enctype(u32 etype)
+{
+ size_t i;
+
+ for (i = 0; i < ARRAY_SIZE(supported_gss_krb5_enctypes); i++)
+ if (supported_gss_krb5_enctypes[i].etype == etype)
+ return &supported_gss_krb5_enctypes[i];
+ return NULL;
+}
+EXPORT_SYMBOL_IF_KUNIT(gss_krb5_lookup_enctype);
+
+static struct crypto_sync_skcipher *
+gss_krb5_alloc_cipher_v2(const char *cname, const struct xdr_netobj *key)
+{
+ struct crypto_sync_skcipher *tfm;
+
+ tfm = crypto_alloc_sync_skcipher(cname, 0, 0);
+ if (IS_ERR(tfm))
+ return NULL;
+ if (crypto_sync_skcipher_setkey(tfm, key->data, key->len)) {
+ crypto_free_sync_skcipher(tfm);
+ return NULL;
+ }
+ return tfm;
+}
+
+static struct crypto_ahash *
+gss_krb5_alloc_hash_v2(struct krb5_ctx *kctx, const struct xdr_netobj *key)
+{
+ struct crypto_ahash *tfm;
+
+ tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC);
+ if (IS_ERR(tfm))
+ return NULL;
+ if (crypto_ahash_setkey(tfm, key->data, key->len)) {
+ crypto_free_ahash(tfm);
+ return NULL;
+ }
+ return tfm;
+}
+
+static int
+gss_krb5_import_ctx_v2(struct krb5_ctx *ctx, gfp_t gfp_mask)
+{
+ struct xdr_netobj keyin = {
+ .len = ctx->gk5e->keylength,
+ .data = ctx->Ksess,
+ };
+ struct xdr_netobj keyout;
+ int ret = -EINVAL;
+
+ keyout.data = kmalloc(GSS_KRB5_MAX_KEYLEN, gfp_mask);
+ if (!keyout.data)
+ return -ENOMEM;
+
+ /* initiator seal encryption */
+ keyout.len = ctx->gk5e->Ke_length;
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SEAL,
+ KEY_USAGE_SEED_ENCRYPTION, gfp_mask))
+ goto out;
+ ctx->initiator_enc = gss_krb5_alloc_cipher_v2(ctx->gk5e->encrypt_name,
+ &keyout);
+ if (ctx->initiator_enc == NULL)
+ goto out;
+ if (ctx->gk5e->aux_cipher) {
+ ctx->initiator_enc_aux =
+ gss_krb5_alloc_cipher_v2(ctx->gk5e->aux_cipher,
+ &keyout);
+ if (ctx->initiator_enc_aux == NULL)
+ goto out_free;
+ }
+
+ /* acceptor seal encryption */
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SEAL,
+ KEY_USAGE_SEED_ENCRYPTION, gfp_mask))
+ goto out_free;
+ ctx->acceptor_enc = gss_krb5_alloc_cipher_v2(ctx->gk5e->encrypt_name,
+ &keyout);
+ if (ctx->acceptor_enc == NULL)
+ goto out_free;
+ if (ctx->gk5e->aux_cipher) {
+ ctx->acceptor_enc_aux =
+ gss_krb5_alloc_cipher_v2(ctx->gk5e->aux_cipher,
+ &keyout);
+ if (ctx->acceptor_enc_aux == NULL)
+ goto out_free;
+ }
+
+ /* initiator sign checksum */
+ keyout.len = ctx->gk5e->Kc_length;
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SIGN,
+ KEY_USAGE_SEED_CHECKSUM, gfp_mask))
+ goto out_free;
+ ctx->initiator_sign = gss_krb5_alloc_hash_v2(ctx, &keyout);
+ if (ctx->initiator_sign == NULL)
+ goto out_free;
+
+ /* acceptor sign checksum */
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SIGN,
+ KEY_USAGE_SEED_CHECKSUM, gfp_mask))
+ goto out_free;
+ ctx->acceptor_sign = gss_krb5_alloc_hash_v2(ctx, &keyout);
+ if (ctx->acceptor_sign == NULL)
+ goto out_free;
+
+ /* initiator seal integrity */
+ keyout.len = ctx->gk5e->Ki_length;
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SEAL,
+ KEY_USAGE_SEED_INTEGRITY, gfp_mask))
+ goto out_free;
+ ctx->initiator_integ = gss_krb5_alloc_hash_v2(ctx, &keyout);
+ if (ctx->initiator_integ == NULL)
+ goto out_free;
+
+ /* acceptor seal integrity */
+ if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SEAL,
+ KEY_USAGE_SEED_INTEGRITY, gfp_mask))
+ goto out_free;
+ ctx->acceptor_integ = gss_krb5_alloc_hash_v2(ctx, &keyout);
+ if (ctx->acceptor_integ == NULL)
+ goto out_free;
+
+ ret = 0;
+out:
+ kfree_sensitive(keyout.data);
+ return ret;
+
+out_free:
+ crypto_free_ahash(ctx->acceptor_integ);
+ crypto_free_ahash(ctx->initiator_integ);
+ crypto_free_ahash(ctx->acceptor_sign);
+ crypto_free_ahash(ctx->initiator_sign);
+ crypto_free_sync_skcipher(ctx->acceptor_enc_aux);
+ crypto_free_sync_skcipher(ctx->acceptor_enc);
+ crypto_free_sync_skcipher(ctx->initiator_enc_aux);
+ crypto_free_sync_skcipher(ctx->initiator_enc);
+ goto out;
+}
+
+static int
+gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx,
+ gfp_t gfp_mask)
+{
+ u64 seq_send64;
+ int keylen;
+ u32 time32;
+
+ p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags));
+ if (IS_ERR(p))
+ goto out_err;
+ ctx->initiate = ctx->flags & KRB5_CTX_FLAG_INITIATOR;
+
+ p = simple_get_bytes(p, end, &time32, sizeof(time32));
+ if (IS_ERR(p))
+ goto out_err;
+ /* unsigned 32-bit time overflows in year 2106 */
+ ctx->endtime = (time64_t)time32;
+ p = simple_get_bytes(p, end, &seq_send64, sizeof(seq_send64));
+ if (IS_ERR(p))
+ goto out_err;
+ atomic64_set(&ctx->seq_send64, seq_send64);
+ /* set seq_send for use by "older" enctypes */
+ atomic_set(&ctx->seq_send, seq_send64);
+ if (seq_send64 != atomic_read(&ctx->seq_send)) {
+ dprintk("%s: seq_send64 %llx, seq_send %x overflow?\n", __func__,
+ seq_send64, atomic_read(&ctx->seq_send));
+ p = ERR_PTR(-EINVAL);
+ goto out_err;
+ }
+ p = simple_get_bytes(p, end, &ctx->enctype, sizeof(ctx->enctype));
+ if (IS_ERR(p))
+ goto out_err;
+ ctx->gk5e = gss_krb5_lookup_enctype(ctx->enctype);
+ if (ctx->gk5e == NULL) {
+ dprintk("gss_kerberos_mech: unsupported krb5 enctype %u\n",
+ ctx->enctype);
+ p = ERR_PTR(-EINVAL);
+ goto out_err;
+ }
+ keylen = ctx->gk5e->keylength;
+
+ p = simple_get_bytes(p, end, ctx->Ksess, keylen);
+ if (IS_ERR(p))
+ goto out_err;
+
+ if (p != end) {
+ p = ERR_PTR(-EINVAL);
+ goto out_err;
+ }
+
+ ctx->mech_used.data = kmemdup(gss_kerberos_mech.gm_oid.data,
+ gss_kerberos_mech.gm_oid.len, gfp_mask);
+ if (unlikely(ctx->mech_used.data == NULL)) {
+ p = ERR_PTR(-ENOMEM);
+ goto out_err;
+ }
+ ctx->mech_used.len = gss_kerberos_mech.gm_oid.len;
+
+ return gss_krb5_import_ctx_v2(ctx, gfp_mask);
+
+out_err:
+ return PTR_ERR(p);
+}
+
+static int
+gss_krb5_import_sec_context(const void *p, size_t len, struct gss_ctx *ctx_id,
+ time64_t *endtime, gfp_t gfp_mask)
+{
+ const void *end = (const void *)((const char *)p + len);
+ struct krb5_ctx *ctx;
+ int ret;
+
+ ctx = kzalloc(sizeof(*ctx), gfp_mask);
+ if (ctx == NULL)
+ return -ENOMEM;
+
+ ret = gss_import_v2_context(p, end, ctx, gfp_mask);
+ memzero_explicit(&ctx->Ksess, sizeof(ctx->Ksess));
+ if (ret) {
+ kfree(ctx);
+ return ret;
+ }
+
+ ctx_id->internal_ctx_id = ctx;
+ if (endtime)
+ *endtime = ctx->endtime;
+ return 0;
+}
+
+static void
+gss_krb5_delete_sec_context(void *internal_ctx)
+{
+ struct krb5_ctx *kctx = internal_ctx;
+
+ crypto_free_sync_skcipher(kctx->seq);
+ crypto_free_sync_skcipher(kctx->enc);
+ crypto_free_sync_skcipher(kctx->acceptor_enc);
+ crypto_free_sync_skcipher(kctx->initiator_enc);
+ crypto_free_sync_skcipher(kctx->acceptor_enc_aux);
+ crypto_free_sync_skcipher(kctx->initiator_enc_aux);
+ crypto_free_ahash(kctx->acceptor_sign);
+ crypto_free_ahash(kctx->initiator_sign);
+ crypto_free_ahash(kctx->acceptor_integ);
+ crypto_free_ahash(kctx->initiator_integ);
+ kfree(kctx->mech_used.data);
+ kfree(kctx);
+}
+
+/**
+ * gss_krb5_get_mic - get_mic for the Kerberos GSS mechanism
+ * @gctx: GSS context
+ * @text: plaintext to checksum
+ * @token: buffer into which to write the computed checksum
+ *
+ * Return values:
+ * %GSS_S_COMPLETE - success, and @token is filled in
+ * %GSS_S_FAILURE - checksum could not be generated
+ * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid
+ */
+static u32 gss_krb5_get_mic(struct gss_ctx *gctx, struct xdr_buf *text,
+ struct xdr_netobj *token)
+{
+ struct krb5_ctx *kctx = gctx->internal_ctx_id;
+
+ return kctx->gk5e->get_mic(kctx, text, token);
+}
+
+/**
+ * gss_krb5_verify_mic - verify_mic for the Kerberos GSS mechanism
+ * @gctx: GSS context
+ * @message_buffer: plaintext to check
+ * @read_token: received checksum to check
+ *
+ * Return values:
+ * %GSS_S_COMPLETE - computed and received checksums match
+ * %GSS_S_DEFECTIVE_TOKEN - received checksum is not valid
+ * %GSS_S_BAD_SIG - computed and received checksums do not match
+ * %GSS_S_FAILURE - received checksum could not be checked
+ * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid
+ */
+static u32 gss_krb5_verify_mic(struct gss_ctx *gctx,
+ struct xdr_buf *message_buffer,
+ struct xdr_netobj *read_token)
+{
+ struct krb5_ctx *kctx = gctx->internal_ctx_id;
+
+ return kctx->gk5e->verify_mic(kctx, message_buffer, read_token);
+}
+
+/**
+ * gss_krb5_wrap - gss_wrap for the Kerberos GSS mechanism
+ * @gctx: initialized GSS context
+ * @offset: byte offset in @buf to start writing the cipher text
+ * @buf: OUT: send buffer
+ * @pages: plaintext to wrap
+ *
+ * Return values:
+ * %GSS_S_COMPLETE - success, @buf has been updated
+ * %GSS_S_FAILURE - @buf could not be wrapped
+ * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid
+ */
+static u32 gss_krb5_wrap(struct gss_ctx *gctx, int offset,
+ struct xdr_buf *buf, struct page **pages)
+{
+ struct krb5_ctx *kctx = gctx->internal_ctx_id;
+
+ return kctx->gk5e->wrap(kctx, offset, buf, pages);
+}
+
+/**
+ * gss_krb5_unwrap - gss_unwrap for the Kerberos GSS mechanism
+ * @gctx: initialized GSS context
+ * @offset: starting byte offset into @buf
+ * @len: size of ciphertext to unwrap
+ * @buf: ciphertext to unwrap
+ *
+ * Return values:
+ * %GSS_S_COMPLETE - success, @buf has been updated
+ * %GSS_S_DEFECTIVE_TOKEN - received blob is not valid
+ * %GSS_S_BAD_SIG - computed and received checksums do not match
+ * %GSS_S_FAILURE - @buf could not be unwrapped
+ * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid
+ */
+static u32 gss_krb5_unwrap(struct gss_ctx *gctx, int offset,
+ int len, struct xdr_buf *buf)
+{
+ struct krb5_ctx *kctx = gctx->internal_ctx_id;
+
+ return kctx->gk5e->unwrap(kctx, offset, len, buf,
+ &gctx->slack, &gctx->align);
+}
+
+static const struct gss_api_ops gss_kerberos_ops = {
+ .gss_import_sec_context = gss_krb5_import_sec_context,
+ .gss_get_mic = gss_krb5_get_mic,
+ .gss_verify_mic = gss_krb5_verify_mic,
+ .gss_wrap = gss_krb5_wrap,
+ .gss_unwrap = gss_krb5_unwrap,
+ .gss_delete_sec_context = gss_krb5_delete_sec_context,
+};
+
+static struct pf_desc gss_kerberos_pfs[] = {
+ [0] = {
+ .pseudoflavor = RPC_AUTH_GSS_KRB5,
+ .qop = GSS_C_QOP_DEFAULT,
+ .service = RPC_GSS_SVC_NONE,
+ .name = "krb5",
+ },
+ [1] = {
+ .pseudoflavor = RPC_AUTH_GSS_KRB5I,
+ .qop = GSS_C_QOP_DEFAULT,
+ .service = RPC_GSS_SVC_INTEGRITY,
+ .name = "krb5i",
+ .datatouch = true,
+ },
+ [2] = {
+ .pseudoflavor = RPC_AUTH_GSS_KRB5P,
+ .qop = GSS_C_QOP_DEFAULT,
+ .service = RPC_GSS_SVC_PRIVACY,
+ .name = "krb5p",
+ .datatouch = true,
+ },
+};
+
+MODULE_ALIAS("rpc-auth-gss-krb5");
+MODULE_ALIAS("rpc-auth-gss-krb5i");
+MODULE_ALIAS("rpc-auth-gss-krb5p");
+MODULE_ALIAS("rpc-auth-gss-390003");
+MODULE_ALIAS("rpc-auth-gss-390004");
+MODULE_ALIAS("rpc-auth-gss-390005");
+MODULE_ALIAS("rpc-auth-gss-1.2.840.113554.1.2.2");
+
+static struct gss_api_mech gss_kerberos_mech = {
+ .gm_name = "krb5",
+ .gm_owner = THIS_MODULE,
+ .gm_oid = { 9, "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02" },
+ .gm_ops = &gss_kerberos_ops,
+ .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs),
+ .gm_pfs = gss_kerberos_pfs,
+ .gm_upcall_enctypes = gss_krb5_enctype_priority_list,
+};
+
+static int __init init_kerberos_module(void)
+{
+ int status;
+
+ gss_krb5_prepare_enctype_priority_list();
+ status = gss_mech_register(&gss_kerberos_mech);
+ if (status)
+ printk("Failed to register kerberos gss mechanism!\n");
+ return status;
+}
+
+static void __exit cleanup_kerberos_module(void)
+{
+ gss_mech_unregister(&gss_kerberos_mech);
+}
+
+MODULE_LICENSE("GPL");
+module_init(init_kerberos_module);
+module_exit(cleanup_kerberos_module);
diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c
new file mode 100644
index 0000000000..ce540df9bc
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_seal.c
@@ -0,0 +1,133 @@
+/*
+ * linux/net/sunrpc/gss_krb5_seal.c
+ *
+ * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5seal.c
+ *
+ * Copyright (c) 2000-2008 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Andy Adamson <andros@umich.edu>
+ * J. Bruce Fields <bfields@umich.edu>
+ */
+
+/*
+ * Copyright 1993 by OpenVision Technologies, Inc.
+ *
+ * Permission to use, copy, modify, distribute, and sell this software
+ * and its documentation for any purpose is hereby granted without fee,
+ * provided that the above copyright notice appears in all copies and
+ * that both that copyright notice and this permission notice appear in
+ * supporting documentation, and that the name of OpenVision not be used
+ * in advertising or publicity pertaining to distribution of the software
+ * without specific, written prior permission. OpenVision makes no
+ * representations about the suitability of this software for any
+ * purpose. It is provided "as is" without express or implied warranty.
+ *
+ * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
+ * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO
+ * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR
+ * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF
+ * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
+ * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+ * PERFORMANCE OF THIS SOFTWARE.
+ */
+
+/*
+ * Copyright (C) 1998 by the FundsXpress, INC.
+ *
+ * All rights reserved.
+ *
+ * Export of this software from the United States of America may require
+ * a specific license from the United States Government. It is the
+ * responsibility of any person or organization contemplating export to
+ * obtain such a license before exporting.
+ *
+ * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
+ * distribute this software and its documentation for any purpose and
+ * without fee is hereby granted, provided that the above copyright
+ * notice appear in all copies and that both that copyright notice and
+ * this permission notice appear in supporting documentation, and that
+ * the name of FundsXpress. not be used in advertising or publicity pertaining
+ * to distribution of the software without specific, written prior
+ * permission. FundsXpress makes no representations about the suitability of
+ * this software for any purpose. It is provided "as is" without express
+ * or implied warranty.
+ *
+ * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
+ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+#include <linux/types.h>
+#include <linux/jiffies.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/random.h>
+#include <linux/crypto.h>
+#include <linux/atomic.h>
+
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+static void *
+setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token)
+{
+ u16 *ptr;
+ void *krb5_hdr;
+ u8 *p, flags = 0x00;
+
+ if ((ctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0)
+ flags |= 0x01;
+ if (ctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY)
+ flags |= 0x04;
+
+ /* Per rfc 4121, sec 4.2.6.1, there is no header,
+ * just start the token.
+ */
+ krb5_hdr = (u16 *)token->data;
+ ptr = krb5_hdr;
+
+ *ptr++ = KG2_TOK_MIC;
+ p = (u8 *)ptr;
+ *p++ = flags;
+ *p++ = 0xff;
+ ptr = (u16 *)p;
+ *ptr++ = 0xffff;
+ *ptr = 0xffff;
+
+ token->len = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength;
+ return krb5_hdr;
+}
+
+u32
+gss_krb5_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text,
+ struct xdr_netobj *token)
+{
+ struct crypto_ahash *tfm = ctx->initiate ?
+ ctx->initiator_sign : ctx->acceptor_sign;
+ struct xdr_netobj cksumobj = {
+ .len = ctx->gk5e->cksumlength,
+ };
+ __be64 seq_send_be64;
+ void *krb5_hdr;
+ time64_t now;
+
+ dprintk("RPC: %s\n", __func__);
+
+ krb5_hdr = setup_token_v2(ctx, token);
+
+ /* Set up the sequence number. Now 64-bits in clear
+ * text and w/o direction indicator */
+ seq_send_be64 = cpu_to_be64(atomic64_fetch_inc(&ctx->seq_send64));
+ memcpy(krb5_hdr + 8, (char *) &seq_send_be64, 8);
+
+ cksumobj.data = krb5_hdr + GSS_KRB5_TOK_HDR_LEN;
+ if (gss_krb5_checksum(tfm, krb5_hdr, GSS_KRB5_TOK_HDR_LEN,
+ text, 0, &cksumobj))
+ return GSS_S_FAILURE;
+
+ now = ktime_get_real_seconds();
+ return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE;
+}
diff --git a/net/sunrpc/auth_gss/gss_krb5_test.c b/net/sunrpc/auth_gss/gss_krb5_test.c
new file mode 100644
index 0000000000..85625e3f38
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_test.c
@@ -0,0 +1,1859 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2022 Oracle and/or its affiliates.
+ *
+ * KUnit test of SunRPC's GSS Kerberos mechanism. Subsystem
+ * name is "rpcsec_gss_krb5".
+ */
+
+#include <kunit/test.h>
+#include <kunit/visibility.h>
+
+#include <linux/kernel.h>
+#include <crypto/hash.h>
+
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/gss_krb5.h>
+
+#include "gss_krb5_internal.h"
+
+MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
+
+struct gss_krb5_test_param {
+ const char *desc;
+ u32 enctype;
+ u32 nfold;
+ u32 constant;
+ const struct xdr_netobj *base_key;
+ const struct xdr_netobj *Ke;
+ const struct xdr_netobj *usage;
+ const struct xdr_netobj *plaintext;
+ const struct xdr_netobj *confounder;
+ const struct xdr_netobj *expected_result;
+ const struct xdr_netobj *expected_hmac;
+ const struct xdr_netobj *next_iv;
+};
+
+static inline void gss_krb5_get_desc(const struct gss_krb5_test_param *param,
+ char *desc)
+{
+ strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
+}
+
+static void kdf_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_netobj derivedkey;
+ int err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ derivedkey.data = kunit_kzalloc(test, param->expected_result->len,
+ GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, derivedkey.data);
+ derivedkey.len = param->expected_result->len;
+
+ /* Act */
+ err = gk5e->derive_key(gk5e, param->base_key, &derivedkey,
+ param->usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ derivedkey.data, derivedkey.len), 0,
+ "key mismatch");
+}
+
+static void checksum_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ struct xdr_buf buf = {
+ .head[0].iov_len = param->plaintext->len,
+ .len = param->plaintext->len,
+ };
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_netobj Kc, checksum;
+ struct crypto_ahash *tfm;
+ int err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ Kc.len = gk5e->Kc_length;
+ Kc.data = kunit_kzalloc(test, Kc.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Kc.data);
+ err = gk5e->derive_key(gk5e, param->base_key, &Kc,
+ param->usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tfm);
+ err = crypto_ahash_setkey(tfm, Kc.data, Kc.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ buf.head[0].iov_base = kunit_kzalloc(test, buf.head[0].iov_len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, buf.head[0].iov_base);
+ memcpy(buf.head[0].iov_base, param->plaintext->data, buf.head[0].iov_len);
+
+ checksum.len = gk5e->cksumlength;
+ checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data);
+
+ /* Act */
+ err = gss_krb5_checksum(tfm, NULL, 0, &buf, 0, &checksum);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ checksum.data, checksum.len), 0,
+ "checksum mismatch");
+
+ crypto_free_ahash(tfm);
+}
+
+#define DEFINE_HEX_XDR_NETOBJ(name, hex_array...) \
+ static const u8 name ## _data[] = { hex_array }; \
+ static const struct xdr_netobj name = { \
+ .data = (u8 *)name##_data, \
+ .len = sizeof(name##_data), \
+ }
+
+#define DEFINE_STR_XDR_NETOBJ(name, string) \
+ static const u8 name ## _str[] = string; \
+ static const struct xdr_netobj name = { \
+ .data = (u8 *)name##_str, \
+ .len = sizeof(name##_str) - 1, \
+ }
+
+/*
+ * RFC 3961 Appendix A.1. n-fold
+ *
+ * The n-fold function is defined in section 5.1 of RFC 3961.
+ *
+ * This test material is copyright (C) The Internet Society (2005).
+ */
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test1_plaintext,
+ 0x30, 0x31, 0x32, 0x33, 0x34, 0x35
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test1_expected_result,
+ 0xbe, 0x07, 0x26, 0x31, 0x27, 0x6b, 0x19, 0x55
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test2_plaintext,
+ 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test2_expected_result,
+ 0x78, 0xa0, 0x7b, 0x6c, 0xaf, 0x85, 0xfa
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test3_plaintext,
+ 0x52, 0x6f, 0x75, 0x67, 0x68, 0x20, 0x43, 0x6f,
+ 0x6e, 0x73, 0x65, 0x6e, 0x73, 0x75, 0x73, 0x2c,
+ 0x20, 0x61, 0x6e, 0x64, 0x20, 0x52, 0x75, 0x6e,
+ 0x6e, 0x69, 0x6e, 0x67, 0x20, 0x43, 0x6f, 0x64,
+ 0x65
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test3_expected_result,
+ 0xbb, 0x6e, 0xd3, 0x08, 0x70, 0xb7, 0xf0, 0xe0
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test4_plaintext,
+ 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test4_expected_result,
+ 0x59, 0xe4, 0xa8, 0xca, 0x7c, 0x03, 0x85, 0xc3,
+ 0xc3, 0x7b, 0x3f, 0x6d, 0x20, 0x00, 0x24, 0x7c,
+ 0xb6, 0xe6, 0xbd, 0x5b, 0x3e
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test5_plaintext,
+ 0x4d, 0x41, 0x53, 0x53, 0x41, 0x43, 0x48, 0x56,
+ 0x53, 0x45, 0x54, 0x54, 0x53, 0x20, 0x49, 0x4e,
+ 0x53, 0x54, 0x49, 0x54, 0x56, 0x54, 0x45, 0x20,
+ 0x4f, 0x46, 0x20, 0x54, 0x45, 0x43, 0x48, 0x4e,
+ 0x4f, 0x4c, 0x4f, 0x47, 0x59
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test5_expected_result,
+ 0xdb, 0x3b, 0x0d, 0x8f, 0x0b, 0x06, 0x1e, 0x60,
+ 0x32, 0x82, 0xb3, 0x08, 0xa5, 0x08, 0x41, 0x22,
+ 0x9a, 0xd7, 0x98, 0xfa, 0xb9, 0x54, 0x0c, 0x1b
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test6_plaintext,
+ 0x51
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test6_expected_result,
+ 0x51, 0x8a, 0x54, 0xa2, 0x15, 0xa8, 0x45, 0x2a,
+ 0x51, 0x8a, 0x54, 0xa2, 0x15, 0xa8, 0x45, 0x2a,
+ 0x51, 0x8a, 0x54, 0xa2, 0x15
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test7_plaintext,
+ 0x62, 0x61
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test7_expected_result,
+ 0xfb, 0x25, 0xd5, 0x31, 0xae, 0x89, 0x74, 0x49,
+ 0x9f, 0x52, 0xfd, 0x92, 0xea, 0x98, 0x57, 0xc4,
+ 0xba, 0x24, 0xcf, 0x29, 0x7e
+);
+
+DEFINE_HEX_XDR_NETOBJ(nfold_test_kerberos,
+ 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test8_expected_result,
+ 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test9_expected_result,
+ 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73,
+ 0x7b, 0x9b, 0x5b, 0x2b, 0x93, 0x13, 0x2b, 0x93
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test10_expected_result,
+ 0x83, 0x72, 0xc2, 0x36, 0x34, 0x4e, 0x5f, 0x15,
+ 0x50, 0xcd, 0x07, 0x47, 0xe1, 0x5d, 0x62, 0xca,
+ 0x7a, 0x5a, 0x3b, 0xce, 0xa4
+);
+DEFINE_HEX_XDR_NETOBJ(nfold_test11_expected_result,
+ 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73,
+ 0x7b, 0x9b, 0x5b, 0x2b, 0x93, 0x13, 0x2b, 0x93,
+ 0x5c, 0x9b, 0xdc, 0xda, 0xd9, 0x5c, 0x98, 0x99,
+ 0xc4, 0xca, 0xe4, 0xde, 0xe6, 0xd6, 0xca, 0xe4
+);
+
+static const struct gss_krb5_test_param rfc3961_nfold_test_params[] = {
+ {
+ .desc = "64-fold(\"012345\")",
+ .nfold = 64,
+ .plaintext = &nfold_test1_plaintext,
+ .expected_result = &nfold_test1_expected_result,
+ },
+ {
+ .desc = "56-fold(\"password\")",
+ .nfold = 56,
+ .plaintext = &nfold_test2_plaintext,
+ .expected_result = &nfold_test2_expected_result,
+ },
+ {
+ .desc = "64-fold(\"Rough Consensus, and Running Code\")",
+ .nfold = 64,
+ .plaintext = &nfold_test3_plaintext,
+ .expected_result = &nfold_test3_expected_result,
+ },
+ {
+ .desc = "168-fold(\"password\")",
+ .nfold = 168,
+ .plaintext = &nfold_test4_plaintext,
+ .expected_result = &nfold_test4_expected_result,
+ },
+ {
+ .desc = "192-fold(\"MASSACHVSETTS INSTITVTE OF TECHNOLOGY\")",
+ .nfold = 192,
+ .plaintext = &nfold_test5_plaintext,
+ .expected_result = &nfold_test5_expected_result,
+ },
+ {
+ .desc = "168-fold(\"Q\")",
+ .nfold = 168,
+ .plaintext = &nfold_test6_plaintext,
+ .expected_result = &nfold_test6_expected_result,
+ },
+ {
+ .desc = "168-fold(\"ba\")",
+ .nfold = 168,
+ .plaintext = &nfold_test7_plaintext,
+ .expected_result = &nfold_test7_expected_result,
+ },
+ {
+ .desc = "64-fold(\"kerberos\")",
+ .nfold = 64,
+ .plaintext = &nfold_test_kerberos,
+ .expected_result = &nfold_test8_expected_result,
+ },
+ {
+ .desc = "128-fold(\"kerberos\")",
+ .nfold = 128,
+ .plaintext = &nfold_test_kerberos,
+ .expected_result = &nfold_test9_expected_result,
+ },
+ {
+ .desc = "168-fold(\"kerberos\")",
+ .nfold = 168,
+ .plaintext = &nfold_test_kerberos,
+ .expected_result = &nfold_test10_expected_result,
+ },
+ {
+ .desc = "256-fold(\"kerberos\")",
+ .nfold = 256,
+ .plaintext = &nfold_test_kerberos,
+ .expected_result = &nfold_test11_expected_result,
+ },
+};
+
+/* Creates the function rfc3961_nfold_gen_params */
+KUNIT_ARRAY_PARAM(rfc3961_nfold, rfc3961_nfold_test_params, gss_krb5_get_desc);
+
+static void rfc3961_nfold_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ u8 *result;
+
+ /* Arrange */
+ result = kunit_kzalloc(test, 4096, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, result);
+
+ /* Act */
+ krb5_nfold(param->plaintext->len * 8, param->plaintext->data,
+ param->expected_result->len * 8, result);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ result, param->expected_result->len), 0,
+ "result mismatch");
+}
+
+static struct kunit_case rfc3961_test_cases[] = {
+ {
+ .name = "RFC 3961 n-fold",
+ .run_case = rfc3961_nfold_case,
+ .generate_params = rfc3961_nfold_gen_params,
+ },
+ {}
+};
+
+static struct kunit_suite rfc3961_suite = {
+ .name = "RFC 3961 tests",
+ .test_cases = rfc3961_test_cases,
+};
+
+/*
+ * From RFC 3962 Appendix B: Sample Test Vectors
+ *
+ * Some test vectors for CBC with ciphertext stealing, using an
+ * initial vector of all-zero.
+ *
+ * This test material is copyright (C) The Internet Society (2005).
+ */
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_encryption_key,
+ 0x63, 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x20,
+ 0x74, 0x65, 0x72, 0x69, 0x79, 0x61, 0x6b, 0x69
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_expected_result,
+ 0xc6, 0x35, 0x35, 0x68, 0xf2, 0xbf, 0x8c, 0xb4,
+ 0xd8, 0xa5, 0x80, 0x36, 0x2d, 0xa7, 0xff, 0x7f,
+ 0x97
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_next_iv,
+ 0xc6, 0x35, 0x35, 0x68, 0xf2, 0xbf, 0x8c, 0xb4,
+ 0xd8, 0xa5, 0x80, 0x36, 0x2d, 0xa7, 0xff, 0x7f
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c,
+ 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_expected_result,
+ 0xfc, 0x00, 0x78, 0x3e, 0x0e, 0xfd, 0xb2, 0xc1,
+ 0xd4, 0x45, 0xd4, 0xc8, 0xef, 0xf7, 0xed, 0x22,
+ 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0,
+ 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_next_iv,
+ 0xfc, 0x00, 0x78, 0x3e, 0x0e, 0xfd, 0xb2, 0xc1,
+ 0xd4, 0x45, 0xd4, 0xc8, 0xef, 0xf7, 0xed, 0x22
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c,
+ 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_expected_result,
+ 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5,
+ 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8,
+ 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0,
+ 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_next_iv,
+ 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5,
+ 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c,
+ 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43,
+ 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20,
+ 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_expected_result,
+ 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0,
+ 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84,
+ 0xb3, 0xff, 0xfd, 0x94, 0x0c, 0x16, 0xa1, 0x8c,
+ 0x1b, 0x55, 0x49, 0xd2, 0xf8, 0x38, 0x02, 0x9e,
+ 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5,
+ 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_next_iv,
+ 0xb3, 0xff, 0xfd, 0x94, 0x0c, 0x16, 0xa1, 0x8c,
+ 0x1b, 0x55, 0x49, 0xd2, 0xf8, 0x38, 0x02, 0x9e
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c,
+ 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43,
+ 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20,
+ 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c, 0x20
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_expected_result,
+ 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0,
+ 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84,
+ 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0,
+ 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8,
+ 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5,
+ 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_next_iv,
+ 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0,
+ 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_plaintext,
+ 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20,
+ 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65,
+ 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c,
+ 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43,
+ 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20,
+ 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c, 0x20,
+ 0x61, 0x6e, 0x64, 0x20, 0x77, 0x6f, 0x6e, 0x74,
+ 0x6f, 0x6e, 0x20, 0x73, 0x6f, 0x75, 0x70, 0x2e
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_expected_result,
+ 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0,
+ 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84,
+ 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5,
+ 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8,
+ 0x48, 0x07, 0xef, 0xe8, 0x36, 0xee, 0x89, 0xa5,
+ 0x26, 0x73, 0x0d, 0xbc, 0x2f, 0x7b, 0xc8, 0x40,
+ 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0,
+ 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8
+);
+DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_next_iv,
+ 0x48, 0x07, 0xef, 0xe8, 0x36, 0xee, 0x89, 0xa5,
+ 0x26, 0x73, 0x0d, 0xbc, 0x2f, 0x7b, 0xc8, 0x40
+);
+
+static const struct gss_krb5_test_param rfc3962_encrypt_test_params[] = {
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 1",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test1_plaintext,
+ .expected_result = &rfc3962_enc_test1_expected_result,
+ .next_iv = &rfc3962_enc_test1_next_iv,
+ },
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 2",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test2_plaintext,
+ .expected_result = &rfc3962_enc_test2_expected_result,
+ .next_iv = &rfc3962_enc_test2_next_iv,
+ },
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 3",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test3_plaintext,
+ .expected_result = &rfc3962_enc_test3_expected_result,
+ .next_iv = &rfc3962_enc_test3_next_iv,
+ },
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 4",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test4_plaintext,
+ .expected_result = &rfc3962_enc_test4_expected_result,
+ .next_iv = &rfc3962_enc_test4_next_iv,
+ },
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 5",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test5_plaintext,
+ .expected_result = &rfc3962_enc_test5_expected_result,
+ .next_iv = &rfc3962_enc_test5_next_iv,
+ },
+ {
+ .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 6",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &rfc3962_enc_test6_plaintext,
+ .expected_result = &rfc3962_enc_test6_expected_result,
+ .next_iv = &rfc3962_enc_test6_next_iv,
+ },
+};
+
+/* Creates the function rfc3962_encrypt_gen_params */
+KUNIT_ARRAY_PARAM(rfc3962_encrypt, rfc3962_encrypt_test_params,
+ gss_krb5_get_desc);
+
+/*
+ * This tests the implementation of the encryption part of the mechanism.
+ * It does not apply a confounder or test the result of HMAC over the
+ * plaintext.
+ */
+static void rfc3962_encrypt_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ struct crypto_sync_skcipher *cts_tfm, *cbc_tfm;
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_buf buf;
+ void *iv, *text;
+ u32 err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm);
+ err = crypto_sync_skcipher_setkey(cbc_tfm, param->Ke->data, param->Ke->len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm);
+ err = crypto_sync_skcipher_setkey(cts_tfm, param->Ke->data, param->Ke->len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ iv = kunit_kzalloc(test, crypto_sync_skcipher_ivsize(cts_tfm), GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, iv);
+
+ text = kunit_kzalloc(test, param->plaintext->len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text);
+
+ memcpy(text, param->plaintext->data, param->plaintext->len);
+ memset(&buf, 0, sizeof(buf));
+ buf.head[0].iov_base = text;
+ buf.head[0].iov_len = param->plaintext->len;
+ buf.len = buf.head[0].iov_len;
+
+ /* Act */
+ err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL,
+ iv, crypto_sync_skcipher_ivsize(cts_tfm));
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ param->expected_result->len, buf.len,
+ "ciphertext length mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ text, param->expected_result->len), 0,
+ "ciphertext mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->next_iv->data, iv,
+ param->next_iv->len), 0,
+ "IV mismatch");
+
+ crypto_free_sync_skcipher(cts_tfm);
+ crypto_free_sync_skcipher(cbc_tfm);
+}
+
+static struct kunit_case rfc3962_test_cases[] = {
+ {
+ .name = "RFC 3962 encryption",
+ .run_case = rfc3962_encrypt_case,
+ .generate_params = rfc3962_encrypt_gen_params,
+ },
+ {}
+};
+
+static struct kunit_suite rfc3962_suite = {
+ .name = "RFC 3962 suite",
+ .test_cases = rfc3962_test_cases,
+};
+
+/*
+ * From RFC 6803 Section 10. Test vectors
+ *
+ * Sample results for key derivation
+ *
+ * Copyright (c) 2012 IETF Trust and the persons identified as the
+ * document authors. All rights reserved.
+ */
+
+DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_basekey,
+ 0x57, 0xd0, 0x29, 0x72, 0x98, 0xff, 0xd9, 0xd3,
+ 0x5d, 0xe5, 0xa4, 0x7f, 0xb4, 0xbd, 0xe2, 0x4b
+);
+DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Kc,
+ 0xd1, 0x55, 0x77, 0x5a, 0x20, 0x9d, 0x05, 0xf0,
+ 0x2b, 0x38, 0xd4, 0x2a, 0x38, 0x9e, 0x5a, 0x56
+);
+DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Ke,
+ 0x64, 0xdf, 0x83, 0xf8, 0x5a, 0x53, 0x2f, 0x17,
+ 0x57, 0x7d, 0x8c, 0x37, 0x03, 0x57, 0x96, 0xab
+);
+DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Ki,
+ 0x3e, 0x4f, 0xbd, 0xf3, 0x0f, 0xb8, 0x25, 0x9c,
+ 0x42, 0x5c, 0xb6, 0xc9, 0x6f, 0x1f, 0x46, 0x35
+);
+
+DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_basekey,
+ 0xb9, 0xd6, 0x82, 0x8b, 0x20, 0x56, 0xb7, 0xbe,
+ 0x65, 0x6d, 0x88, 0xa1, 0x23, 0xb1, 0xfa, 0xc6,
+ 0x82, 0x14, 0xac, 0x2b, 0x72, 0x7e, 0xcf, 0x5f,
+ 0x69, 0xaf, 0xe0, 0xc4, 0xdf, 0x2a, 0x6d, 0x2c
+);
+DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Kc,
+ 0xe4, 0x67, 0xf9, 0xa9, 0x55, 0x2b, 0xc7, 0xd3,
+ 0x15, 0x5a, 0x62, 0x20, 0xaf, 0x9c, 0x19, 0x22,
+ 0x0e, 0xee, 0xd4, 0xff, 0x78, 0xb0, 0xd1, 0xe6,
+ 0xa1, 0x54, 0x49, 0x91, 0x46, 0x1a, 0x9e, 0x50
+);
+DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Ke,
+ 0x41, 0x2a, 0xef, 0xc3, 0x62, 0xa7, 0x28, 0x5f,
+ 0xc3, 0x96, 0x6c, 0x6a, 0x51, 0x81, 0xe7, 0x60,
+ 0x5a, 0xe6, 0x75, 0x23, 0x5b, 0x6d, 0x54, 0x9f,
+ 0xbf, 0xc9, 0xab, 0x66, 0x30, 0xa4, 0xc6, 0x04
+);
+DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Ki,
+ 0xfa, 0x62, 0x4f, 0xa0, 0xe5, 0x23, 0x99, 0x3f,
+ 0xa3, 0x88, 0xae, 0xfd, 0xc6, 0x7e, 0x67, 0xeb,
+ 0xcd, 0x8c, 0x08, 0xe8, 0xa0, 0x24, 0x6b, 0x1d,
+ 0x73, 0xb0, 0xd1, 0xdd, 0x9f, 0xc5, 0x82, 0xb0
+);
+
+DEFINE_HEX_XDR_NETOBJ(usage_checksum,
+ 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_CHECKSUM
+);
+DEFINE_HEX_XDR_NETOBJ(usage_encryption,
+ 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_ENCRYPTION
+);
+DEFINE_HEX_XDR_NETOBJ(usage_integrity,
+ 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_INTEGRITY
+);
+
+static const struct gss_krb5_test_param rfc6803_kdf_test_params[] = {
+ {
+ .desc = "Derive Kc subkey for camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .base_key = &camellia128_cts_cmac_basekey,
+ .usage = &usage_checksum,
+ .expected_result = &camellia128_cts_cmac_Kc,
+ },
+ {
+ .desc = "Derive Ke subkey for camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .base_key = &camellia128_cts_cmac_basekey,
+ .usage = &usage_encryption,
+ .expected_result = &camellia128_cts_cmac_Ke,
+ },
+ {
+ .desc = "Derive Ki subkey for camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .base_key = &camellia128_cts_cmac_basekey,
+ .usage = &usage_integrity,
+ .expected_result = &camellia128_cts_cmac_Ki,
+ },
+ {
+ .desc = "Derive Kc subkey for camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .base_key = &camellia256_cts_cmac_basekey,
+ .usage = &usage_checksum,
+ .expected_result = &camellia256_cts_cmac_Kc,
+ },
+ {
+ .desc = "Derive Ke subkey for camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .base_key = &camellia256_cts_cmac_basekey,
+ .usage = &usage_encryption,
+ .expected_result = &camellia256_cts_cmac_Ke,
+ },
+ {
+ .desc = "Derive Ki subkey for camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .base_key = &camellia256_cts_cmac_basekey,
+ .usage = &usage_integrity,
+ .expected_result = &camellia256_cts_cmac_Ki,
+ },
+};
+
+/* Creates the function rfc6803_kdf_gen_params */
+KUNIT_ARRAY_PARAM(rfc6803_kdf, rfc6803_kdf_test_params, gss_krb5_get_desc);
+
+/*
+ * From RFC 6803 Section 10. Test vectors
+ *
+ * Sample checksums.
+ *
+ * Copyright (c) 2012 IETF Trust and the persons identified as the
+ * document authors. All rights reserved.
+ *
+ * XXX: These tests are likely to fail on EBCDIC or Unicode platforms.
+ */
+DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test1_plaintext,
+ "abcdefghijk");
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_basekey,
+ 0x1d, 0xc4, 0x6a, 0x8d, 0x76, 0x3f, 0x4f, 0x93,
+ 0x74, 0x2b, 0xcb, 0xa3, 0x38, 0x75, 0x76, 0xc3
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_usage,
+ 0x00, 0x00, 0x00, 0x07, KEY_USAGE_SEED_CHECKSUM
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_expected_result,
+ 0x11, 0x78, 0xe6, 0xc5, 0xc4, 0x7a, 0x8c, 0x1a,
+ 0xe0, 0xc4, 0xb9, 0xc7, 0xd4, 0xeb, 0x7b, 0x6b
+);
+
+DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test2_plaintext,
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ");
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_basekey,
+ 0x50, 0x27, 0xbc, 0x23, 0x1d, 0x0f, 0x3a, 0x9d,
+ 0x23, 0x33, 0x3f, 0x1c, 0xa6, 0xfd, 0xbe, 0x7c
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_usage,
+ 0x00, 0x00, 0x00, 0x08, KEY_USAGE_SEED_CHECKSUM
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_expected_result,
+ 0xd1, 0xb3, 0x4f, 0x70, 0x04, 0xa7, 0x31, 0xf2,
+ 0x3a, 0x0c, 0x00, 0xbf, 0x6c, 0x3f, 0x75, 0x3a
+);
+
+DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test3_plaintext,
+ "123456789");
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_basekey,
+ 0xb6, 0x1c, 0x86, 0xcc, 0x4e, 0x5d, 0x27, 0x57,
+ 0x54, 0x5a, 0xd4, 0x23, 0x39, 0x9f, 0xb7, 0x03,
+ 0x1e, 0xca, 0xb9, 0x13, 0xcb, 0xb9, 0x00, 0xbd,
+ 0x7a, 0x3c, 0x6d, 0xd8, 0xbf, 0x92, 0x01, 0x5b
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_usage,
+ 0x00, 0x00, 0x00, 0x09, KEY_USAGE_SEED_CHECKSUM
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_expected_result,
+ 0x87, 0xa1, 0x2c, 0xfd, 0x2b, 0x96, 0x21, 0x48,
+ 0x10, 0xf0, 0x1c, 0x82, 0x6e, 0x77, 0x44, 0xb1
+);
+
+DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test4_plaintext,
+ "!@#$%^&*()!@#$%^&*()!@#$%^&*()");
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_basekey,
+ 0x32, 0x16, 0x4c, 0x5b, 0x43, 0x4d, 0x1d, 0x15,
+ 0x38, 0xe4, 0xcf, 0xd9, 0xbe, 0x80, 0x40, 0xfe,
+ 0x8c, 0x4a, 0xc7, 0xac, 0xc4, 0xb9, 0x3d, 0x33,
+ 0x14, 0xd2, 0x13, 0x36, 0x68, 0x14, 0x7a, 0x05
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_usage,
+ 0x00, 0x00, 0x00, 0x0a, KEY_USAGE_SEED_CHECKSUM
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_expected_result,
+ 0x3f, 0xa0, 0xb4, 0x23, 0x55, 0xe5, 0x2b, 0x18,
+ 0x91, 0x87, 0x29, 0x4a, 0xa2, 0x52, 0xab, 0x64
+);
+
+static const struct gss_krb5_test_param rfc6803_checksum_test_params[] = {
+ {
+ .desc = "camellia128-cts-cmac checksum test 1",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .base_key = &rfc6803_checksum_test1_basekey,
+ .usage = &rfc6803_checksum_test1_usage,
+ .plaintext = &rfc6803_checksum_test1_plaintext,
+ .expected_result = &rfc6803_checksum_test1_expected_result,
+ },
+ {
+ .desc = "camellia128-cts-cmac checksum test 2",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .base_key = &rfc6803_checksum_test2_basekey,
+ .usage = &rfc6803_checksum_test2_usage,
+ .plaintext = &rfc6803_checksum_test2_plaintext,
+ .expected_result = &rfc6803_checksum_test2_expected_result,
+ },
+ {
+ .desc = "camellia256-cts-cmac checksum test 3",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .base_key = &rfc6803_checksum_test3_basekey,
+ .usage = &rfc6803_checksum_test3_usage,
+ .plaintext = &rfc6803_checksum_test3_plaintext,
+ .expected_result = &rfc6803_checksum_test3_expected_result,
+ },
+ {
+ .desc = "camellia256-cts-cmac checksum test 4",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .base_key = &rfc6803_checksum_test4_basekey,
+ .usage = &rfc6803_checksum_test4_usage,
+ .plaintext = &rfc6803_checksum_test4_plaintext,
+ .expected_result = &rfc6803_checksum_test4_expected_result,
+ },
+};
+
+/* Creates the function rfc6803_checksum_gen_params */
+KUNIT_ARRAY_PARAM(rfc6803_checksum, rfc6803_checksum_test_params,
+ gss_krb5_get_desc);
+
+/*
+ * From RFC 6803 Section 10. Test vectors
+ *
+ * Sample encryptions (all using the default cipher state)
+ *
+ * Copyright (c) 2012 IETF Trust and the persons identified as the
+ * document authors. All rights reserved.
+ *
+ * Key usage values are from errata 4326 against RFC 6803.
+ */
+
+static const struct xdr_netobj rfc6803_enc_empty_plaintext = {
+ .len = 0,
+};
+
+DEFINE_STR_XDR_NETOBJ(rfc6803_enc_1byte_plaintext, "1");
+DEFINE_STR_XDR_NETOBJ(rfc6803_enc_9byte_plaintext, "9 bytesss");
+DEFINE_STR_XDR_NETOBJ(rfc6803_enc_13byte_plaintext, "13 bytes byte");
+DEFINE_STR_XDR_NETOBJ(rfc6803_enc_30byte_plaintext,
+ "30 bytes bytes bytes bytes byt"
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_confounder,
+ 0xb6, 0x98, 0x22, 0xa1, 0x9a, 0x6b, 0x09, 0xc0,
+ 0xeb, 0xc8, 0x55, 0x7d, 0x1f, 0x1b, 0x6c, 0x0a
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_basekey,
+ 0x1d, 0xc4, 0x6a, 0x8d, 0x76, 0x3f, 0x4f, 0x93,
+ 0x74, 0x2b, 0xcb, 0xa3, 0x38, 0x75, 0x76, 0xc3
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_expected_result,
+ 0xc4, 0x66, 0xf1, 0x87, 0x10, 0x69, 0x92, 0x1e,
+ 0xdb, 0x7c, 0x6f, 0xde, 0x24, 0x4a, 0x52, 0xdb,
+ 0x0b, 0xa1, 0x0e, 0xdc, 0x19, 0x7b, 0xdb, 0x80,
+ 0x06, 0x65, 0x8c, 0xa3, 0xcc, 0xce, 0x6e, 0xb8
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_confounder,
+ 0x6f, 0x2f, 0xc3, 0xc2, 0xa1, 0x66, 0xfd, 0x88,
+ 0x98, 0x96, 0x7a, 0x83, 0xde, 0x95, 0x96, 0xd9
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_basekey,
+ 0x50, 0x27, 0xbc, 0x23, 0x1d, 0x0f, 0x3a, 0x9d,
+ 0x23, 0x33, 0x3f, 0x1c, 0xa6, 0xfd, 0xbe, 0x7c
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_expected_result,
+ 0x84, 0x2d, 0x21, 0xfd, 0x95, 0x03, 0x11, 0xc0,
+ 0xdd, 0x46, 0x4a, 0x3f, 0x4b, 0xe8, 0xd6, 0xda,
+ 0x88, 0xa5, 0x6d, 0x55, 0x9c, 0x9b, 0x47, 0xd3,
+ 0xf9, 0xa8, 0x50, 0x67, 0xaf, 0x66, 0x15, 0x59,
+ 0xb8
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_confounder,
+ 0xa5, 0xb4, 0xa7, 0x1e, 0x07, 0x7a, 0xee, 0xf9,
+ 0x3c, 0x87, 0x63, 0xc1, 0x8f, 0xdb, 0x1f, 0x10
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_basekey,
+ 0xa1, 0xbb, 0x61, 0xe8, 0x05, 0xf9, 0xba, 0x6d,
+ 0xde, 0x8f, 0xdb, 0xdd, 0xc0, 0x5c, 0xde, 0xa0
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_expected_result,
+ 0x61, 0x9f, 0xf0, 0x72, 0xe3, 0x62, 0x86, 0xff,
+ 0x0a, 0x28, 0xde, 0xb3, 0xa3, 0x52, 0xec, 0x0d,
+ 0x0e, 0xdf, 0x5c, 0x51, 0x60, 0xd6, 0x63, 0xc9,
+ 0x01, 0x75, 0x8c, 0xcf, 0x9d, 0x1e, 0xd3, 0x3d,
+ 0x71, 0xdb, 0x8f, 0x23, 0xaa, 0xbf, 0x83, 0x48,
+ 0xa0
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_confounder,
+ 0x19, 0xfe, 0xe4, 0x0d, 0x81, 0x0c, 0x52, 0x4b,
+ 0x5b, 0x22, 0xf0, 0x18, 0x74, 0xc6, 0x93, 0xda
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_basekey,
+ 0x2c, 0xa2, 0x7a, 0x5f, 0xaf, 0x55, 0x32, 0x24,
+ 0x45, 0x06, 0x43, 0x4e, 0x1c, 0xef, 0x66, 0x76
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_expected_result,
+ 0xb8, 0xec, 0xa3, 0x16, 0x7a, 0xe6, 0x31, 0x55,
+ 0x12, 0xe5, 0x9f, 0x98, 0xa7, 0xc5, 0x00, 0x20,
+ 0x5e, 0x5f, 0x63, 0xff, 0x3b, 0xb3, 0x89, 0xaf,
+ 0x1c, 0x41, 0xa2, 0x1d, 0x64, 0x0d, 0x86, 0x15,
+ 0xc9, 0xed, 0x3f, 0xbe, 0xb0, 0x5a, 0xb6, 0xac,
+ 0xb6, 0x76, 0x89, 0xb5, 0xea
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_confounder,
+ 0xca, 0x7a, 0x7a, 0xb4, 0xbe, 0x19, 0x2d, 0xab,
+ 0xd6, 0x03, 0x50, 0x6d, 0xb1, 0x9c, 0x39, 0xe2
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_basekey,
+ 0x78, 0x24, 0xf8, 0xc1, 0x6f, 0x83, 0xff, 0x35,
+ 0x4c, 0x6b, 0xf7, 0x51, 0x5b, 0x97, 0x3f, 0x43
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_expected_result,
+ 0xa2, 0x6a, 0x39, 0x05, 0xa4, 0xff, 0xd5, 0x81,
+ 0x6b, 0x7b, 0x1e, 0x27, 0x38, 0x0d, 0x08, 0x09,
+ 0x0c, 0x8e, 0xc1, 0xf3, 0x04, 0x49, 0x6e, 0x1a,
+ 0xbd, 0xcd, 0x2b, 0xdc, 0xd1, 0xdf, 0xfc, 0x66,
+ 0x09, 0x89, 0xe1, 0x17, 0xa7, 0x13, 0xdd, 0xbb,
+ 0x57, 0xa4, 0x14, 0x6c, 0x15, 0x87, 0xcb, 0xa4,
+ 0x35, 0x66, 0x65, 0x59, 0x1d, 0x22, 0x40, 0x28,
+ 0x2f, 0x58, 0x42, 0xb1, 0x05, 0xa5
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_confounder,
+ 0x3c, 0xbb, 0xd2, 0xb4, 0x59, 0x17, 0x94, 0x10,
+ 0x67, 0xf9, 0x65, 0x99, 0xbb, 0x98, 0x92, 0x6c
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_basekey,
+ 0xb6, 0x1c, 0x86, 0xcc, 0x4e, 0x5d, 0x27, 0x57,
+ 0x54, 0x5a, 0xd4, 0x23, 0x39, 0x9f, 0xb7, 0x03,
+ 0x1e, 0xca, 0xb9, 0x13, 0xcb, 0xb9, 0x00, 0xbd,
+ 0x7a, 0x3c, 0x6d, 0xd8, 0xbf, 0x92, 0x01, 0x5b
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_expected_result,
+ 0x03, 0x88, 0x6d, 0x03, 0x31, 0x0b, 0x47, 0xa6,
+ 0xd8, 0xf0, 0x6d, 0x7b, 0x94, 0xd1, 0xdd, 0x83,
+ 0x7e, 0xcc, 0xe3, 0x15, 0xef, 0x65, 0x2a, 0xff,
+ 0x62, 0x08, 0x59, 0xd9, 0x4a, 0x25, 0x92, 0x66
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_confounder,
+ 0xde, 0xf4, 0x87, 0xfc, 0xeb, 0xe6, 0xde, 0x63,
+ 0x46, 0xd4, 0xda, 0x45, 0x21, 0xbb, 0xa2, 0xd2
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_basekey,
+ 0x1b, 0x97, 0xfe, 0x0a, 0x19, 0x0e, 0x20, 0x21,
+ 0xeb, 0x30, 0x75, 0x3e, 0x1b, 0x6e, 0x1e, 0x77,
+ 0xb0, 0x75, 0x4b, 0x1d, 0x68, 0x46, 0x10, 0x35,
+ 0x58, 0x64, 0x10, 0x49, 0x63, 0x46, 0x38, 0x33
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_expected_result,
+ 0x2c, 0x9c, 0x15, 0x70, 0x13, 0x3c, 0x99, 0xbf,
+ 0x6a, 0x34, 0xbc, 0x1b, 0x02, 0x12, 0x00, 0x2f,
+ 0xd1, 0x94, 0x33, 0x87, 0x49, 0xdb, 0x41, 0x35,
+ 0x49, 0x7a, 0x34, 0x7c, 0xfc, 0xd9, 0xd1, 0x8a,
+ 0x12
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_confounder,
+ 0xad, 0x4f, 0xf9, 0x04, 0xd3, 0x4e, 0x55, 0x53,
+ 0x84, 0xb1, 0x41, 0x00, 0xfc, 0x46, 0x5f, 0x88
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_basekey,
+ 0x32, 0x16, 0x4c, 0x5b, 0x43, 0x4d, 0x1d, 0x15,
+ 0x38, 0xe4, 0xcf, 0xd9, 0xbe, 0x80, 0x40, 0xfe,
+ 0x8c, 0x4a, 0xc7, 0xac, 0xc4, 0xb9, 0x3d, 0x33,
+ 0x14, 0xd2, 0x13, 0x36, 0x68, 0x14, 0x7a, 0x05
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_expected_result,
+ 0x9c, 0x6d, 0xe7, 0x5f, 0x81, 0x2d, 0xe7, 0xed,
+ 0x0d, 0x28, 0xb2, 0x96, 0x35, 0x57, 0xa1, 0x15,
+ 0x64, 0x09, 0x98, 0x27, 0x5b, 0x0a, 0xf5, 0x15,
+ 0x27, 0x09, 0x91, 0x3f, 0xf5, 0x2a, 0x2a, 0x9c,
+ 0x8e, 0x63, 0xb8, 0x72, 0xf9, 0x2e, 0x64, 0xc8,
+ 0x39
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_confounder,
+ 0xcf, 0x9b, 0xca, 0x6d, 0xf1, 0x14, 0x4e, 0x0c,
+ 0x0a, 0xf9, 0xb8, 0xf3, 0x4c, 0x90, 0xd5, 0x14
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_basekey,
+ 0xb0, 0x38, 0xb1, 0x32, 0xcd, 0x8e, 0x06, 0x61,
+ 0x22, 0x67, 0xfa, 0xb7, 0x17, 0x00, 0x66, 0xd8,
+ 0x8a, 0xec, 0xcb, 0xa0, 0xb7, 0x44, 0xbf, 0xc6,
+ 0x0d, 0xc8, 0x9b, 0xca, 0x18, 0x2d, 0x07, 0x15
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_expected_result,
+ 0xee, 0xec, 0x85, 0xa9, 0x81, 0x3c, 0xdc, 0x53,
+ 0x67, 0x72, 0xab, 0x9b, 0x42, 0xde, 0xfc, 0x57,
+ 0x06, 0xf7, 0x26, 0xe9, 0x75, 0xdd, 0xe0, 0x5a,
+ 0x87, 0xeb, 0x54, 0x06, 0xea, 0x32, 0x4c, 0xa1,
+ 0x85, 0xc9, 0x98, 0x6b, 0x42, 0xaa, 0xbe, 0x79,
+ 0x4b, 0x84, 0x82, 0x1b, 0xee
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_confounder,
+ 0x64, 0x4d, 0xef, 0x38, 0xda, 0x35, 0x00, 0x72,
+ 0x75, 0x87, 0x8d, 0x21, 0x68, 0x55, 0xe2, 0x28
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_basekey,
+ 0xcc, 0xfc, 0xd3, 0x49, 0xbf, 0x4c, 0x66, 0x77,
+ 0xe8, 0x6e, 0x4b, 0x02, 0xb8, 0xea, 0xb9, 0x24,
+ 0xa5, 0x46, 0xac, 0x73, 0x1c, 0xf9, 0xbf, 0x69,
+ 0x89, 0xb9, 0x96, 0xe7, 0xd6, 0xbf, 0xbb, 0xa7
+);
+DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_expected_result,
+ 0x0e, 0x44, 0x68, 0x09, 0x85, 0x85, 0x5f, 0x2d,
+ 0x1f, 0x18, 0x12, 0x52, 0x9c, 0xa8, 0x3b, 0xfd,
+ 0x8e, 0x34, 0x9d, 0xe6, 0xfd, 0x9a, 0xda, 0x0b,
+ 0xaa, 0xa0, 0x48, 0xd6, 0x8e, 0x26, 0x5f, 0xeb,
+ 0xf3, 0x4a, 0xd1, 0x25, 0x5a, 0x34, 0x49, 0x99,
+ 0xad, 0x37, 0x14, 0x68, 0x87, 0xa6, 0xc6, 0x84,
+ 0x57, 0x31, 0xac, 0x7f, 0x46, 0x37, 0x6a, 0x05,
+ 0x04, 0xcd, 0x06, 0x57, 0x14, 0x74
+);
+
+static const struct gss_krb5_test_param rfc6803_encrypt_test_params[] = {
+ {
+ .desc = "Encrypt empty plaintext with camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .constant = 0,
+ .base_key = &rfc6803_enc_test1_basekey,
+ .plaintext = &rfc6803_enc_empty_plaintext,
+ .confounder = &rfc6803_enc_test1_confounder,
+ .expected_result = &rfc6803_enc_test1_expected_result,
+ },
+ {
+ .desc = "Encrypt 1 byte with camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .constant = 1,
+ .base_key = &rfc6803_enc_test2_basekey,
+ .plaintext = &rfc6803_enc_1byte_plaintext,
+ .confounder = &rfc6803_enc_test2_confounder,
+ .expected_result = &rfc6803_enc_test2_expected_result,
+ },
+ {
+ .desc = "Encrypt 9 bytes with camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .constant = 2,
+ .base_key = &rfc6803_enc_test3_basekey,
+ .plaintext = &rfc6803_enc_9byte_plaintext,
+ .confounder = &rfc6803_enc_test3_confounder,
+ .expected_result = &rfc6803_enc_test3_expected_result,
+ },
+ {
+ .desc = "Encrypt 13 bytes with camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .constant = 3,
+ .base_key = &rfc6803_enc_test4_basekey,
+ .plaintext = &rfc6803_enc_13byte_plaintext,
+ .confounder = &rfc6803_enc_test4_confounder,
+ .expected_result = &rfc6803_enc_test4_expected_result,
+ },
+ {
+ .desc = "Encrypt 30 bytes with camellia128-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .constant = 4,
+ .base_key = &rfc6803_enc_test5_basekey,
+ .plaintext = &rfc6803_enc_30byte_plaintext,
+ .confounder = &rfc6803_enc_test5_confounder,
+ .expected_result = &rfc6803_enc_test5_expected_result,
+ },
+ {
+ .desc = "Encrypt empty plaintext with camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .constant = 0,
+ .base_key = &rfc6803_enc_test6_basekey,
+ .plaintext = &rfc6803_enc_empty_plaintext,
+ .confounder = &rfc6803_enc_test6_confounder,
+ .expected_result = &rfc6803_enc_test6_expected_result,
+ },
+ {
+ .desc = "Encrypt 1 byte with camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .constant = 1,
+ .base_key = &rfc6803_enc_test7_basekey,
+ .plaintext = &rfc6803_enc_1byte_plaintext,
+ .confounder = &rfc6803_enc_test7_confounder,
+ .expected_result = &rfc6803_enc_test7_expected_result,
+ },
+ {
+ .desc = "Encrypt 9 bytes with camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .constant = 2,
+ .base_key = &rfc6803_enc_test8_basekey,
+ .plaintext = &rfc6803_enc_9byte_plaintext,
+ .confounder = &rfc6803_enc_test8_confounder,
+ .expected_result = &rfc6803_enc_test8_expected_result,
+ },
+ {
+ .desc = "Encrypt 13 bytes with camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .constant = 3,
+ .base_key = &rfc6803_enc_test9_basekey,
+ .plaintext = &rfc6803_enc_13byte_plaintext,
+ .confounder = &rfc6803_enc_test9_confounder,
+ .expected_result = &rfc6803_enc_test9_expected_result,
+ },
+ {
+ .desc = "Encrypt 30 bytes with camellia256-cts-cmac",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .constant = 4,
+ .base_key = &rfc6803_enc_test10_basekey,
+ .plaintext = &rfc6803_enc_30byte_plaintext,
+ .confounder = &rfc6803_enc_test10_confounder,
+ .expected_result = &rfc6803_enc_test10_expected_result,
+ },
+};
+
+/* Creates the function rfc6803_encrypt_gen_params */
+KUNIT_ARRAY_PARAM(rfc6803_encrypt, rfc6803_encrypt_test_params,
+ gss_krb5_get_desc);
+
+static void rfc6803_encrypt_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ struct crypto_sync_skcipher *cts_tfm, *cbc_tfm;
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_netobj Ke, Ki, checksum;
+ u8 usage_data[GSS_KRB5_K5CLENGTH];
+ struct xdr_netobj usage = {
+ .data = usage_data,
+ .len = sizeof(usage_data),
+ };
+ struct crypto_ahash *ahash_tfm;
+ unsigned int blocksize;
+ struct xdr_buf buf;
+ void *text;
+ size_t len;
+ u32 err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ memset(usage_data, 0, sizeof(usage_data));
+ usage.data[3] = param->constant;
+
+ Ke.len = gk5e->Ke_length;
+ Ke.data = kunit_kzalloc(test, Ke.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ke.data);
+ usage.data[4] = KEY_USAGE_SEED_ENCRYPTION;
+ err = gk5e->derive_key(gk5e, param->base_key, &Ke, &usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm);
+ err = crypto_sync_skcipher_setkey(cbc_tfm, Ke.data, Ke.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm);
+ err = crypto_sync_skcipher_setkey(cts_tfm, Ke.data, Ke.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+ blocksize = crypto_sync_skcipher_blocksize(cts_tfm);
+
+ len = param->confounder->len + param->plaintext->len + blocksize;
+ text = kunit_kzalloc(test, len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text);
+ memcpy(text, param->confounder->data, param->confounder->len);
+ memcpy(text + param->confounder->len, param->plaintext->data,
+ param->plaintext->len);
+
+ memset(&buf, 0, sizeof(buf));
+ buf.head[0].iov_base = text;
+ buf.head[0].iov_len = param->confounder->len + param->plaintext->len;
+ buf.len = buf.head[0].iov_len;
+
+ checksum.len = gk5e->cksumlength;
+ checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data);
+
+ Ki.len = gk5e->Ki_length;
+ Ki.data = kunit_kzalloc(test, Ki.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ki.data);
+ usage.data[4] = KEY_USAGE_SEED_INTEGRITY;
+ err = gk5e->derive_key(gk5e, param->base_key, &Ki,
+ &usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+ ahash_tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, ahash_tfm);
+ err = crypto_ahash_setkey(ahash_tfm, Ki.data, Ki.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Act */
+ err = gss_krb5_checksum(ahash_tfm, NULL, 0, &buf, 0, &checksum);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test, param->expected_result->len,
+ buf.len + checksum.len,
+ "ciphertext length mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ buf.head[0].iov_base, buf.len), 0,
+ "encrypted result mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data +
+ (param->expected_result->len - checksum.len),
+ checksum.data, checksum.len), 0,
+ "HMAC mismatch");
+
+ crypto_free_ahash(ahash_tfm);
+ crypto_free_sync_skcipher(cts_tfm);
+ crypto_free_sync_skcipher(cbc_tfm);
+}
+
+static struct kunit_case rfc6803_test_cases[] = {
+ {
+ .name = "RFC 6803 key derivation",
+ .run_case = kdf_case,
+ .generate_params = rfc6803_kdf_gen_params,
+ },
+ {
+ .name = "RFC 6803 checksum",
+ .run_case = checksum_case,
+ .generate_params = rfc6803_checksum_gen_params,
+ },
+ {
+ .name = "RFC 6803 encryption",
+ .run_case = rfc6803_encrypt_case,
+ .generate_params = rfc6803_encrypt_gen_params,
+ },
+ {}
+};
+
+static struct kunit_suite rfc6803_suite = {
+ .name = "RFC 6803 suite",
+ .test_cases = rfc6803_test_cases,
+};
+
+/*
+ * From RFC 8009 Appendix A. Test Vectors
+ *
+ * Sample results for SHA-2 enctype key derivation
+ *
+ * This test material is copyright (c) 2016 IETF Trust and the
+ * persons identified as the document authors. All rights reserved.
+ */
+
+DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_basekey,
+ 0x37, 0x05, 0xd9, 0x60, 0x80, 0xc1, 0x77, 0x28,
+ 0xa0, 0xe8, 0x00, 0xea, 0xb6, 0xe0, 0xd2, 0x3c
+);
+DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Kc,
+ 0xb3, 0x1a, 0x01, 0x8a, 0x48, 0xf5, 0x47, 0x76,
+ 0xf4, 0x03, 0xe9, 0xa3, 0x96, 0x32, 0x5d, 0xc3
+);
+DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Ke,
+ 0x9b, 0x19, 0x7d, 0xd1, 0xe8, 0xc5, 0x60, 0x9d,
+ 0x6e, 0x67, 0xc3, 0xe3, 0x7c, 0x62, 0xc7, 0x2e
+);
+DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Ki,
+ 0x9f, 0xda, 0x0e, 0x56, 0xab, 0x2d, 0x85, 0xe1,
+ 0x56, 0x9a, 0x68, 0x86, 0x96, 0xc2, 0x6a, 0x6c
+);
+
+DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_basekey,
+ 0x6d, 0x40, 0x4d, 0x37, 0xfa, 0xf7, 0x9f, 0x9d,
+ 0xf0, 0xd3, 0x35, 0x68, 0xd3, 0x20, 0x66, 0x98,
+ 0x00, 0xeb, 0x48, 0x36, 0x47, 0x2e, 0xa8, 0xa0,
+ 0x26, 0xd1, 0x6b, 0x71, 0x82, 0x46, 0x0c, 0x52
+);
+DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Kc,
+ 0xef, 0x57, 0x18, 0xbe, 0x86, 0xcc, 0x84, 0x96,
+ 0x3d, 0x8b, 0xbb, 0x50, 0x31, 0xe9, 0xf5, 0xc4,
+ 0xba, 0x41, 0xf2, 0x8f, 0xaf, 0x69, 0xe7, 0x3d
+);
+DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Ke,
+ 0x56, 0xab, 0x22, 0xbe, 0xe6, 0x3d, 0x82, 0xd7,
+ 0xbc, 0x52, 0x27, 0xf6, 0x77, 0x3f, 0x8e, 0xa7,
+ 0xa5, 0xeb, 0x1c, 0x82, 0x51, 0x60, 0xc3, 0x83,
+ 0x12, 0x98, 0x0c, 0x44, 0x2e, 0x5c, 0x7e, 0x49
+);
+DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Ki,
+ 0x69, 0xb1, 0x65, 0x14, 0xe3, 0xcd, 0x8e, 0x56,
+ 0xb8, 0x20, 0x10, 0xd5, 0xc7, 0x30, 0x12, 0xb6,
+ 0x22, 0xc4, 0xd0, 0x0f, 0xfc, 0x23, 0xed, 0x1f
+);
+
+static const struct gss_krb5_test_param rfc8009_kdf_test_params[] = {
+ {
+ .desc = "Derive Kc subkey for aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .usage = &usage_checksum,
+ .expected_result = &aes128_cts_hmac_sha256_128_Kc,
+ },
+ {
+ .desc = "Derive Ke subkey for aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .usage = &usage_encryption,
+ .expected_result = &aes128_cts_hmac_sha256_128_Ke,
+ },
+ {
+ .desc = "Derive Ki subkey for aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .usage = &usage_integrity,
+ .expected_result = &aes128_cts_hmac_sha256_128_Ki,
+ },
+ {
+ .desc = "Derive Kc subkey for aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .usage = &usage_checksum,
+ .expected_result = &aes256_cts_hmac_sha384_192_Kc,
+ },
+ {
+ .desc = "Derive Ke subkey for aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .usage = &usage_encryption,
+ .expected_result = &aes256_cts_hmac_sha384_192_Ke,
+ },
+ {
+ .desc = "Derive Ki subkey for aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .usage = &usage_integrity,
+ .expected_result = &aes256_cts_hmac_sha384_192_Ki,
+ },
+};
+
+/* Creates the function rfc8009_kdf_gen_params */
+KUNIT_ARRAY_PARAM(rfc8009_kdf, rfc8009_kdf_test_params, gss_krb5_get_desc);
+
+/*
+ * From RFC 8009 Appendix A. Test Vectors
+ *
+ * These sample checksums use the above sample key derivation results,
+ * including use of the same base-key and key usage values.
+ *
+ * This test material is copyright (c) 2016 IETF Trust and the
+ * persons identified as the document authors. All rights reserved.
+ */
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_plaintext,
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+ 0x10, 0x11, 0x12, 0x13, 0x14
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_test1_expected_result,
+ 0xd7, 0x83, 0x67, 0x18, 0x66, 0x43, 0xd6, 0x7b,
+ 0x41, 0x1c, 0xba, 0x91, 0x39, 0xfc, 0x1d, 0xee
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_test2_expected_result,
+ 0x45, 0xee, 0x79, 0x15, 0x67, 0xee, 0xfc, 0xa3,
+ 0x7f, 0x4a, 0xc1, 0xe0, 0x22, 0x2d, 0xe8, 0x0d,
+ 0x43, 0xc3, 0xbf, 0xa0, 0x66, 0x99, 0x67, 0x2a
+);
+
+static const struct gss_krb5_test_param rfc8009_checksum_test_params[] = {
+ {
+ .desc = "Checksum with aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .usage = &usage_checksum,
+ .plaintext = &rfc8009_checksum_plaintext,
+ .expected_result = &rfc8009_checksum_test1_expected_result,
+ },
+ {
+ .desc = "Checksum with aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .usage = &usage_checksum,
+ .plaintext = &rfc8009_checksum_plaintext,
+ .expected_result = &rfc8009_checksum_test2_expected_result,
+ },
+};
+
+/* Creates the function rfc8009_checksum_gen_params */
+KUNIT_ARRAY_PARAM(rfc8009_checksum, rfc8009_checksum_test_params,
+ gss_krb5_get_desc);
+
+/*
+ * From RFC 8009 Appendix A. Test Vectors
+ *
+ * Sample encryptions (all using the default cipher state):
+ * --------------------------------------------------------
+ *
+ * These sample encryptions use the above sample key derivation results,
+ * including use of the same base-key and key usage values.
+ *
+ * This test material is copyright (c) 2016 IETF Trust and the
+ * persons identified as the document authors. All rights reserved.
+ */
+
+static const struct xdr_netobj rfc8009_enc_empty_plaintext = {
+ .len = 0,
+};
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_short_plaintext,
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_block_plaintext,
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_long_plaintext,
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+ 0x10, 0x11, 0x12, 0x13, 0x14
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_confounder,
+ 0x7e, 0x58, 0x95, 0xea, 0xf2, 0x67, 0x24, 0x35,
+ 0xba, 0xd8, 0x17, 0xf5, 0x45, 0xa3, 0x71, 0x48
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_expected_result,
+ 0xef, 0x85, 0xfb, 0x89, 0x0b, 0xb8, 0x47, 0x2f,
+ 0x4d, 0xab, 0x20, 0x39, 0x4d, 0xca, 0x78, 0x1d
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_expected_hmac,
+ 0xad, 0x87, 0x7e, 0xda, 0x39, 0xd5, 0x0c, 0x87,
+ 0x0c, 0x0d, 0x5a, 0x0a, 0x8e, 0x48, 0xc7, 0x18
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_confounder,
+ 0x7b, 0xca, 0x28, 0x5e, 0x2f, 0xd4, 0x13, 0x0f,
+ 0xb5, 0x5b, 0x1a, 0x5c, 0x83, 0xbc, 0x5b, 0x24
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_expected_result,
+ 0x84, 0xd7, 0xf3, 0x07, 0x54, 0xed, 0x98, 0x7b,
+ 0xab, 0x0b, 0xf3, 0x50, 0x6b, 0xeb, 0x09, 0xcf,
+ 0xb5, 0x54, 0x02, 0xce, 0xf7, 0xe6
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_expected_hmac,
+ 0x87, 0x7c, 0xe9, 0x9e, 0x24, 0x7e, 0x52, 0xd1,
+ 0x6e, 0xd4, 0x42, 0x1d, 0xfd, 0xf8, 0x97, 0x6c
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_confounder,
+ 0x56, 0xab, 0x21, 0x71, 0x3f, 0xf6, 0x2c, 0x0a,
+ 0x14, 0x57, 0x20, 0x0f, 0x6f, 0xa9, 0x94, 0x8f
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_expected_result,
+ 0x35, 0x17, 0xd6, 0x40, 0xf5, 0x0d, 0xdc, 0x8a,
+ 0xd3, 0x62, 0x87, 0x22, 0xb3, 0x56, 0x9d, 0x2a,
+ 0xe0, 0x74, 0x93, 0xfa, 0x82, 0x63, 0x25, 0x40,
+ 0x80, 0xea, 0x65, 0xc1, 0x00, 0x8e, 0x8f, 0xc2
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_expected_hmac,
+ 0x95, 0xfb, 0x48, 0x52, 0xe7, 0xd8, 0x3e, 0x1e,
+ 0x7c, 0x48, 0xc3, 0x7e, 0xeb, 0xe6, 0xb0, 0xd3
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_confounder,
+ 0xa7, 0xa4, 0xe2, 0x9a, 0x47, 0x28, 0xce, 0x10,
+ 0x66, 0x4f, 0xb6, 0x4e, 0x49, 0xad, 0x3f, 0xac
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_expected_result,
+ 0x72, 0x0f, 0x73, 0xb1, 0x8d, 0x98, 0x59, 0xcd,
+ 0x6c, 0xcb, 0x43, 0x46, 0x11, 0x5c, 0xd3, 0x36,
+ 0xc7, 0x0f, 0x58, 0xed, 0xc0, 0xc4, 0x43, 0x7c,
+ 0x55, 0x73, 0x54, 0x4c, 0x31, 0xc8, 0x13, 0xbc,
+ 0xe1, 0xe6, 0xd0, 0x72, 0xc1
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_expected_hmac,
+ 0x86, 0xb3, 0x9a, 0x41, 0x3c, 0x2f, 0x92, 0xca,
+ 0x9b, 0x83, 0x34, 0xa2, 0x87, 0xff, 0xcb, 0xfc
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_confounder,
+ 0xf7, 0x64, 0xe9, 0xfa, 0x15, 0xc2, 0x76, 0x47,
+ 0x8b, 0x2c, 0x7d, 0x0c, 0x4e, 0x5f, 0x58, 0xe4
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_expected_result,
+ 0x41, 0xf5, 0x3f, 0xa5, 0xbf, 0xe7, 0x02, 0x6d,
+ 0x91, 0xfa, 0xf9, 0xbe, 0x95, 0x91, 0x95, 0xa0
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_expected_hmac,
+ 0x58, 0x70, 0x72, 0x73, 0xa9, 0x6a, 0x40, 0xf0,
+ 0xa0, 0x19, 0x60, 0x62, 0x1a, 0xc6, 0x12, 0x74,
+ 0x8b, 0x9b, 0xbf, 0xbe, 0x7e, 0xb4, 0xce, 0x3c
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_confounder,
+ 0xb8, 0x0d, 0x32, 0x51, 0xc1, 0xf6, 0x47, 0x14,
+ 0x94, 0x25, 0x6f, 0xfe, 0x71, 0x2d, 0x0b, 0x9a
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_expected_result,
+ 0x4e, 0xd7, 0xb3, 0x7c, 0x2b, 0xca, 0xc8, 0xf7,
+ 0x4f, 0x23, 0xc1, 0xcf, 0x07, 0xe6, 0x2b, 0xc7,
+ 0xb7, 0x5f, 0xb3, 0xf6, 0x37, 0xb9
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_expected_hmac,
+ 0xf5, 0x59, 0xc7, 0xf6, 0x64, 0xf6, 0x9e, 0xab,
+ 0x7b, 0x60, 0x92, 0x23, 0x75, 0x26, 0xea, 0x0d,
+ 0x1f, 0x61, 0xcb, 0x20, 0xd6, 0x9d, 0x10, 0xf2
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_confounder,
+ 0x53, 0xbf, 0x8a, 0x0d, 0x10, 0x52, 0x65, 0xd4,
+ 0xe2, 0x76, 0x42, 0x86, 0x24, 0xce, 0x5e, 0x63
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_expected_result,
+ 0xbc, 0x47, 0xff, 0xec, 0x79, 0x98, 0xeb, 0x91,
+ 0xe8, 0x11, 0x5c, 0xf8, 0xd1, 0x9d, 0xac, 0x4b,
+ 0xbb, 0xe2, 0xe1, 0x63, 0xe8, 0x7d, 0xd3, 0x7f,
+ 0x49, 0xbe, 0xca, 0x92, 0x02, 0x77, 0x64, 0xf6
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_expected_hmac,
+ 0x8c, 0xf5, 0x1f, 0x14, 0xd7, 0x98, 0xc2, 0x27,
+ 0x3f, 0x35, 0xdf, 0x57, 0x4d, 0x1f, 0x93, 0x2e,
+ 0x40, 0xc4, 0xff, 0x25, 0x5b, 0x36, 0xa2, 0x66
+);
+
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_confounder,
+ 0x76, 0x3e, 0x65, 0x36, 0x7e, 0x86, 0x4f, 0x02,
+ 0xf5, 0x51, 0x53, 0xc7, 0xe3, 0xb5, 0x8a, 0xf1
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_expected_result,
+ 0x40, 0x01, 0x3e, 0x2d, 0xf5, 0x8e, 0x87, 0x51,
+ 0x95, 0x7d, 0x28, 0x78, 0xbc, 0xd2, 0xd6, 0xfe,
+ 0x10, 0x1c, 0xcf, 0xd5, 0x56, 0xcb, 0x1e, 0xae,
+ 0x79, 0xdb, 0x3c, 0x3e, 0xe8, 0x64, 0x29, 0xf2,
+ 0xb2, 0xa6, 0x02, 0xac, 0x86
+);
+DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_expected_hmac,
+ 0xfe, 0xf6, 0xec, 0xb6, 0x47, 0xd6, 0x29, 0x5f,
+ 0xae, 0x07, 0x7a, 0x1f, 0xeb, 0x51, 0x75, 0x08,
+ 0xd2, 0xc1, 0x6b, 0x41, 0x92, 0xe0, 0x1f, 0x62
+);
+
+static const struct gss_krb5_test_param rfc8009_encrypt_test_params[] = {
+ {
+ .desc = "Encrypt empty plaintext with aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .plaintext = &rfc8009_enc_empty_plaintext,
+ .confounder = &rfc8009_enc_test1_confounder,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .expected_result = &rfc8009_enc_test1_expected_result,
+ .expected_hmac = &rfc8009_enc_test1_expected_hmac,
+ },
+ {
+ .desc = "Encrypt short plaintext with aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .plaintext = &rfc8009_enc_short_plaintext,
+ .confounder = &rfc8009_enc_test2_confounder,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .expected_result = &rfc8009_enc_test2_expected_result,
+ .expected_hmac = &rfc8009_enc_test2_expected_hmac,
+ },
+ {
+ .desc = "Encrypt block plaintext with aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .plaintext = &rfc8009_enc_block_plaintext,
+ .confounder = &rfc8009_enc_test3_confounder,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .expected_result = &rfc8009_enc_test3_expected_result,
+ .expected_hmac = &rfc8009_enc_test3_expected_hmac,
+ },
+ {
+ .desc = "Encrypt long plaintext with aes128-cts-hmac-sha256-128",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .plaintext = &rfc8009_enc_long_plaintext,
+ .confounder = &rfc8009_enc_test4_confounder,
+ .base_key = &aes128_cts_hmac_sha256_128_basekey,
+ .expected_result = &rfc8009_enc_test4_expected_result,
+ .expected_hmac = &rfc8009_enc_test4_expected_hmac,
+ },
+ {
+ .desc = "Encrypt empty plaintext with aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .plaintext = &rfc8009_enc_empty_plaintext,
+ .confounder = &rfc8009_enc_test5_confounder,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .expected_result = &rfc8009_enc_test5_expected_result,
+ .expected_hmac = &rfc8009_enc_test5_expected_hmac,
+ },
+ {
+ .desc = "Encrypt short plaintext with aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .plaintext = &rfc8009_enc_short_plaintext,
+ .confounder = &rfc8009_enc_test6_confounder,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .expected_result = &rfc8009_enc_test6_expected_result,
+ .expected_hmac = &rfc8009_enc_test6_expected_hmac,
+ },
+ {
+ .desc = "Encrypt block plaintext with aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .plaintext = &rfc8009_enc_block_plaintext,
+ .confounder = &rfc8009_enc_test7_confounder,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .expected_result = &rfc8009_enc_test7_expected_result,
+ .expected_hmac = &rfc8009_enc_test7_expected_hmac,
+ },
+ {
+ .desc = "Encrypt long plaintext with aes256-cts-hmac-sha384-192",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .plaintext = &rfc8009_enc_long_plaintext,
+ .confounder = &rfc8009_enc_test8_confounder,
+ .base_key = &aes256_cts_hmac_sha384_192_basekey,
+ .expected_result = &rfc8009_enc_test8_expected_result,
+ .expected_hmac = &rfc8009_enc_test8_expected_hmac,
+ },
+};
+
+/* Creates the function rfc8009_encrypt_gen_params */
+KUNIT_ARRAY_PARAM(rfc8009_encrypt, rfc8009_encrypt_test_params,
+ gss_krb5_get_desc);
+
+static void rfc8009_encrypt_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ struct crypto_sync_skcipher *cts_tfm, *cbc_tfm;
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_netobj Ke, Ki, checksum;
+ u8 usage_data[GSS_KRB5_K5CLENGTH];
+ struct xdr_netobj usage = {
+ .data = usage_data,
+ .len = sizeof(usage_data),
+ };
+ struct crypto_ahash *ahash_tfm;
+ struct xdr_buf buf;
+ void *text;
+ size_t len;
+ u32 err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ *(__be32 *)usage.data = cpu_to_be32(2);
+
+ Ke.len = gk5e->Ke_length;
+ Ke.data = kunit_kzalloc(test, Ke.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ke.data);
+ usage.data[4] = KEY_USAGE_SEED_ENCRYPTION;
+ err = gk5e->derive_key(gk5e, param->base_key, &Ke,
+ &usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm);
+ err = crypto_sync_skcipher_setkey(cbc_tfm, Ke.data, Ke.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm);
+ err = crypto_sync_skcipher_setkey(cts_tfm, Ke.data, Ke.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ len = param->confounder->len + param->plaintext->len;
+ text = kunit_kzalloc(test, len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text);
+ memcpy(text, param->confounder->data, param->confounder->len);
+ memcpy(text + param->confounder->len, param->plaintext->data,
+ param->plaintext->len);
+
+ memset(&buf, 0, sizeof(buf));
+ buf.head[0].iov_base = text;
+ buf.head[0].iov_len = param->confounder->len + param->plaintext->len;
+ buf.len = buf.head[0].iov_len;
+
+ checksum.len = gk5e->cksumlength;
+ checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data);
+
+ Ki.len = gk5e->Ki_length;
+ Ki.data = kunit_kzalloc(test, Ki.len, GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ki.data);
+ usage.data[4] = KEY_USAGE_SEED_INTEGRITY;
+ err = gk5e->derive_key(gk5e, param->base_key, &Ki,
+ &usage, GFP_KERNEL);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ ahash_tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, ahash_tfm);
+ err = crypto_ahash_setkey(ahash_tfm, Ki.data, Ki.len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Act */
+ err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0);
+ KUNIT_ASSERT_EQ(test, err, 0);
+ err = krb5_etm_checksum(cts_tfm, ahash_tfm, &buf, 0, &checksum);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ param->expected_result->len, buf.len,
+ "ciphertext length mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->expected_result->data,
+ buf.head[0].iov_base,
+ param->expected_result->len), 0,
+ "ciphertext mismatch");
+ KUNIT_EXPECT_EQ_MSG(test, memcmp(param->expected_hmac->data,
+ checksum.data,
+ checksum.len), 0,
+ "HMAC mismatch");
+
+ crypto_free_ahash(ahash_tfm);
+ crypto_free_sync_skcipher(cts_tfm);
+ crypto_free_sync_skcipher(cbc_tfm);
+}
+
+static struct kunit_case rfc8009_test_cases[] = {
+ {
+ .name = "RFC 8009 key derivation",
+ .run_case = kdf_case,
+ .generate_params = rfc8009_kdf_gen_params,
+ },
+ {
+ .name = "RFC 8009 checksum",
+ .run_case = checksum_case,
+ .generate_params = rfc8009_checksum_gen_params,
+ },
+ {
+ .name = "RFC 8009 encryption",
+ .run_case = rfc8009_encrypt_case,
+ .generate_params = rfc8009_encrypt_gen_params,
+ },
+ {}
+};
+
+static struct kunit_suite rfc8009_suite = {
+ .name = "RFC 8009 suite",
+ .test_cases = rfc8009_test_cases,
+};
+
+/*
+ * Encryption self-tests
+ */
+
+DEFINE_STR_XDR_NETOBJ(encrypt_selftest_plaintext,
+ "This is the plaintext for the encryption self-test.");
+
+static const struct gss_krb5_test_param encrypt_selftest_params[] = {
+ {
+ .desc = "aes128-cts-hmac-sha1-96 encryption self-test",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+ {
+ .desc = "aes256-cts-hmac-sha1-96 encryption self-test",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA1_96,
+ .Ke = &rfc3962_encryption_key,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+ {
+ .desc = "camellia128-cts-cmac encryption self-test",
+ .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC,
+ .Ke = &camellia128_cts_cmac_Ke,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+ {
+ .desc = "camellia256-cts-cmac encryption self-test",
+ .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC,
+ .Ke = &camellia256_cts_cmac_Ke,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+ {
+ .desc = "aes128-cts-hmac-sha256-128 encryption self-test",
+ .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128,
+ .Ke = &aes128_cts_hmac_sha256_128_Ke,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+ {
+ .desc = "aes256-cts-hmac-sha384-192 encryption self-test",
+ .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192,
+ .Ke = &aes256_cts_hmac_sha384_192_Ke,
+ .plaintext = &encrypt_selftest_plaintext,
+ },
+};
+
+/* Creates the function encrypt_selftest_gen_params */
+KUNIT_ARRAY_PARAM(encrypt_selftest, encrypt_selftest_params,
+ gss_krb5_get_desc);
+
+/*
+ * Encrypt and decrypt plaintext, and ensure the input plaintext
+ * matches the output plaintext. A confounder is not added in this
+ * case.
+ */
+static void encrypt_selftest_case(struct kunit *test)
+{
+ const struct gss_krb5_test_param *param = test->param_value;
+ struct crypto_sync_skcipher *cts_tfm, *cbc_tfm;
+ const struct gss_krb5_enctype *gk5e;
+ struct xdr_buf buf;
+ void *text;
+ int err;
+
+ /* Arrange */
+ gk5e = gss_krb5_lookup_enctype(param->enctype);
+ if (!gk5e)
+ kunit_skip(test, "Encryption type is not available");
+
+ cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm);
+ err = crypto_sync_skcipher_setkey(cbc_tfm, param->Ke->data, param->Ke->len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm);
+ err = crypto_sync_skcipher_setkey(cts_tfm, param->Ke->data, param->Ke->len);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ text = kunit_kzalloc(test, roundup(param->plaintext->len,
+ crypto_sync_skcipher_blocksize(cbc_tfm)),
+ GFP_KERNEL);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text);
+
+ memcpy(text, param->plaintext->data, param->plaintext->len);
+ memset(&buf, 0, sizeof(buf));
+ buf.head[0].iov_base = text;
+ buf.head[0].iov_len = param->plaintext->len;
+ buf.len = buf.head[0].iov_len;
+
+ /* Act */
+ err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0);
+ KUNIT_ASSERT_EQ(test, err, 0);
+ err = krb5_cbc_cts_decrypt(cts_tfm, cbc_tfm, 0, &buf);
+ KUNIT_ASSERT_EQ(test, err, 0);
+
+ /* Assert */
+ KUNIT_EXPECT_EQ_MSG(test,
+ param->plaintext->len, buf.len,
+ "length mismatch");
+ KUNIT_EXPECT_EQ_MSG(test,
+ memcmp(param->plaintext->data,
+ buf.head[0].iov_base, buf.len), 0,
+ "plaintext mismatch");
+
+ crypto_free_sync_skcipher(cts_tfm);
+ crypto_free_sync_skcipher(cbc_tfm);
+}
+
+static struct kunit_case encryption_test_cases[] = {
+ {
+ .name = "Encryption self-tests",
+ .run_case = encrypt_selftest_case,
+ .generate_params = encrypt_selftest_gen_params,
+ },
+ {}
+};
+
+static struct kunit_suite encryption_test_suite = {
+ .name = "Encryption test suite",
+ .test_cases = encryption_test_cases,
+};
+
+kunit_test_suites(&rfc3961_suite,
+ &rfc3962_suite,
+ &rfc6803_suite,
+ &rfc8009_suite,
+ &encryption_test_suite);
+
+MODULE_DESCRIPTION("Test RPCSEC GSS Kerberos 5 functions");
+MODULE_LICENSE("GPL");
diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c
new file mode 100644
index 0000000000..4fbc50a0a2
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c
@@ -0,0 +1,128 @@
+/*
+ * linux/net/sunrpc/gss_krb5_unseal.c
+ *
+ * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5unseal.c
+ *
+ * Copyright (c) 2000-2008 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * Andy Adamson <andros@umich.edu>
+ */
+
+/*
+ * Copyright 1993 by OpenVision Technologies, Inc.
+ *
+ * Permission to use, copy, modify, distribute, and sell this software
+ * and its documentation for any purpose is hereby granted without fee,
+ * provided that the above copyright notice appears in all copies and
+ * that both that copyright notice and this permission notice appear in
+ * supporting documentation, and that the name of OpenVision not be used
+ * in advertising or publicity pertaining to distribution of the software
+ * without specific, written prior permission. OpenVision makes no
+ * representations about the suitability of this software for any
+ * purpose. It is provided "as is" without express or implied warranty.
+ *
+ * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
+ * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO
+ * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR
+ * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF
+ * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
+ * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+ * PERFORMANCE OF THIS SOFTWARE.
+ */
+
+/*
+ * Copyright (C) 1998 by the FundsXpress, INC.
+ *
+ * All rights reserved.
+ *
+ * Export of this software from the United States of America may require
+ * a specific license from the United States Government. It is the
+ * responsibility of any person or organization contemplating export to
+ * obtain such a license before exporting.
+ *
+ * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
+ * distribute this software and its documentation for any purpose and
+ * without fee is hereby granted, provided that the above copyright
+ * notice appear in all copies and that both that copyright notice and
+ * this permission notice appear in supporting documentation, and that
+ * the name of FundsXpress. not be used in advertising or publicity pertaining
+ * to distribution of the software without specific, written prior
+ * permission. FundsXpress makes no representations about the suitability of
+ * this software for any purpose. It is provided "as is" without express
+ * or implied warranty.
+ *
+ * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
+ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+#include <crypto/algapi.h>
+#include <linux/types.h>
+#include <linux/jiffies.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/crypto.h>
+
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+u32
+gss_krb5_verify_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *message_buffer,
+ struct xdr_netobj *read_token)
+{
+ struct crypto_ahash *tfm = ctx->initiate ?
+ ctx->acceptor_sign : ctx->initiator_sign;
+ char cksumdata[GSS_KRB5_MAX_CKSUM_LEN];
+ struct xdr_netobj cksumobj = {
+ .len = ctx->gk5e->cksumlength,
+ .data = cksumdata,
+ };
+ u8 *ptr = read_token->data;
+ __be16 be16_ptr;
+ time64_t now;
+ u8 flags;
+ int i;
+
+ dprintk("RPC: %s\n", __func__);
+
+ memcpy(&be16_ptr, (char *) ptr, 2);
+ if (be16_to_cpu(be16_ptr) != KG2_TOK_MIC)
+ return GSS_S_DEFECTIVE_TOKEN;
+
+ flags = ptr[2];
+ if ((!ctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) ||
+ (ctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)))
+ return GSS_S_BAD_SIG;
+
+ if (flags & KG2_TOKEN_FLAG_SEALED) {
+ dprintk("%s: token has unexpected sealed flag\n", __func__);
+ return GSS_S_FAILURE;
+ }
+
+ for (i = 3; i < 8; i++)
+ if (ptr[i] != 0xff)
+ return GSS_S_DEFECTIVE_TOKEN;
+
+ if (gss_krb5_checksum(tfm, ptr, GSS_KRB5_TOK_HDR_LEN,
+ message_buffer, 0, &cksumobj))
+ return GSS_S_FAILURE;
+
+ if (memcmp(cksumobj.data, ptr + GSS_KRB5_TOK_HDR_LEN,
+ ctx->gk5e->cksumlength))
+ return GSS_S_BAD_SIG;
+
+ /* it got through unscathed. Make sure the context is unexpired */
+ now = ktime_get_real_seconds();
+ if (now > ctx->endtime)
+ return GSS_S_CONTEXT_EXPIRED;
+
+ /*
+ * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss
+ * doesn't want it checked; see page 6 of rfc 2203.
+ */
+
+ return GSS_S_COMPLETE;
+}
diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c
new file mode 100644
index 0000000000..b3e1738ff6
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c
@@ -0,0 +1,237 @@
+/*
+ * COPYRIGHT (c) 2008
+ * The Regents of the University of Michigan
+ * ALL RIGHTS RESERVED
+ *
+ * Permission is granted to use, copy, create derivative works
+ * and redistribute this software and such derivative works
+ * for any purpose, so long as the name of The University of
+ * Michigan is not used in any advertising or publicity
+ * pertaining to the use of distribution of this software
+ * without specific, written prior authorization. If the
+ * above copyright notice or any other identification of the
+ * University of Michigan is included in any copy of any
+ * portion of this software, then the disclaimer below must
+ * also be included.
+ *
+ * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION
+ * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY
+ * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF
+ * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE
+ * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE
+ * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR
+ * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING
+ * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN
+ * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF
+ * SUCH DAMAGES.
+ */
+
+#include <crypto/skcipher.h>
+#include <linux/types.h>
+#include <linux/jiffies.h>
+#include <linux/sunrpc/gss_krb5.h>
+#include <linux/pagemap.h>
+
+#include "gss_krb5_internal.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+/*
+ * We can shift data by up to LOCAL_BUF_LEN bytes in a pass. If we need
+ * to do more than that, we shift repeatedly. Kevin Coffman reports
+ * seeing 28 bytes as the value used by Microsoft clients and servers
+ * with AES, so this constant is chosen to allow handling 28 in one pass
+ * without using too much stack space.
+ *
+ * If that proves to a problem perhaps we could use a more clever
+ * algorithm.
+ */
+#define LOCAL_BUF_LEN 32u
+
+static void rotate_buf_a_little(struct xdr_buf *buf, unsigned int shift)
+{
+ char head[LOCAL_BUF_LEN];
+ char tmp[LOCAL_BUF_LEN];
+ unsigned int this_len, i;
+
+ BUG_ON(shift > LOCAL_BUF_LEN);
+
+ read_bytes_from_xdr_buf(buf, 0, head, shift);
+ for (i = 0; i + shift < buf->len; i += LOCAL_BUF_LEN) {
+ this_len = min(LOCAL_BUF_LEN, buf->len - (i + shift));
+ read_bytes_from_xdr_buf(buf, i+shift, tmp, this_len);
+ write_bytes_to_xdr_buf(buf, i, tmp, this_len);
+ }
+ write_bytes_to_xdr_buf(buf, buf->len - shift, head, shift);
+}
+
+static void _rotate_left(struct xdr_buf *buf, unsigned int shift)
+{
+ int shifted = 0;
+ int this_shift;
+
+ shift %= buf->len;
+ while (shifted < shift) {
+ this_shift = min(shift - shifted, LOCAL_BUF_LEN);
+ rotate_buf_a_little(buf, this_shift);
+ shifted += this_shift;
+ }
+}
+
+static void rotate_left(u32 base, struct xdr_buf *buf, unsigned int shift)
+{
+ struct xdr_buf subbuf;
+
+ xdr_buf_subsegment(buf, &subbuf, base, buf->len - base);
+ _rotate_left(&subbuf, shift);
+}
+
+u32
+gss_krb5_wrap_v2(struct krb5_ctx *kctx, int offset,
+ struct xdr_buf *buf, struct page **pages)
+{
+ u8 *ptr;
+ time64_t now;
+ u8 flags = 0x00;
+ __be16 *be16ptr;
+ __be64 *be64ptr;
+ u32 err;
+
+ dprintk("RPC: %s\n", __func__);
+
+ /* make room for gss token header */
+ if (xdr_extend_head(buf, offset, GSS_KRB5_TOK_HDR_LEN))
+ return GSS_S_FAILURE;
+
+ /* construct gss token header */
+ ptr = buf->head[0].iov_base + offset;
+ *ptr++ = (unsigned char) ((KG2_TOK_WRAP>>8) & 0xff);
+ *ptr++ = (unsigned char) (KG2_TOK_WRAP & 0xff);
+
+ if ((kctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0)
+ flags |= KG2_TOKEN_FLAG_SENTBYACCEPTOR;
+ if ((kctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) != 0)
+ flags |= KG2_TOKEN_FLAG_ACCEPTORSUBKEY;
+ /* We always do confidentiality in wrap tokens */
+ flags |= KG2_TOKEN_FLAG_SEALED;
+
+ *ptr++ = flags;
+ *ptr++ = 0xff;
+ be16ptr = (__be16 *)ptr;
+
+ *be16ptr++ = 0;
+ /* "inner" token header always uses 0 for RRC */
+ *be16ptr++ = 0;
+
+ be64ptr = (__be64 *)be16ptr;
+ *be64ptr = cpu_to_be64(atomic64_fetch_inc(&kctx->seq_send64));
+
+ err = (*kctx->gk5e->encrypt)(kctx, offset, buf, pages);
+ if (err)
+ return err;
+
+ now = ktime_get_real_seconds();
+ return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE;
+}
+
+u32
+gss_krb5_unwrap_v2(struct krb5_ctx *kctx, int offset, int len,
+ struct xdr_buf *buf, unsigned int *slack,
+ unsigned int *align)
+{
+ time64_t now;
+ u8 *ptr;
+ u8 flags = 0x00;
+ u16 ec, rrc;
+ int err;
+ u32 headskip, tailskip;
+ u8 decrypted_hdr[GSS_KRB5_TOK_HDR_LEN];
+ unsigned int movelen;
+
+
+ dprintk("RPC: %s\n", __func__);
+
+ ptr = buf->head[0].iov_base + offset;
+
+ if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_WRAP)
+ return GSS_S_DEFECTIVE_TOKEN;
+
+ flags = ptr[2];
+ if ((!kctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) ||
+ (kctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)))
+ return GSS_S_BAD_SIG;
+
+ if ((flags & KG2_TOKEN_FLAG_SEALED) == 0) {
+ dprintk("%s: token missing expected sealed flag\n", __func__);
+ return GSS_S_DEFECTIVE_TOKEN;
+ }
+
+ if (ptr[3] != 0xff)
+ return GSS_S_DEFECTIVE_TOKEN;
+
+ ec = be16_to_cpup((__be16 *)(ptr + 4));
+ rrc = be16_to_cpup((__be16 *)(ptr + 6));
+
+ /*
+ * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss
+ * doesn't want it checked; see page 6 of rfc 2203.
+ */
+
+ if (rrc != 0)
+ rotate_left(offset + 16, buf, rrc);
+
+ err = (*kctx->gk5e->decrypt)(kctx, offset, len, buf,
+ &headskip, &tailskip);
+ if (err)
+ return GSS_S_FAILURE;
+
+ /*
+ * Retrieve the decrypted gss token header and verify
+ * it against the original
+ */
+ err = read_bytes_from_xdr_buf(buf,
+ len - GSS_KRB5_TOK_HDR_LEN - tailskip,
+ decrypted_hdr, GSS_KRB5_TOK_HDR_LEN);
+ if (err) {
+ dprintk("%s: error %u getting decrypted_hdr\n", __func__, err);
+ return GSS_S_FAILURE;
+ }
+ if (memcmp(ptr, decrypted_hdr, 6)
+ || memcmp(ptr + 8, decrypted_hdr + 8, 8)) {
+ dprintk("%s: token hdr, plaintext hdr mismatch!\n", __func__);
+ return GSS_S_FAILURE;
+ }
+
+ /* do sequencing checks */
+
+ /* it got through unscathed. Make sure the context is unexpired */
+ now = ktime_get_real_seconds();
+ if (now > kctx->endtime)
+ return GSS_S_CONTEXT_EXPIRED;
+
+ /*
+ * Move the head data back to the right position in xdr_buf.
+ * We ignore any "ec" data since it might be in the head or
+ * the tail, and we really don't need to deal with it.
+ * Note that buf->head[0].iov_len may indicate the available
+ * head buffer space rather than that actually occupied.
+ */
+ movelen = min_t(unsigned int, buf->head[0].iov_len, len);
+ movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip;
+ BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen >
+ buf->head[0].iov_len);
+ memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen);
+ buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip;
+ buf->len = len - (GSS_KRB5_TOK_HDR_LEN + headskip);
+
+ /* Trim off the trailing "extra count" and checksum blob */
+ xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip);
+
+ *align = XDR_QUADLEN(GSS_KRB5_TOK_HDR_LEN + headskip);
+ *slack = *align + XDR_QUADLEN(ec + GSS_KRB5_TOK_HDR_LEN + tailskip);
+ return GSS_S_COMPLETE;
+}
diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c
new file mode 100644
index 0000000000..fae632da10
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_mech_switch.c
@@ -0,0 +1,448 @@
+// SPDX-License-Identifier: BSD-3-Clause
+/*
+ * linux/net/sunrpc/gss_mech_switch.c
+ *
+ * Copyright (c) 2001 The Regents of the University of Michigan.
+ * All rights reserved.
+ *
+ * J. Bruce Fields <bfields@umich.edu>
+ */
+
+#include <linux/types.h>
+#include <linux/slab.h>
+#include <linux/module.h>
+#include <linux/oid_registry.h>
+#include <linux/sunrpc/msg_prot.h>
+#include <linux/sunrpc/gss_asn1.h>
+#include <linux/sunrpc/auth_gss.h>
+#include <linux/sunrpc/svcauth_gss.h>
+#include <linux/sunrpc/gss_err.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/gss_api.h>
+#include <linux/sunrpc/clnt.h>
+#include <trace/events/rpcgss.h>
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+static LIST_HEAD(registered_mechs);
+static DEFINE_SPINLOCK(registered_mechs_lock);
+
+static void
+gss_mech_free(struct gss_api_mech *gm)
+{
+ struct pf_desc *pf;
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ pf = &gm->gm_pfs[i];
+ if (pf->domain)
+ auth_domain_put(pf->domain);
+ kfree(pf->auth_domain_name);
+ pf->auth_domain_name = NULL;
+ }
+}
+
+static inline char *
+make_auth_domain_name(char *name)
+{
+ static char *prefix = "gss/";
+ char *new;
+
+ new = kmalloc(strlen(name) + strlen(prefix) + 1, GFP_KERNEL);
+ if (new) {
+ strcpy(new, prefix);
+ strcat(new, name);
+ }
+ return new;
+}
+
+static int
+gss_mech_svc_setup(struct gss_api_mech *gm)
+{
+ struct auth_domain *dom;
+ struct pf_desc *pf;
+ int i, status;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ pf = &gm->gm_pfs[i];
+ pf->auth_domain_name = make_auth_domain_name(pf->name);
+ status = -ENOMEM;
+ if (pf->auth_domain_name == NULL)
+ goto out;
+ dom = svcauth_gss_register_pseudoflavor(
+ pf->pseudoflavor, pf->auth_domain_name);
+ if (IS_ERR(dom)) {
+ status = PTR_ERR(dom);
+ goto out;
+ }
+ pf->domain = dom;
+ }
+ return 0;
+out:
+ gss_mech_free(gm);
+ return status;
+}
+
+/**
+ * gss_mech_register - register a GSS mechanism
+ * @gm: GSS mechanism handle
+ *
+ * Returns zero if successful, or a negative errno.
+ */
+int gss_mech_register(struct gss_api_mech *gm)
+{
+ int status;
+
+ status = gss_mech_svc_setup(gm);
+ if (status)
+ return status;
+ spin_lock(&registered_mechs_lock);
+ list_add_rcu(&gm->gm_list, &registered_mechs);
+ spin_unlock(&registered_mechs_lock);
+ dprintk("RPC: registered gss mechanism %s\n", gm->gm_name);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(gss_mech_register);
+
+/**
+ * gss_mech_unregister - release a GSS mechanism
+ * @gm: GSS mechanism handle
+ *
+ */
+void gss_mech_unregister(struct gss_api_mech *gm)
+{
+ spin_lock(&registered_mechs_lock);
+ list_del_rcu(&gm->gm_list);
+ spin_unlock(&registered_mechs_lock);
+ dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name);
+ gss_mech_free(gm);
+}
+EXPORT_SYMBOL_GPL(gss_mech_unregister);
+
+struct gss_api_mech *gss_mech_get(struct gss_api_mech *gm)
+{
+ __module_get(gm->gm_owner);
+ return gm;
+}
+EXPORT_SYMBOL(gss_mech_get);
+
+static struct gss_api_mech *
+_gss_mech_get_by_name(const char *name)
+{
+ struct gss_api_mech *pos, *gm = NULL;
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(pos, &registered_mechs, gm_list) {
+ if (0 == strcmp(name, pos->gm_name)) {
+ if (try_module_get(pos->gm_owner))
+ gm = pos;
+ break;
+ }
+ }
+ rcu_read_unlock();
+ return gm;
+
+}
+
+struct gss_api_mech * gss_mech_get_by_name(const char *name)
+{
+ struct gss_api_mech *gm = NULL;
+
+ gm = _gss_mech_get_by_name(name);
+ if (!gm) {
+ request_module("rpc-auth-gss-%s", name);
+ gm = _gss_mech_get_by_name(name);
+ }
+ return gm;
+}
+
+struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj)
+{
+ struct gss_api_mech *pos, *gm = NULL;
+ char buf[32];
+
+ if (sprint_oid(obj->data, obj->len, buf, sizeof(buf)) < 0)
+ return NULL;
+ request_module("rpc-auth-gss-%s", buf);
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(pos, &registered_mechs, gm_list) {
+ if (obj->len == pos->gm_oid.len) {
+ if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) {
+ if (try_module_get(pos->gm_owner))
+ gm = pos;
+ break;
+ }
+ }
+ }
+ rcu_read_unlock();
+ if (!gm)
+ trace_rpcgss_oid_to_mech(buf);
+ return gm;
+}
+
+static inline int
+mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor)
+{
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].pseudoflavor == pseudoflavor)
+ return 1;
+ }
+ return 0;
+}
+
+static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor)
+{
+ struct gss_api_mech *gm = NULL, *pos;
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(pos, &registered_mechs, gm_list) {
+ if (!mech_supports_pseudoflavor(pos, pseudoflavor))
+ continue;
+ if (try_module_get(pos->gm_owner))
+ gm = pos;
+ break;
+ }
+ rcu_read_unlock();
+ return gm;
+}
+
+struct gss_api_mech *
+gss_mech_get_by_pseudoflavor(u32 pseudoflavor)
+{
+ struct gss_api_mech *gm;
+
+ gm = _gss_mech_get_by_pseudoflavor(pseudoflavor);
+
+ if (!gm) {
+ request_module("rpc-auth-gss-%u", pseudoflavor);
+ gm = _gss_mech_get_by_pseudoflavor(pseudoflavor);
+ }
+ return gm;
+}
+
+/**
+ * gss_svc_to_pseudoflavor - map a GSS service number to a pseudoflavor
+ * @gm: GSS mechanism handle
+ * @qop: GSS quality-of-protection value
+ * @service: GSS service value
+ *
+ * Returns a matching security flavor, or RPC_AUTH_MAXFLAVOR if none is found.
+ */
+rpc_authflavor_t gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 qop,
+ u32 service)
+{
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].qop == qop &&
+ gm->gm_pfs[i].service == service) {
+ return gm->gm_pfs[i].pseudoflavor;
+ }
+ }
+ return RPC_AUTH_MAXFLAVOR;
+}
+
+/**
+ * gss_mech_info2flavor - look up a pseudoflavor given a GSS tuple
+ * @info: a GSS mech OID, quality of protection, and service value
+ *
+ * Returns a matching pseudoflavor, or RPC_AUTH_MAXFLAVOR if the tuple is
+ * not supported.
+ */
+rpc_authflavor_t gss_mech_info2flavor(struct rpcsec_gss_info *info)
+{
+ rpc_authflavor_t pseudoflavor;
+ struct gss_api_mech *gm;
+
+ gm = gss_mech_get_by_OID(&info->oid);
+ if (gm == NULL)
+ return RPC_AUTH_MAXFLAVOR;
+
+ pseudoflavor = gss_svc_to_pseudoflavor(gm, info->qop, info->service);
+
+ gss_mech_put(gm);
+ return pseudoflavor;
+}
+
+/**
+ * gss_mech_flavor2info - look up a GSS tuple for a given pseudoflavor
+ * @pseudoflavor: GSS pseudoflavor to match
+ * @info: rpcsec_gss_info structure to fill in
+ *
+ * Returns zero and fills in "info" if pseudoflavor matches a
+ * supported mechanism. Otherwise a negative errno is returned.
+ */
+int gss_mech_flavor2info(rpc_authflavor_t pseudoflavor,
+ struct rpcsec_gss_info *info)
+{
+ struct gss_api_mech *gm;
+ int i;
+
+ gm = gss_mech_get_by_pseudoflavor(pseudoflavor);
+ if (gm == NULL)
+ return -ENOENT;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) {
+ memcpy(info->oid.data, gm->gm_oid.data, gm->gm_oid.len);
+ info->oid.len = gm->gm_oid.len;
+ info->qop = gm->gm_pfs[i].qop;
+ info->service = gm->gm_pfs[i].service;
+ gss_mech_put(gm);
+ return 0;
+ }
+ }
+
+ gss_mech_put(gm);
+ return -ENOENT;
+}
+
+u32
+gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor)
+{
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].pseudoflavor == pseudoflavor)
+ return gm->gm_pfs[i].service;
+ }
+ return 0;
+}
+EXPORT_SYMBOL(gss_pseudoflavor_to_service);
+
+bool
+gss_pseudoflavor_to_datatouch(struct gss_api_mech *gm, u32 pseudoflavor)
+{
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].pseudoflavor == pseudoflavor)
+ return gm->gm_pfs[i].datatouch;
+ }
+ return false;
+}
+
+char *
+gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service)
+{
+ int i;
+
+ for (i = 0; i < gm->gm_pf_num; i++) {
+ if (gm->gm_pfs[i].service == service)
+ return gm->gm_pfs[i].auth_domain_name;
+ }
+ return NULL;
+}
+
+void
+gss_mech_put(struct gss_api_mech * gm)
+{
+ if (gm)
+ module_put(gm->gm_owner);
+}
+EXPORT_SYMBOL(gss_mech_put);
+
+/* The mech could probably be determined from the token instead, but it's just
+ * as easy for now to pass it in. */
+int
+gss_import_sec_context(const void *input_token, size_t bufsize,
+ struct gss_api_mech *mech,
+ struct gss_ctx **ctx_id,
+ time64_t *endtime,
+ gfp_t gfp_mask)
+{
+ if (!(*ctx_id = kzalloc(sizeof(**ctx_id), gfp_mask)))
+ return -ENOMEM;
+ (*ctx_id)->mech_type = gss_mech_get(mech);
+
+ return mech->gm_ops->gss_import_sec_context(input_token, bufsize,
+ *ctx_id, endtime, gfp_mask);
+}
+
+/* gss_get_mic: compute a mic over message and return mic_token. */
+
+u32
+gss_get_mic(struct gss_ctx *context_handle,
+ struct xdr_buf *message,
+ struct xdr_netobj *mic_token)
+{
+ return context_handle->mech_type->gm_ops
+ ->gss_get_mic(context_handle,
+ message,
+ mic_token);
+}
+
+/* gss_verify_mic: check whether the provided mic_token verifies message. */
+
+u32
+gss_verify_mic(struct gss_ctx *context_handle,
+ struct xdr_buf *message,
+ struct xdr_netobj *mic_token)
+{
+ return context_handle->mech_type->gm_ops
+ ->gss_verify_mic(context_handle,
+ message,
+ mic_token);
+}
+
+/*
+ * This function is called from both the client and server code.
+ * Each makes guarantees about how much "slack" space is available
+ * for the underlying function in "buf"'s head and tail while
+ * performing the wrap.
+ *
+ * The client and server code allocate RPC_MAX_AUTH_SIZE extra
+ * space in both the head and tail which is available for use by
+ * the wrap function.
+ *
+ * Underlying functions should verify they do not use more than
+ * RPC_MAX_AUTH_SIZE of extra space in either the head or tail
+ * when performing the wrap.
+ */
+u32
+gss_wrap(struct gss_ctx *ctx_id,
+ int offset,
+ struct xdr_buf *buf,
+ struct page **inpages)
+{
+ return ctx_id->mech_type->gm_ops
+ ->gss_wrap(ctx_id, offset, buf, inpages);
+}
+
+u32
+gss_unwrap(struct gss_ctx *ctx_id,
+ int offset,
+ int len,
+ struct xdr_buf *buf)
+{
+ return ctx_id->mech_type->gm_ops
+ ->gss_unwrap(ctx_id, offset, len, buf);
+}
+
+
+/* gss_delete_sec_context: free all resources associated with context_handle.
+ * Note this differs from the RFC 2744-specified prototype in that we don't
+ * bother returning an output token, since it would never be used anyway. */
+
+u32
+gss_delete_sec_context(struct gss_ctx **context_handle)
+{
+ dprintk("RPC: gss_delete_sec_context deleting %p\n",
+ *context_handle);
+
+ if (!*context_handle)
+ return GSS_S_NO_CONTEXT;
+ if ((*context_handle)->internal_ctx_id)
+ (*context_handle)->mech_type->gm_ops
+ ->gss_delete_sec_context((*context_handle)
+ ->internal_ctx_id);
+ gss_mech_put((*context_handle)->mech_type);
+ kfree(*context_handle);
+ *context_handle=NULL;
+ return GSS_S_COMPLETE;
+}
diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c
new file mode 100644
index 0000000000..f549e4c05d
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c
@@ -0,0 +1,403 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * linux/net/sunrpc/gss_rpc_upcall.c
+ *
+ * Copyright (C) 2012 Simo Sorce <simo@redhat.com>
+ */
+
+#include <linux/types.h>
+#include <linux/un.h>
+
+#include <linux/sunrpc/svcauth.h>
+#include "gss_rpc_upcall.h"
+
+#define GSSPROXY_SOCK_PATHNAME "/var/run/gssproxy.sock"
+
+#define GSSPROXY_PROGRAM (400112u)
+#define GSSPROXY_VERS_1 (1u)
+
+/*
+ * Encoding/Decoding functions
+ */
+
+enum {
+ GSSX_NULL = 0, /* Unused */
+ GSSX_INDICATE_MECHS = 1,
+ GSSX_GET_CALL_CONTEXT = 2,
+ GSSX_IMPORT_AND_CANON_NAME = 3,
+ GSSX_EXPORT_CRED = 4,
+ GSSX_IMPORT_CRED = 5,
+ GSSX_ACQUIRE_CRED = 6,
+ GSSX_STORE_CRED = 7,
+ GSSX_INIT_SEC_CONTEXT = 8,
+ GSSX_ACCEPT_SEC_CONTEXT = 9,
+ GSSX_RELEASE_HANDLE = 10,
+ GSSX_GET_MIC = 11,
+ GSSX_VERIFY = 12,
+ GSSX_WRAP = 13,
+ GSSX_UNWRAP = 14,
+ GSSX_WRAP_SIZE_LIMIT = 15,
+};
+
+#define PROC(proc, name) \
+[GSSX_##proc] = { \
+ .p_proc = GSSX_##proc, \
+ .p_encode = gssx_enc_##name, \
+ .p_decode = gssx_dec_##name, \
+ .p_arglen = GSSX_ARG_##name##_sz, \
+ .p_replen = GSSX_RES_##name##_sz, \
+ .p_statidx = GSSX_##proc, \
+ .p_name = #proc, \
+}
+
+static const struct rpc_procinfo gssp_procedures[] = {
+ PROC(INDICATE_MECHS, indicate_mechs),
+ PROC(GET_CALL_CONTEXT, get_call_context),
+ PROC(IMPORT_AND_CANON_NAME, import_and_canon_name),
+ PROC(EXPORT_CRED, export_cred),
+ PROC(IMPORT_CRED, import_cred),
+ PROC(ACQUIRE_CRED, acquire_cred),
+ PROC(STORE_CRED, store_cred),
+ PROC(INIT_SEC_CONTEXT, init_sec_context),
+ PROC(ACCEPT_SEC_CONTEXT, accept_sec_context),
+ PROC(RELEASE_HANDLE, release_handle),
+ PROC(GET_MIC, get_mic),
+ PROC(VERIFY, verify),
+ PROC(WRAP, wrap),
+ PROC(UNWRAP, unwrap),
+ PROC(WRAP_SIZE_LIMIT, wrap_size_limit),
+};
+
+
+
+/*
+ * Common transport functions
+ */
+
+static const struct rpc_program gssp_program;
+
+static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt)
+{
+ static const struct sockaddr_un gssp_localaddr = {
+ .sun_family = AF_LOCAL,
+ .sun_path = GSSPROXY_SOCK_PATHNAME,
+ };
+ struct rpc_create_args args = {
+ .net = net,
+ .protocol = XPRT_TRANSPORT_LOCAL,
+ .address = (struct sockaddr *)&gssp_localaddr,
+ .addrsize = sizeof(gssp_localaddr),
+ .servername = "localhost",
+ .program = &gssp_program,
+ .version = GSSPROXY_VERS_1,
+ .authflavor = RPC_AUTH_NULL,
+ /*
+ * Note we want connection to be done in the caller's
+ * filesystem namespace. We therefore turn off the idle
+ * timeout, which would result in reconnections being
+ * done without the correct namespace:
+ */
+ .flags = RPC_CLNT_CREATE_NOPING |
+ RPC_CLNT_CREATE_CONNECTED |
+ RPC_CLNT_CREATE_NO_IDLE_TIMEOUT
+ };
+ struct rpc_clnt *clnt;
+ int result = 0;
+
+ clnt = rpc_create(&args);
+ if (IS_ERR(clnt)) {
+ dprintk("RPC: failed to create AF_LOCAL gssproxy "
+ "client (errno %ld).\n", PTR_ERR(clnt));
+ result = PTR_ERR(clnt);
+ *_clnt = NULL;
+ goto out;
+ }
+
+ dprintk("RPC: created new gssp local client (gssp_local_clnt: "
+ "%p)\n", clnt);
+ *_clnt = clnt;
+
+out:
+ return result;
+}
+
+void init_gssp_clnt(struct sunrpc_net *sn)
+{
+ mutex_init(&sn->gssp_lock);
+ sn->gssp_clnt = NULL;
+}
+
+int set_gssp_clnt(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_clnt *clnt;
+ int ret;
+
+ mutex_lock(&sn->gssp_lock);
+ ret = gssp_rpc_create(net, &clnt);
+ if (!ret) {
+ if (sn->gssp_clnt)
+ rpc_shutdown_client(sn->gssp_clnt);
+ sn->gssp_clnt = clnt;
+ }
+ mutex_unlock(&sn->gssp_lock);
+ return ret;
+}
+
+void clear_gssp_clnt(struct sunrpc_net *sn)
+{
+ mutex_lock(&sn->gssp_lock);
+ if (sn->gssp_clnt) {
+ rpc_shutdown_client(sn->gssp_clnt);
+ sn->gssp_clnt = NULL;
+ }
+ mutex_unlock(&sn->gssp_lock);
+}
+
+static struct rpc_clnt *get_gssp_clnt(struct sunrpc_net *sn)
+{
+ struct rpc_clnt *clnt;
+
+ mutex_lock(&sn->gssp_lock);
+ clnt = sn->gssp_clnt;
+ if (clnt)
+ refcount_inc(&clnt->cl_count);
+ mutex_unlock(&sn->gssp_lock);
+ return clnt;
+}
+
+static int gssp_call(struct net *net, struct rpc_message *msg)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_clnt *clnt;
+ int status;
+
+ clnt = get_gssp_clnt(sn);
+ if (!clnt)
+ return -EIO;
+ status = rpc_call_sync(clnt, msg, 0);
+ if (status < 0) {
+ dprintk("gssp: rpc_call returned error %d\n", -status);
+ switch (status) {
+ case -EPROTONOSUPPORT:
+ status = -EINVAL;
+ break;
+ case -ECONNREFUSED:
+ case -ETIMEDOUT:
+ case -ENOTCONN:
+ status = -EAGAIN;
+ break;
+ case -ERESTARTSYS:
+ if (signalled ())
+ status = -EINTR;
+ break;
+ default:
+ break;
+ }
+ }
+ rpc_release_client(clnt);
+ return status;
+}
+
+static void gssp_free_receive_pages(struct gssx_arg_accept_sec_context *arg)
+{
+ unsigned int i;
+
+ for (i = 0; i < arg->npages && arg->pages[i]; i++)
+ __free_page(arg->pages[i]);
+
+ kfree(arg->pages);
+}
+
+static int gssp_alloc_receive_pages(struct gssx_arg_accept_sec_context *arg)
+{
+ unsigned int i;
+
+ arg->npages = DIV_ROUND_UP(NGROUPS_MAX * 4, PAGE_SIZE);
+ arg->pages = kcalloc(arg->npages, sizeof(struct page *), GFP_KERNEL);
+ if (!arg->pages)
+ return -ENOMEM;
+ for (i = 0; i < arg->npages; i++) {
+ arg->pages[i] = alloc_page(GFP_KERNEL);
+ if (!arg->pages[i]) {
+ gssp_free_receive_pages(arg);
+ return -ENOMEM;
+ }
+ }
+ return 0;
+}
+
+static char *gssp_stringify(struct xdr_netobj *netobj)
+{
+ return kmemdup_nul(netobj->data, netobj->len, GFP_KERNEL);
+}
+
+static void gssp_hostbased_service(char **principal)
+{
+ char *c;
+
+ if (!*principal)
+ return;
+
+ /* terminate and remove realm part */
+ c = strchr(*principal, '@');
+ if (c) {
+ *c = '\0';
+
+ /* change service-hostname delimiter */
+ c = strchr(*principal, '/');
+ if (c)
+ *c = '@';
+ }
+ if (!c) {
+ /* not a service principal */
+ kfree(*principal);
+ *principal = NULL;
+ }
+}
+
+/*
+ * Public functions
+ */
+
+/* numbers somewhat arbitrary but large enough for current needs */
+#define GSSX_MAX_OUT_HANDLE 128
+#define GSSX_MAX_SRC_PRINC 256
+#define GSSX_KMEMBUF (GSSX_max_output_handle_sz + \
+ GSSX_max_oid_sz + \
+ GSSX_max_princ_sz + \
+ sizeof(struct svc_cred))
+
+int gssp_accept_sec_context_upcall(struct net *net,
+ struct gssp_upcall_data *data)
+{
+ struct gssx_ctx ctxh = {
+ .state = data->in_handle
+ };
+ struct gssx_arg_accept_sec_context arg = {
+ .input_token = data->in_token,
+ };
+ struct gssx_ctx rctxh = {
+ /*
+ * pass in the max length we expect for each of these
+ * buffers but let the xdr code kmalloc them:
+ */
+ .exported_context_token.len = GSSX_max_output_handle_sz,
+ .mech.len = GSS_OID_MAX_LEN,
+ .targ_name.display_name.len = GSSX_max_princ_sz,
+ .src_name.display_name.len = GSSX_max_princ_sz
+ };
+ struct gssx_res_accept_sec_context res = {
+ .context_handle = &rctxh,
+ .output_token = &data->out_token
+ };
+ struct rpc_message msg = {
+ .rpc_proc = &gssp_procedures[GSSX_ACCEPT_SEC_CONTEXT],
+ .rpc_argp = &arg,
+ .rpc_resp = &res,
+ .rpc_cred = NULL, /* FIXME ? */
+ };
+ struct xdr_netobj client_name = { 0 , NULL };
+ struct xdr_netobj target_name = { 0, NULL };
+ int ret;
+
+ if (data->in_handle.len != 0)
+ arg.context_handle = &ctxh;
+ res.output_token->len = GSSX_max_output_token_sz;
+
+ ret = gssp_alloc_receive_pages(&arg);
+ if (ret)
+ return ret;
+
+ ret = gssp_call(net, &msg);
+
+ gssp_free_receive_pages(&arg);
+
+ /* we need to fetch all data even in case of error so
+ * that we can free special strctures is they have been allocated */
+ data->major_status = res.status.major_status;
+ data->minor_status = res.status.minor_status;
+ if (res.context_handle) {
+ data->out_handle = rctxh.exported_context_token;
+ data->mech_oid.len = rctxh.mech.len;
+ if (rctxh.mech.data) {
+ memcpy(data->mech_oid.data, rctxh.mech.data,
+ data->mech_oid.len);
+ kfree(rctxh.mech.data);
+ }
+ client_name = rctxh.src_name.display_name;
+ target_name = rctxh.targ_name.display_name;
+ }
+
+ if (res.options.count == 1) {
+ gssx_buffer *value = &res.options.data[0].value;
+ /* Currently we only decode CREDS_VALUE, if we add
+ * anything else we'll have to loop and match on the
+ * option name */
+ if (value->len == 1) {
+ /* steal group info from struct svc_cred */
+ data->creds = *(struct svc_cred *)value->data;
+ data->found_creds = 1;
+ }
+ /* whether we use it or not, free data */
+ kfree(value->data);
+ }
+
+ if (res.options.count != 0) {
+ kfree(res.options.data);
+ }
+
+ /* convert to GSS_NT_HOSTBASED_SERVICE form and set into creds */
+ if (data->found_creds) {
+ if (client_name.data) {
+ data->creds.cr_raw_principal =
+ gssp_stringify(&client_name);
+ data->creds.cr_principal =
+ gssp_stringify(&client_name);
+ gssp_hostbased_service(&data->creds.cr_principal);
+ }
+ if (target_name.data) {
+ data->creds.cr_targ_princ =
+ gssp_stringify(&target_name);
+ gssp_hostbased_service(&data->creds.cr_targ_princ);
+ }
+ }
+ kfree(client_name.data);
+ kfree(target_name.data);
+
+ return ret;
+}
+
+void gssp_free_upcall_data(struct gssp_upcall_data *data)
+{
+ kfree(data->in_handle.data);
+ kfree(data->out_handle.data);
+ kfree(data->out_token.data);
+ free_svc_cred(&data->creds);
+}
+
+/*
+ * Initialization stuff
+ */
+static unsigned int gssp_version1_counts[ARRAY_SIZE(gssp_procedures)];
+static const struct rpc_version gssp_version1 = {
+ .number = GSSPROXY_VERS_1,
+ .nrprocs = ARRAY_SIZE(gssp_procedures),
+ .procs = gssp_procedures,
+ .counts = gssp_version1_counts,
+};
+
+static const struct rpc_version *gssp_version[] = {
+ NULL,
+ &gssp_version1,
+};
+
+static struct rpc_stat gssp_stats;
+
+static const struct rpc_program gssp_program = {
+ .name = "gssproxy",
+ .number = GSSPROXY_PROGRAM,
+ .nrvers = ARRAY_SIZE(gssp_version),
+ .version = gssp_version,
+ .stats = &gssp_stats,
+};
diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h
new file mode 100644
index 0000000000..31e9634416
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: GPL-2.0+ */
+/*
+ * linux/net/sunrpc/gss_rpc_upcall.h
+ *
+ * Copyright (C) 2012 Simo Sorce <simo@redhat.com>
+ */
+
+#ifndef _GSS_RPC_UPCALL_H
+#define _GSS_RPC_UPCALL_H
+
+#include <linux/sunrpc/gss_api.h>
+#include <linux/sunrpc/auth_gss.h>
+#include "gss_rpc_xdr.h"
+#include "../netns.h"
+
+struct gssp_upcall_data {
+ struct xdr_netobj in_handle;
+ struct gssp_in_token in_token;
+ struct xdr_netobj out_handle;
+ struct xdr_netobj out_token;
+ struct rpcsec_gss_oid mech_oid;
+ struct svc_cred creds;
+ int found_creds;
+ int major_status;
+ int minor_status;
+};
+
+int gssp_accept_sec_context_upcall(struct net *net,
+ struct gssp_upcall_data *data);
+void gssp_free_upcall_data(struct gssp_upcall_data *data);
+
+void init_gssp_clnt(struct sunrpc_net *);
+int set_gssp_clnt(struct net *);
+void clear_gssp_clnt(struct sunrpc_net *);
+
+#endif /* _GSS_RPC_UPCALL_H */
diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c
new file mode 100644
index 0000000000..d79f12c255
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c
@@ -0,0 +1,838 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * GSS Proxy upcall module
+ *
+ * Copyright (C) 2012 Simo Sorce <simo@redhat.com>
+ */
+
+#include <linux/sunrpc/svcauth.h>
+#include "gss_rpc_xdr.h"
+
+static int gssx_enc_bool(struct xdr_stream *xdr, int v)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ *p = v ? xdr_one : xdr_zero;
+ return 0;
+}
+
+static int gssx_dec_bool(struct xdr_stream *xdr, u32 *v)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ *v = be32_to_cpu(*p);
+ return 0;
+}
+
+static int gssx_enc_buffer(struct xdr_stream *xdr,
+ const gssx_buffer *buf)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, sizeof(u32) + buf->len);
+ if (!p)
+ return -ENOSPC;
+ xdr_encode_opaque(p, buf->data, buf->len);
+ return 0;
+}
+
+static int gssx_enc_in_token(struct xdr_stream *xdr,
+ const struct gssp_in_token *in)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4);
+ if (!p)
+ return -ENOSPC;
+ *p = cpu_to_be32(in->page_len);
+
+ /* all we need to do is to write pages */
+ xdr_write_pages(xdr, in->pages, in->page_base, in->page_len);
+
+ return 0;
+}
+
+
+static int gssx_dec_buffer(struct xdr_stream *xdr,
+ gssx_buffer *buf)
+{
+ u32 length;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+
+ length = be32_to_cpup(p);
+ p = xdr_inline_decode(xdr, length);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+
+ if (buf->len == 0) {
+ /* we intentionally are not interested in this buffer */
+ return 0;
+ }
+ if (length > buf->len)
+ return -ENOSPC;
+
+ if (!buf->data) {
+ buf->data = kmemdup(p, length, GFP_KERNEL);
+ if (!buf->data)
+ return -ENOMEM;
+ } else {
+ memcpy(buf->data, p, length);
+ }
+ buf->len = length;
+ return 0;
+}
+
+static int gssx_enc_option(struct xdr_stream *xdr,
+ struct gssx_option *opt)
+{
+ int err;
+
+ err = gssx_enc_buffer(xdr, &opt->option);
+ if (err)
+ return err;
+ err = gssx_enc_buffer(xdr, &opt->value);
+ return err;
+}
+
+static int gssx_dec_option(struct xdr_stream *xdr,
+ struct gssx_option *opt)
+{
+ int err;
+
+ err = gssx_dec_buffer(xdr, &opt->option);
+ if (err)
+ return err;
+ err = gssx_dec_buffer(xdr, &opt->value);
+ return err;
+}
+
+static int dummy_enc_opt_array(struct xdr_stream *xdr,
+ const struct gssx_option_array *oa)
+{
+ __be32 *p;
+
+ if (oa->count != 0)
+ return -EINVAL;
+
+ p = xdr_reserve_space(xdr, 4);
+ if (!p)
+ return -ENOSPC;
+ *p = 0;
+
+ return 0;
+}
+
+static int dummy_dec_opt_array(struct xdr_stream *xdr,
+ struct gssx_option_array *oa)
+{
+ struct gssx_option dummy;
+ u32 count, i;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ count = be32_to_cpup(p++);
+ memset(&dummy, 0, sizeof(dummy));
+ for (i = 0; i < count; i++) {
+ gssx_dec_option(xdr, &dummy);
+ }
+
+ oa->count = 0;
+ oa->data = NULL;
+ return 0;
+}
+
+static int get_host_u32(struct xdr_stream *xdr, u32 *res)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (!p)
+ return -EINVAL;
+ /* Contents of linux creds are all host-endian: */
+ memcpy(res, p, sizeof(u32));
+ return 0;
+}
+
+static int gssx_dec_linux_creds(struct xdr_stream *xdr,
+ struct svc_cred *creds)
+{
+ u32 length;
+ __be32 *p;
+ u32 tmp;
+ u32 N;
+ int i, err;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+
+ length = be32_to_cpup(p);
+
+ if (length > (3 + NGROUPS_MAX) * sizeof(u32))
+ return -ENOSPC;
+
+ /* uid */
+ err = get_host_u32(xdr, &tmp);
+ if (err)
+ return err;
+ creds->cr_uid = make_kuid(&init_user_ns, tmp);
+
+ /* gid */
+ err = get_host_u32(xdr, &tmp);
+ if (err)
+ return err;
+ creds->cr_gid = make_kgid(&init_user_ns, tmp);
+
+ /* number of additional gid's */
+ err = get_host_u32(xdr, &tmp);
+ if (err)
+ return err;
+ N = tmp;
+ if ((3 + N) * sizeof(u32) != length)
+ return -EINVAL;
+ creds->cr_group_info = groups_alloc(N);
+ if (creds->cr_group_info == NULL)
+ return -ENOMEM;
+
+ /* gid's */
+ for (i = 0; i < N; i++) {
+ kgid_t kgid;
+ err = get_host_u32(xdr, &tmp);
+ if (err)
+ goto out_free_groups;
+ err = -EINVAL;
+ kgid = make_kgid(&init_user_ns, tmp);
+ if (!gid_valid(kgid))
+ goto out_free_groups;
+ creds->cr_group_info->gid[i] = kgid;
+ }
+ groups_sort(creds->cr_group_info);
+
+ return 0;
+out_free_groups:
+ groups_free(creds->cr_group_info);
+ return err;
+}
+
+static int gssx_dec_option_array(struct xdr_stream *xdr,
+ struct gssx_option_array *oa)
+{
+ struct svc_cred *creds;
+ u32 count, i;
+ __be32 *p;
+ int err;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ count = be32_to_cpup(p++);
+ if (!count)
+ return 0;
+
+ /* we recognize only 1 currently: CREDS_VALUE */
+ oa->count = 1;
+
+ oa->data = kmalloc(sizeof(struct gssx_option), GFP_KERNEL);
+ if (!oa->data)
+ return -ENOMEM;
+
+ creds = kzalloc(sizeof(struct svc_cred), GFP_KERNEL);
+ if (!creds) {
+ kfree(oa->data);
+ return -ENOMEM;
+ }
+
+ oa->data[0].option.data = CREDS_VALUE;
+ oa->data[0].option.len = sizeof(CREDS_VALUE);
+ oa->data[0].value.data = (void *)creds;
+ oa->data[0].value.len = 0;
+
+ for (i = 0; i < count; i++) {
+ gssx_buffer dummy = { 0, NULL };
+ u32 length;
+
+ /* option buffer */
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+
+ length = be32_to_cpup(p);
+ p = xdr_inline_decode(xdr, length);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+
+ if (length == sizeof(CREDS_VALUE) &&
+ memcmp(p, CREDS_VALUE, sizeof(CREDS_VALUE)) == 0) {
+ /* We have creds here. parse them */
+ err = gssx_dec_linux_creds(xdr, creds);
+ if (err)
+ return err;
+ oa->data[0].value.len = 1; /* presence */
+ } else {
+ /* consume uninteresting buffer */
+ err = gssx_dec_buffer(xdr, &dummy);
+ if (err)
+ return err;
+ }
+ }
+ return 0;
+}
+
+static int gssx_dec_status(struct xdr_stream *xdr,
+ struct gssx_status *status)
+{
+ __be32 *p;
+ int err;
+
+ /* status->major_status */
+ p = xdr_inline_decode(xdr, 8);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ p = xdr_decode_hyper(p, &status->major_status);
+
+ /* status->mech */
+ err = gssx_dec_buffer(xdr, &status->mech);
+ if (err)
+ return err;
+
+ /* status->minor_status */
+ p = xdr_inline_decode(xdr, 8);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ p = xdr_decode_hyper(p, &status->minor_status);
+
+ /* status->major_status_string */
+ err = gssx_dec_buffer(xdr, &status->major_status_string);
+ if (err)
+ return err;
+
+ /* status->minor_status_string */
+ err = gssx_dec_buffer(xdr, &status->minor_status_string);
+ if (err)
+ return err;
+
+ /* status->server_ctx */
+ err = gssx_dec_buffer(xdr, &status->server_ctx);
+ if (err)
+ return err;
+
+ /* we assume we have no options for now, so simply consume them */
+ /* status->options */
+ err = dummy_dec_opt_array(xdr, &status->options);
+
+ return err;
+}
+
+static int gssx_enc_call_ctx(struct xdr_stream *xdr,
+ const struct gssx_call_ctx *ctx)
+{
+ struct gssx_option opt;
+ __be32 *p;
+ int err;
+
+ /* ctx->locale */
+ err = gssx_enc_buffer(xdr, &ctx->locale);
+ if (err)
+ return err;
+
+ /* ctx->server_ctx */
+ err = gssx_enc_buffer(xdr, &ctx->server_ctx);
+ if (err)
+ return err;
+
+ /* we always want to ask for lucid contexts */
+ /* ctx->options */
+ p = xdr_reserve_space(xdr, 4);
+ *p = cpu_to_be32(2);
+
+ /* we want a lucid_v1 context */
+ opt.option.data = LUCID_OPTION;
+ opt.option.len = sizeof(LUCID_OPTION);
+ opt.value.data = LUCID_VALUE;
+ opt.value.len = sizeof(LUCID_VALUE);
+ err = gssx_enc_option(xdr, &opt);
+
+ /* ..and user creds */
+ opt.option.data = CREDS_OPTION;
+ opt.option.len = sizeof(CREDS_OPTION);
+ opt.value.data = CREDS_VALUE;
+ opt.value.len = sizeof(CREDS_VALUE);
+ err = gssx_enc_option(xdr, &opt);
+
+ return err;
+}
+
+static int gssx_dec_name_attr(struct xdr_stream *xdr,
+ struct gssx_name_attr *attr)
+{
+ int err;
+
+ /* attr->attr */
+ err = gssx_dec_buffer(xdr, &attr->attr);
+ if (err)
+ return err;
+
+ /* attr->value */
+ err = gssx_dec_buffer(xdr, &attr->value);
+ if (err)
+ return err;
+
+ /* attr->extensions */
+ err = dummy_dec_opt_array(xdr, &attr->extensions);
+
+ return err;
+}
+
+static int dummy_enc_nameattr_array(struct xdr_stream *xdr,
+ struct gssx_name_attr_array *naa)
+{
+ __be32 *p;
+
+ if (naa->count != 0)
+ return -EINVAL;
+
+ p = xdr_reserve_space(xdr, 4);
+ if (!p)
+ return -ENOSPC;
+ *p = 0;
+
+ return 0;
+}
+
+static int dummy_dec_nameattr_array(struct xdr_stream *xdr,
+ struct gssx_name_attr_array *naa)
+{
+ struct gssx_name_attr dummy = { .attr = {.len = 0} };
+ u32 count, i;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ count = be32_to_cpup(p++);
+ for (i = 0; i < count; i++) {
+ gssx_dec_name_attr(xdr, &dummy);
+ }
+
+ naa->count = 0;
+ naa->data = NULL;
+ return 0;
+}
+
+static struct xdr_netobj zero_netobj = {};
+
+static struct gssx_name_attr_array zero_name_attr_array = {};
+
+static struct gssx_option_array zero_option_array = {};
+
+static int gssx_enc_name(struct xdr_stream *xdr,
+ struct gssx_name *name)
+{
+ int err;
+
+ /* name->display_name */
+ err = gssx_enc_buffer(xdr, &name->display_name);
+ if (err)
+ return err;
+
+ /* name->name_type */
+ err = gssx_enc_buffer(xdr, &zero_netobj);
+ if (err)
+ return err;
+
+ /* name->exported_name */
+ err = gssx_enc_buffer(xdr, &zero_netobj);
+ if (err)
+ return err;
+
+ /* name->exported_composite_name */
+ err = gssx_enc_buffer(xdr, &zero_netobj);
+ if (err)
+ return err;
+
+ /* leave name_attributes empty for now, will add once we have any
+ * to pass up at all */
+ /* name->name_attributes */
+ err = dummy_enc_nameattr_array(xdr, &zero_name_attr_array);
+ if (err)
+ return err;
+
+ /* leave options empty for now, will add once we have any options
+ * to pass up at all */
+ /* name->extensions */
+ err = dummy_enc_opt_array(xdr, &zero_option_array);
+
+ return err;
+}
+
+
+static int gssx_dec_name(struct xdr_stream *xdr,
+ struct gssx_name *name)
+{
+ struct xdr_netobj dummy_netobj = { .len = 0 };
+ struct gssx_name_attr_array dummy_name_attr_array = { .count = 0 };
+ struct gssx_option_array dummy_option_array = { .count = 0 };
+ int err;
+
+ /* name->display_name */
+ err = gssx_dec_buffer(xdr, &name->display_name);
+ if (err)
+ return err;
+
+ /* name->name_type */
+ err = gssx_dec_buffer(xdr, &dummy_netobj);
+ if (err)
+ return err;
+
+ /* name->exported_name */
+ err = gssx_dec_buffer(xdr, &dummy_netobj);
+ if (err)
+ return err;
+
+ /* name->exported_composite_name */
+ err = gssx_dec_buffer(xdr, &dummy_netobj);
+ if (err)
+ return err;
+
+ /* we assume we have no attributes for now, so simply consume them */
+ /* name->name_attributes */
+ err = dummy_dec_nameattr_array(xdr, &dummy_name_attr_array);
+ if (err)
+ return err;
+
+ /* we assume we have no options for now, so simply consume them */
+ /* name->extensions */
+ err = dummy_dec_opt_array(xdr, &dummy_option_array);
+
+ return err;
+}
+
+static int dummy_enc_credel_array(struct xdr_stream *xdr,
+ struct gssx_cred_element_array *cea)
+{
+ __be32 *p;
+
+ if (cea->count != 0)
+ return -EINVAL;
+
+ p = xdr_reserve_space(xdr, 4);
+ if (!p)
+ return -ENOSPC;
+ *p = 0;
+
+ return 0;
+}
+
+static int gssx_enc_cred(struct xdr_stream *xdr,
+ struct gssx_cred *cred)
+{
+ int err;
+
+ /* cred->desired_name */
+ err = gssx_enc_name(xdr, &cred->desired_name);
+ if (err)
+ return err;
+
+ /* cred->elements */
+ err = dummy_enc_credel_array(xdr, &cred->elements);
+ if (err)
+ return err;
+
+ /* cred->cred_handle_reference */
+ err = gssx_enc_buffer(xdr, &cred->cred_handle_reference);
+ if (err)
+ return err;
+
+ /* cred->needs_release */
+ err = gssx_enc_bool(xdr, cred->needs_release);
+
+ return err;
+}
+
+static int gssx_enc_ctx(struct xdr_stream *xdr,
+ struct gssx_ctx *ctx)
+{
+ __be32 *p;
+ int err;
+
+ /* ctx->exported_context_token */
+ err = gssx_enc_buffer(xdr, &ctx->exported_context_token);
+ if (err)
+ return err;
+
+ /* ctx->state */
+ err = gssx_enc_buffer(xdr, &ctx->state);
+ if (err)
+ return err;
+
+ /* ctx->need_release */
+ err = gssx_enc_bool(xdr, ctx->need_release);
+ if (err)
+ return err;
+
+ /* ctx->mech */
+ err = gssx_enc_buffer(xdr, &ctx->mech);
+ if (err)
+ return err;
+
+ /* ctx->src_name */
+ err = gssx_enc_name(xdr, &ctx->src_name);
+ if (err)
+ return err;
+
+ /* ctx->targ_name */
+ err = gssx_enc_name(xdr, &ctx->targ_name);
+ if (err)
+ return err;
+
+ /* ctx->lifetime */
+ p = xdr_reserve_space(xdr, 8+8);
+ if (!p)
+ return -ENOSPC;
+ p = xdr_encode_hyper(p, ctx->lifetime);
+
+ /* ctx->ctx_flags */
+ p = xdr_encode_hyper(p, ctx->ctx_flags);
+
+ /* ctx->locally_initiated */
+ err = gssx_enc_bool(xdr, ctx->locally_initiated);
+ if (err)
+ return err;
+
+ /* ctx->open */
+ err = gssx_enc_bool(xdr, ctx->open);
+ if (err)
+ return err;
+
+ /* leave options empty for now, will add once we have any options
+ * to pass up at all */
+ /* ctx->options */
+ err = dummy_enc_opt_array(xdr, &ctx->options);
+
+ return err;
+}
+
+static int gssx_dec_ctx(struct xdr_stream *xdr,
+ struct gssx_ctx *ctx)
+{
+ __be32 *p;
+ int err;
+
+ /* ctx->exported_context_token */
+ err = gssx_dec_buffer(xdr, &ctx->exported_context_token);
+ if (err)
+ return err;
+
+ /* ctx->state */
+ err = gssx_dec_buffer(xdr, &ctx->state);
+ if (err)
+ return err;
+
+ /* ctx->need_release */
+ err = gssx_dec_bool(xdr, &ctx->need_release);
+ if (err)
+ return err;
+
+ /* ctx->mech */
+ err = gssx_dec_buffer(xdr, &ctx->mech);
+ if (err)
+ return err;
+
+ /* ctx->src_name */
+ err = gssx_dec_name(xdr, &ctx->src_name);
+ if (err)
+ return err;
+
+ /* ctx->targ_name */
+ err = gssx_dec_name(xdr, &ctx->targ_name);
+ if (err)
+ return err;
+
+ /* ctx->lifetime */
+ p = xdr_inline_decode(xdr, 8+8);
+ if (unlikely(p == NULL))
+ return -ENOSPC;
+ p = xdr_decode_hyper(p, &ctx->lifetime);
+
+ /* ctx->ctx_flags */
+ p = xdr_decode_hyper(p, &ctx->ctx_flags);
+
+ /* ctx->locally_initiated */
+ err = gssx_dec_bool(xdr, &ctx->locally_initiated);
+ if (err)
+ return err;
+
+ /* ctx->open */
+ err = gssx_dec_bool(xdr, &ctx->open);
+ if (err)
+ return err;
+
+ /* we assume we have no options for now, so simply consume them */
+ /* ctx->options */
+ err = dummy_dec_opt_array(xdr, &ctx->options);
+
+ return err;
+}
+
+static int gssx_enc_cb(struct xdr_stream *xdr, struct gssx_cb *cb)
+{
+ __be32 *p;
+ int err;
+
+ /* cb->initiator_addrtype */
+ p = xdr_reserve_space(xdr, 8);
+ if (!p)
+ return -ENOSPC;
+ p = xdr_encode_hyper(p, cb->initiator_addrtype);
+
+ /* cb->initiator_address */
+ err = gssx_enc_buffer(xdr, &cb->initiator_address);
+ if (err)
+ return err;
+
+ /* cb->acceptor_addrtype */
+ p = xdr_reserve_space(xdr, 8);
+ if (!p)
+ return -ENOSPC;
+ p = xdr_encode_hyper(p, cb->acceptor_addrtype);
+
+ /* cb->acceptor_address */
+ err = gssx_enc_buffer(xdr, &cb->acceptor_address);
+ if (err)
+ return err;
+
+ /* cb->application_data */
+ err = gssx_enc_buffer(xdr, &cb->application_data);
+
+ return err;
+}
+
+void gssx_enc_accept_sec_context(struct rpc_rqst *req,
+ struct xdr_stream *xdr,
+ const void *data)
+{
+ const struct gssx_arg_accept_sec_context *arg = data;
+ int err;
+
+ err = gssx_enc_call_ctx(xdr, &arg->call_ctx);
+ if (err)
+ goto done;
+
+ /* arg->context_handle */
+ if (arg->context_handle)
+ err = gssx_enc_ctx(xdr, arg->context_handle);
+ else
+ err = gssx_enc_bool(xdr, 0);
+ if (err)
+ goto done;
+
+ /* arg->cred_handle */
+ if (arg->cred_handle)
+ err = gssx_enc_cred(xdr, arg->cred_handle);
+ else
+ err = gssx_enc_bool(xdr, 0);
+ if (err)
+ goto done;
+
+ /* arg->input_token */
+ err = gssx_enc_in_token(xdr, &arg->input_token);
+ if (err)
+ goto done;
+
+ /* arg->input_cb */
+ if (arg->input_cb)
+ err = gssx_enc_cb(xdr, arg->input_cb);
+ else
+ err = gssx_enc_bool(xdr, 0);
+ if (err)
+ goto done;
+
+ err = gssx_enc_bool(xdr, arg->ret_deleg_cred);
+ if (err)
+ goto done;
+
+ /* leave options empty for now, will add once we have any options
+ * to pass up at all */
+ /* arg->options */
+ err = dummy_enc_opt_array(xdr, &arg->options);
+
+ xdr_inline_pages(&req->rq_rcv_buf,
+ PAGE_SIZE/2 /* pretty arbitrary */,
+ arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE);
+done:
+ if (err)
+ dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err);
+}
+
+int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp,
+ struct xdr_stream *xdr,
+ void *data)
+{
+ struct gssx_res_accept_sec_context *res = data;
+ u32 value_follows;
+ int err;
+ struct page *scratch;
+
+ scratch = alloc_page(GFP_KERNEL);
+ if (!scratch)
+ return -ENOMEM;
+ xdr_set_scratch_page(xdr, scratch);
+
+ /* res->status */
+ err = gssx_dec_status(xdr, &res->status);
+ if (err)
+ goto out_free;
+
+ /* res->context_handle */
+ err = gssx_dec_bool(xdr, &value_follows);
+ if (err)
+ goto out_free;
+ if (value_follows) {
+ err = gssx_dec_ctx(xdr, res->context_handle);
+ if (err)
+ goto out_free;
+ } else {
+ res->context_handle = NULL;
+ }
+
+ /* res->output_token */
+ err = gssx_dec_bool(xdr, &value_follows);
+ if (err)
+ goto out_free;
+ if (value_follows) {
+ err = gssx_dec_buffer(xdr, res->output_token);
+ if (err)
+ goto out_free;
+ } else {
+ res->output_token = NULL;
+ }
+
+ /* res->delegated_cred_handle */
+ err = gssx_dec_bool(xdr, &value_follows);
+ if (err)
+ goto out_free;
+ if (value_follows) {
+ /* we do not support upcall servers sending this data. */
+ err = -EINVAL;
+ goto out_free;
+ }
+
+ /* res->options */
+ err = gssx_dec_option_array(xdr, &res->options);
+
+out_free:
+ __free_page(scratch);
+ return err;
+}
diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h
new file mode 100644
index 0000000000..3f17411b7e
--- /dev/null
+++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h
@@ -0,0 +1,252 @@
+/* SPDX-License-Identifier: GPL-2.0+ */
+/*
+ * GSS Proxy upcall module
+ *
+ * Copyright (C) 2012 Simo Sorce <simo@redhat.com>
+ */
+
+#ifndef _LINUX_GSS_RPC_XDR_H
+#define _LINUX_GSS_RPC_XDR_H
+
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/xprtsock.h>
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+#define LUCID_OPTION "exported_context_type"
+#define LUCID_VALUE "linux_lucid_v1"
+#define CREDS_OPTION "exported_creds_type"
+#define CREDS_VALUE "linux_creds_v1"
+
+typedef struct xdr_netobj gssx_buffer;
+typedef struct xdr_netobj utf8string;
+typedef struct xdr_netobj gssx_OID;
+
+enum gssx_cred_usage {
+ GSSX_C_INITIATE = 1,
+ GSSX_C_ACCEPT = 2,
+ GSSX_C_BOTH = 3,
+};
+
+struct gssx_option {
+ gssx_buffer option;
+ gssx_buffer value;
+};
+
+struct gssx_option_array {
+ u32 count;
+ struct gssx_option *data;
+};
+
+struct gssx_status {
+ u64 major_status;
+ gssx_OID mech;
+ u64 minor_status;
+ utf8string major_status_string;
+ utf8string minor_status_string;
+ gssx_buffer server_ctx;
+ struct gssx_option_array options;
+};
+
+struct gssx_call_ctx {
+ utf8string locale;
+ gssx_buffer server_ctx;
+ struct gssx_option_array options;
+};
+
+struct gssx_name_attr {
+ gssx_buffer attr;
+ gssx_buffer value;
+ struct gssx_option_array extensions;
+};
+
+struct gssx_name_attr_array {
+ u32 count;
+ struct gssx_name_attr *data;
+};
+
+struct gssx_name {
+ gssx_buffer display_name;
+};
+typedef struct gssx_name gssx_name;
+
+struct gssx_cred_element {
+ gssx_name MN;
+ gssx_OID mech;
+ u32 cred_usage;
+ u64 initiator_time_rec;
+ u64 acceptor_time_rec;
+ struct gssx_option_array options;
+};
+
+struct gssx_cred_element_array {
+ u32 count;
+ struct gssx_cred_element *data;
+};
+
+struct gssx_cred {
+ gssx_name desired_name;
+ struct gssx_cred_element_array elements;
+ gssx_buffer cred_handle_reference;
+ u32 needs_release;
+};
+
+struct gssx_ctx {
+ gssx_buffer exported_context_token;
+ gssx_buffer state;
+ u32 need_release;
+ gssx_OID mech;
+ gssx_name src_name;
+ gssx_name targ_name;
+ u64 lifetime;
+ u64 ctx_flags;
+ u32 locally_initiated;
+ u32 open;
+ struct gssx_option_array options;
+};
+
+struct gssx_cb {
+ u64 initiator_addrtype;
+ gssx_buffer initiator_address;
+ u64 acceptor_addrtype;
+ gssx_buffer acceptor_address;
+ gssx_buffer application_data;
+};
+
+
+/* This structure is not defined in the protocol.
+ * It is used in the kernel to carry around a big buffer
+ * as a set of pages */
+struct gssp_in_token {
+ struct page **pages; /* Array of contiguous pages */
+ unsigned int page_base; /* Start of page data */
+ unsigned int page_len; /* Length of page data */
+};
+
+struct gssx_arg_accept_sec_context {
+ struct gssx_call_ctx call_ctx;
+ struct gssx_ctx *context_handle;
+ struct gssx_cred *cred_handle;
+ struct gssp_in_token input_token;
+ struct gssx_cb *input_cb;
+ u32 ret_deleg_cred;
+ struct gssx_option_array options;
+ struct page **pages;
+ unsigned int npages;
+};
+
+struct gssx_res_accept_sec_context {
+ struct gssx_status status;
+ struct gssx_ctx *context_handle;
+ gssx_buffer *output_token;
+ /* struct gssx_cred *delegated_cred_handle; not used in kernel */
+ struct gssx_option_array options;
+};
+
+
+
+#define gssx_enc_indicate_mechs NULL
+#define gssx_dec_indicate_mechs NULL
+#define gssx_enc_get_call_context NULL
+#define gssx_dec_get_call_context NULL
+#define gssx_enc_import_and_canon_name NULL
+#define gssx_dec_import_and_canon_name NULL
+#define gssx_enc_export_cred NULL
+#define gssx_dec_export_cred NULL
+#define gssx_enc_import_cred NULL
+#define gssx_dec_import_cred NULL
+#define gssx_enc_acquire_cred NULL
+#define gssx_dec_acquire_cred NULL
+#define gssx_enc_store_cred NULL
+#define gssx_dec_store_cred NULL
+#define gssx_enc_init_sec_context NULL
+#define gssx_dec_init_sec_context NULL
+void gssx_enc_accept_sec_context(struct rpc_rqst *req,
+ struct xdr_stream *xdr,
+ const void *data);
+int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp,
+ struct xdr_stream *xdr,
+ void *data);
+#define gssx_enc_release_handle NULL
+#define gssx_dec_release_handle NULL
+#define gssx_enc_get_mic NULL
+#define gssx_dec_get_mic NULL
+#define gssx_enc_verify NULL
+#define gssx_dec_verify NULL
+#define gssx_enc_wrap NULL
+#define gssx_dec_wrap NULL
+#define gssx_enc_unwrap NULL
+#define gssx_dec_unwrap NULL
+#define gssx_enc_wrap_size_limit NULL
+#define gssx_dec_wrap_size_limit NULL
+
+/* non implemented calls are set to 0 size */
+#define GSSX_ARG_indicate_mechs_sz 0
+#define GSSX_RES_indicate_mechs_sz 0
+#define GSSX_ARG_get_call_context_sz 0
+#define GSSX_RES_get_call_context_sz 0
+#define GSSX_ARG_import_and_canon_name_sz 0
+#define GSSX_RES_import_and_canon_name_sz 0
+#define GSSX_ARG_export_cred_sz 0
+#define GSSX_RES_export_cred_sz 0
+#define GSSX_ARG_import_cred_sz 0
+#define GSSX_RES_import_cred_sz 0
+#define GSSX_ARG_acquire_cred_sz 0
+#define GSSX_RES_acquire_cred_sz 0
+#define GSSX_ARG_store_cred_sz 0
+#define GSSX_RES_store_cred_sz 0
+#define GSSX_ARG_init_sec_context_sz 0
+#define GSSX_RES_init_sec_context_sz 0
+
+#define GSSX_default_in_call_ctx_sz (4 + 4 + 4 + \
+ 8 + sizeof(LUCID_OPTION) + sizeof(LUCID_VALUE) + \
+ 8 + sizeof(CREDS_OPTION) + sizeof(CREDS_VALUE))
+#define GSSX_default_in_ctx_hndl_sz (4 + 4+8 + 4 + 4 + 6*4 + 6*4 + 8 + 8 + \
+ 4 + 4 + 4)
+#define GSSX_default_in_cred_sz 4 /* we send in no cred_handle */
+#define GSSX_default_in_token_sz 4 /* does *not* include token data */
+#define GSSX_default_in_cb_sz 4 /* we do not use channel bindings */
+#define GSSX_ARG_accept_sec_context_sz (GSSX_default_in_call_ctx_sz + \
+ GSSX_default_in_ctx_hndl_sz + \
+ GSSX_default_in_cred_sz + \
+ GSSX_default_in_token_sz + \
+ GSSX_default_in_cb_sz + \
+ 4 /* no deleg creds boolean */ + \
+ 4) /* empty options */
+
+/* somewhat arbitrary numbers but large enough (we ignore some of the data
+ * sent down, but it is part of the protocol so we need enough space to take
+ * it in) */
+#define GSSX_default_status_sz 8 + 24 + 8 + 256 + 256 + 16 + 4
+#define GSSX_max_output_handle_sz 128
+#define GSSX_max_oid_sz 16
+#define GSSX_max_princ_sz 256
+#define GSSX_default_ctx_sz (GSSX_max_output_handle_sz + \
+ 16 + 4 + GSSX_max_oid_sz + \
+ 2 * GSSX_max_princ_sz + \
+ 8 + 8 + 4 + 4 + 4)
+#define GSSX_max_output_token_sz 1024
+/* grouplist not included; we allocate separate pages for that: */
+#define GSSX_max_creds_sz (4 + 4 + 4 /* + NGROUPS_MAX*4 */)
+#define GSSX_RES_accept_sec_context_sz (GSSX_default_status_sz + \
+ GSSX_default_ctx_sz + \
+ GSSX_max_output_token_sz + \
+ 4 + GSSX_max_creds_sz)
+
+#define GSSX_ARG_release_handle_sz 0
+#define GSSX_RES_release_handle_sz 0
+#define GSSX_ARG_get_mic_sz 0
+#define GSSX_RES_get_mic_sz 0
+#define GSSX_ARG_verify_sz 0
+#define GSSX_RES_verify_sz 0
+#define GSSX_ARG_wrap_sz 0
+#define GSSX_RES_wrap_sz 0
+#define GSSX_ARG_unwrap_sz 0
+#define GSSX_RES_unwrap_sz 0
+#define GSSX_ARG_wrap_size_limit_sz 0
+#define GSSX_RES_wrap_size_limit_sz 0
+
+#endif /* _LINUX_GSS_RPC_XDR_H */
diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c
new file mode 100644
index 0000000000..18734e70c5
--- /dev/null
+++ b/net/sunrpc/auth_gss/svcauth_gss.c
@@ -0,0 +1,2134 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Neil Brown <neilb@cse.unsw.edu.au>
+ * J. Bruce Fields <bfields@umich.edu>
+ * Andy Adamson <andros@umich.edu>
+ * Dug Song <dugsong@monkey.org>
+ *
+ * RPCSEC_GSS server authentication.
+ * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
+ * (gssapi)
+ *
+ * The RPCSEC_GSS involves three stages:
+ * 1/ context creation
+ * 2/ data exchange
+ * 3/ context destruction
+ *
+ * Context creation is handled largely by upcalls to user-space.
+ * In particular, GSS_Accept_sec_context is handled by an upcall
+ * Data exchange is handled entirely within the kernel
+ * In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
+ * Context destruction is handled in-kernel
+ * GSS_Delete_sec_context is in-kernel
+ *
+ * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
+ * The context handle and gss_token are used as a key into the rpcsec_init cache.
+ * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
+ * being major_status, minor_status, context_handle, reply_token.
+ * These are sent back to the client.
+ * Sequence window management is handled by the kernel. The window size if currently
+ * a compile time constant.
+ *
+ * When user-space is happy that a context is established, it places an entry
+ * in the rpcsec_context cache. The key for this cache is the context_handle.
+ * The content includes:
+ * uid/gidlist - for determining access rights
+ * mechanism type
+ * mechanism specific information, such as a key
+ *
+ */
+
+#include <linux/slab.h>
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/pagemap.h>
+#include <linux/user_namespace.h>
+
+#include <linux/sunrpc/auth_gss.h>
+#include <linux/sunrpc/gss_err.h>
+#include <linux/sunrpc/svcauth.h>
+#include <linux/sunrpc/svcauth_gss.h>
+#include <linux/sunrpc/cache.h>
+#include <linux/sunrpc/gss_krb5.h>
+
+#include <trace/events/rpcgss.h>
+
+#include "gss_rpc_upcall.h"
+
+/*
+ * Unfortunately there isn't a maximum checksum size exported via the
+ * GSS API. Manufacture one based on GSS mechanisms supported by this
+ * implementation.
+ */
+#define GSS_MAX_CKSUMSIZE (GSS_KRB5_TOK_HDR_LEN + GSS_KRB5_MAX_CKSUM_LEN)
+
+/*
+ * This value may be increased in the future to accommodate other
+ * usage of the scratch buffer.
+ */
+#define GSS_SCRATCH_SIZE GSS_MAX_CKSUMSIZE
+
+struct gss_svc_data {
+ /* decoded gss client cred: */
+ struct rpc_gss_wire_cred clcred;
+ u32 gsd_databody_offset;
+ struct rsc *rsci;
+
+ /* for temporary results */
+ __be32 gsd_seq_num;
+ u8 gsd_scratch[GSS_SCRATCH_SIZE];
+};
+
+/* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
+ * into replies.
+ *
+ * Key is context handle (\x if empty) and gss_token.
+ * Content is major_status minor_status (integers) context_handle, reply_token.
+ *
+ */
+
+static int netobj_equal(struct xdr_netobj *a, struct xdr_netobj *b)
+{
+ return a->len == b->len && 0 == memcmp(a->data, b->data, a->len);
+}
+
+#define RSI_HASHBITS 6
+#define RSI_HASHMAX (1<<RSI_HASHBITS)
+
+struct rsi {
+ struct cache_head h;
+ struct xdr_netobj in_handle, in_token;
+ struct xdr_netobj out_handle, out_token;
+ int major_status, minor_status;
+ struct rcu_head rcu_head;
+};
+
+static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old);
+static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item);
+
+static void rsi_free(struct rsi *rsii)
+{
+ kfree(rsii->in_handle.data);
+ kfree(rsii->in_token.data);
+ kfree(rsii->out_handle.data);
+ kfree(rsii->out_token.data);
+}
+
+static void rsi_free_rcu(struct rcu_head *head)
+{
+ struct rsi *rsii = container_of(head, struct rsi, rcu_head);
+
+ rsi_free(rsii);
+ kfree(rsii);
+}
+
+static void rsi_put(struct kref *ref)
+{
+ struct rsi *rsii = container_of(ref, struct rsi, h.ref);
+
+ call_rcu(&rsii->rcu_head, rsi_free_rcu);
+}
+
+static inline int rsi_hash(struct rsi *item)
+{
+ return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
+ ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS);
+}
+
+static int rsi_match(struct cache_head *a, struct cache_head *b)
+{
+ struct rsi *item = container_of(a, struct rsi, h);
+ struct rsi *tmp = container_of(b, struct rsi, h);
+ return netobj_equal(&item->in_handle, &tmp->in_handle) &&
+ netobj_equal(&item->in_token, &tmp->in_token);
+}
+
+static int dup_to_netobj(struct xdr_netobj *dst, char *src, int len)
+{
+ dst->len = len;
+ dst->data = (len ? kmemdup(src, len, GFP_KERNEL) : NULL);
+ if (len && !dst->data)
+ return -ENOMEM;
+ return 0;
+}
+
+static inline int dup_netobj(struct xdr_netobj *dst, struct xdr_netobj *src)
+{
+ return dup_to_netobj(dst, src->data, src->len);
+}
+
+static void rsi_init(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct rsi *new = container_of(cnew, struct rsi, h);
+ struct rsi *item = container_of(citem, struct rsi, h);
+
+ new->out_handle.data = NULL;
+ new->out_handle.len = 0;
+ new->out_token.data = NULL;
+ new->out_token.len = 0;
+ new->in_handle.len = item->in_handle.len;
+ item->in_handle.len = 0;
+ new->in_token.len = item->in_token.len;
+ item->in_token.len = 0;
+ new->in_handle.data = item->in_handle.data;
+ item->in_handle.data = NULL;
+ new->in_token.data = item->in_token.data;
+ item->in_token.data = NULL;
+}
+
+static void update_rsi(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct rsi *new = container_of(cnew, struct rsi, h);
+ struct rsi *item = container_of(citem, struct rsi, h);
+
+ BUG_ON(new->out_handle.data || new->out_token.data);
+ new->out_handle.len = item->out_handle.len;
+ item->out_handle.len = 0;
+ new->out_token.len = item->out_token.len;
+ item->out_token.len = 0;
+ new->out_handle.data = item->out_handle.data;
+ item->out_handle.data = NULL;
+ new->out_token.data = item->out_token.data;
+ item->out_token.data = NULL;
+
+ new->major_status = item->major_status;
+ new->minor_status = item->minor_status;
+}
+
+static struct cache_head *rsi_alloc(void)
+{
+ struct rsi *rsii = kmalloc(sizeof(*rsii), GFP_KERNEL);
+ if (rsii)
+ return &rsii->h;
+ else
+ return NULL;
+}
+
+static int rsi_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+ return sunrpc_cache_pipe_upcall_timeout(cd, h);
+}
+
+static void rsi_request(struct cache_detail *cd,
+ struct cache_head *h,
+ char **bpp, int *blen)
+{
+ struct rsi *rsii = container_of(h, struct rsi, h);
+
+ qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len);
+ qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len);
+ (*bpp)[-1] = '\n';
+ WARN_ONCE(*blen < 0,
+ "RPCSEC/GSS credential too large - please use gssproxy\n");
+}
+
+static int rsi_parse(struct cache_detail *cd,
+ char *mesg, int mlen)
+{
+ /* context token expiry major minor context token */
+ char *buf = mesg;
+ char *ep;
+ int len;
+ struct rsi rsii, *rsip = NULL;
+ time64_t expiry;
+ int status = -EINVAL;
+
+ memset(&rsii, 0, sizeof(rsii));
+ /* handle */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0)
+ goto out;
+ status = -ENOMEM;
+ if (dup_to_netobj(&rsii.in_handle, buf, len))
+ goto out;
+
+ /* token */
+ len = qword_get(&mesg, buf, mlen);
+ status = -EINVAL;
+ if (len < 0)
+ goto out;
+ status = -ENOMEM;
+ if (dup_to_netobj(&rsii.in_token, buf, len))
+ goto out;
+
+ rsip = rsi_lookup(cd, &rsii);
+ if (!rsip)
+ goto out;
+
+ rsii.h.flags = 0;
+ /* expiry */
+ status = get_expiry(&mesg, &expiry);
+ if (status)
+ goto out;
+
+ status = -EINVAL;
+ /* major/minor */
+ len = qword_get(&mesg, buf, mlen);
+ if (len <= 0)
+ goto out;
+ rsii.major_status = simple_strtoul(buf, &ep, 10);
+ if (*ep)
+ goto out;
+ len = qword_get(&mesg, buf, mlen);
+ if (len <= 0)
+ goto out;
+ rsii.minor_status = simple_strtoul(buf, &ep, 10);
+ if (*ep)
+ goto out;
+
+ /* out_handle */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0)
+ goto out;
+ status = -ENOMEM;
+ if (dup_to_netobj(&rsii.out_handle, buf, len))
+ goto out;
+
+ /* out_token */
+ len = qword_get(&mesg, buf, mlen);
+ status = -EINVAL;
+ if (len < 0)
+ goto out;
+ status = -ENOMEM;
+ if (dup_to_netobj(&rsii.out_token, buf, len))
+ goto out;
+ rsii.h.expiry_time = expiry;
+ rsip = rsi_update(cd, &rsii, rsip);
+ status = 0;
+out:
+ rsi_free(&rsii);
+ if (rsip)
+ cache_put(&rsip->h, cd);
+ else
+ status = -ENOMEM;
+ return status;
+}
+
+static const struct cache_detail rsi_cache_template = {
+ .owner = THIS_MODULE,
+ .hash_size = RSI_HASHMAX,
+ .name = "auth.rpcsec.init",
+ .cache_put = rsi_put,
+ .cache_upcall = rsi_upcall,
+ .cache_request = rsi_request,
+ .cache_parse = rsi_parse,
+ .match = rsi_match,
+ .init = rsi_init,
+ .update = update_rsi,
+ .alloc = rsi_alloc,
+};
+
+static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item)
+{
+ struct cache_head *ch;
+ int hash = rsi_hash(item);
+
+ ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash);
+ if (ch)
+ return container_of(ch, struct rsi, h);
+ else
+ return NULL;
+}
+
+static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old)
+{
+ struct cache_head *ch;
+ int hash = rsi_hash(new);
+
+ ch = sunrpc_cache_update(cd, &new->h,
+ &old->h, hash);
+ if (ch)
+ return container_of(ch, struct rsi, h);
+ else
+ return NULL;
+}
+
+
+/*
+ * The rpcsec_context cache is used to store a context that is
+ * used in data exchange.
+ * The key is a context handle. The content is:
+ * uid, gidlist, mechanism, service-set, mech-specific-data
+ */
+
+#define RSC_HASHBITS 10
+#define RSC_HASHMAX (1<<RSC_HASHBITS)
+
+#define GSS_SEQ_WIN 128
+
+struct gss_svc_seq_data {
+ /* highest seq number seen so far: */
+ u32 sd_max;
+ /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
+ * sd_win is nonzero iff sequence number i has been seen already: */
+ unsigned long sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
+ spinlock_t sd_lock;
+};
+
+struct rsc {
+ struct cache_head h;
+ struct xdr_netobj handle;
+ struct svc_cred cred;
+ struct gss_svc_seq_data seqdata;
+ struct gss_ctx *mechctx;
+ struct rcu_head rcu_head;
+};
+
+static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old);
+static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item);
+
+static void rsc_free(struct rsc *rsci)
+{
+ kfree(rsci->handle.data);
+ if (rsci->mechctx)
+ gss_delete_sec_context(&rsci->mechctx);
+ free_svc_cred(&rsci->cred);
+}
+
+static void rsc_free_rcu(struct rcu_head *head)
+{
+ struct rsc *rsci = container_of(head, struct rsc, rcu_head);
+
+ kfree(rsci->handle.data);
+ kfree(rsci);
+}
+
+static void rsc_put(struct kref *ref)
+{
+ struct rsc *rsci = container_of(ref, struct rsc, h.ref);
+
+ if (rsci->mechctx)
+ gss_delete_sec_context(&rsci->mechctx);
+ free_svc_cred(&rsci->cred);
+ call_rcu(&rsci->rcu_head, rsc_free_rcu);
+}
+
+static inline int
+rsc_hash(struct rsc *rsci)
+{
+ return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS);
+}
+
+static int
+rsc_match(struct cache_head *a, struct cache_head *b)
+{
+ struct rsc *new = container_of(a, struct rsc, h);
+ struct rsc *tmp = container_of(b, struct rsc, h);
+
+ return netobj_equal(&new->handle, &tmp->handle);
+}
+
+static void
+rsc_init(struct cache_head *cnew, struct cache_head *ctmp)
+{
+ struct rsc *new = container_of(cnew, struct rsc, h);
+ struct rsc *tmp = container_of(ctmp, struct rsc, h);
+
+ new->handle.len = tmp->handle.len;
+ tmp->handle.len = 0;
+ new->handle.data = tmp->handle.data;
+ tmp->handle.data = NULL;
+ new->mechctx = NULL;
+ init_svc_cred(&new->cred);
+}
+
+static void
+update_rsc(struct cache_head *cnew, struct cache_head *ctmp)
+{
+ struct rsc *new = container_of(cnew, struct rsc, h);
+ struct rsc *tmp = container_of(ctmp, struct rsc, h);
+
+ new->mechctx = tmp->mechctx;
+ tmp->mechctx = NULL;
+ memset(&new->seqdata, 0, sizeof(new->seqdata));
+ spin_lock_init(&new->seqdata.sd_lock);
+ new->cred = tmp->cred;
+ init_svc_cred(&tmp->cred);
+}
+
+static struct cache_head *
+rsc_alloc(void)
+{
+ struct rsc *rsci = kmalloc(sizeof(*rsci), GFP_KERNEL);
+ if (rsci)
+ return &rsci->h;
+ else
+ return NULL;
+}
+
+static int rsc_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+ return -EINVAL;
+}
+
+static int rsc_parse(struct cache_detail *cd,
+ char *mesg, int mlen)
+{
+ /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */
+ char *buf = mesg;
+ int id;
+ int len, rv;
+ struct rsc rsci, *rscp = NULL;
+ time64_t expiry;
+ int status = -EINVAL;
+ struct gss_api_mech *gm = NULL;
+
+ memset(&rsci, 0, sizeof(rsci));
+ /* context handle */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0) goto out;
+ status = -ENOMEM;
+ if (dup_to_netobj(&rsci.handle, buf, len))
+ goto out;
+
+ rsci.h.flags = 0;
+ /* expiry */
+ status = get_expiry(&mesg, &expiry);
+ if (status)
+ goto out;
+
+ status = -EINVAL;
+ rscp = rsc_lookup(cd, &rsci);
+ if (!rscp)
+ goto out;
+
+ /* uid, or NEGATIVE */
+ rv = get_int(&mesg, &id);
+ if (rv == -EINVAL)
+ goto out;
+ if (rv == -ENOENT)
+ set_bit(CACHE_NEGATIVE, &rsci.h.flags);
+ else {
+ int N, i;
+
+ /*
+ * NOTE: we skip uid_valid()/gid_valid() checks here:
+ * instead, * -1 id's are later mapped to the
+ * (export-specific) anonymous id by nfsd_setuser.
+ *
+ * (But supplementary gid's get no such special
+ * treatment so are checked for validity here.)
+ */
+ /* uid */
+ rsci.cred.cr_uid = make_kuid(current_user_ns(), id);
+
+ /* gid */
+ if (get_int(&mesg, &id))
+ goto out;
+ rsci.cred.cr_gid = make_kgid(current_user_ns(), id);
+
+ /* number of additional gid's */
+ if (get_int(&mesg, &N))
+ goto out;
+ if (N < 0 || N > NGROUPS_MAX)
+ goto out;
+ status = -ENOMEM;
+ rsci.cred.cr_group_info = groups_alloc(N);
+ if (rsci.cred.cr_group_info == NULL)
+ goto out;
+
+ /* gid's */
+ status = -EINVAL;
+ for (i=0; i<N; i++) {
+ kgid_t kgid;
+ if (get_int(&mesg, &id))
+ goto out;
+ kgid = make_kgid(current_user_ns(), id);
+ if (!gid_valid(kgid))
+ goto out;
+ rsci.cred.cr_group_info->gid[i] = kgid;
+ }
+ groups_sort(rsci.cred.cr_group_info);
+
+ /* mech name */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0)
+ goto out;
+ gm = rsci.cred.cr_gss_mech = gss_mech_get_by_name(buf);
+ status = -EOPNOTSUPP;
+ if (!gm)
+ goto out;
+
+ status = -EINVAL;
+ /* mech-specific data: */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0)
+ goto out;
+ status = gss_import_sec_context(buf, len, gm, &rsci.mechctx,
+ NULL, GFP_KERNEL);
+ if (status)
+ goto out;
+
+ /* get client name */
+ len = qword_get(&mesg, buf, mlen);
+ if (len > 0) {
+ rsci.cred.cr_principal = kstrdup(buf, GFP_KERNEL);
+ if (!rsci.cred.cr_principal) {
+ status = -ENOMEM;
+ goto out;
+ }
+ }
+
+ }
+ rsci.h.expiry_time = expiry;
+ rscp = rsc_update(cd, &rsci, rscp);
+ status = 0;
+out:
+ rsc_free(&rsci);
+ if (rscp)
+ cache_put(&rscp->h, cd);
+ else
+ status = -ENOMEM;
+ return status;
+}
+
+static const struct cache_detail rsc_cache_template = {
+ .owner = THIS_MODULE,
+ .hash_size = RSC_HASHMAX,
+ .name = "auth.rpcsec.context",
+ .cache_put = rsc_put,
+ .cache_upcall = rsc_upcall,
+ .cache_parse = rsc_parse,
+ .match = rsc_match,
+ .init = rsc_init,
+ .update = update_rsc,
+ .alloc = rsc_alloc,
+};
+
+static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item)
+{
+ struct cache_head *ch;
+ int hash = rsc_hash(item);
+
+ ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash);
+ if (ch)
+ return container_of(ch, struct rsc, h);
+ else
+ return NULL;
+}
+
+static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old)
+{
+ struct cache_head *ch;
+ int hash = rsc_hash(new);
+
+ ch = sunrpc_cache_update(cd, &new->h,
+ &old->h, hash);
+ if (ch)
+ return container_of(ch, struct rsc, h);
+ else
+ return NULL;
+}
+
+
+static struct rsc *
+gss_svc_searchbyctx(struct cache_detail *cd, struct xdr_netobj *handle)
+{
+ struct rsc rsci;
+ struct rsc *found;
+
+ memset(&rsci, 0, sizeof(rsci));
+ if (dup_to_netobj(&rsci.handle, handle->data, handle->len))
+ return NULL;
+ found = rsc_lookup(cd, &rsci);
+ rsc_free(&rsci);
+ if (!found)
+ return NULL;
+ if (cache_check(cd, &found->h, NULL))
+ return NULL;
+ return found;
+}
+
+/**
+ * gss_check_seq_num - GSS sequence number window check
+ * @rqstp: RPC Call to use when reporting errors
+ * @rsci: cached GSS context state (updated on return)
+ * @seq_num: sequence number to check
+ *
+ * Implements sequence number algorithm as specified in
+ * RFC 2203, Section 5.3.3.1. "Context Management".
+ *
+ * Return values:
+ * %true: @rqstp's GSS sequence number is inside the window
+ * %false: @rqstp's GSS sequence number is outside the window
+ */
+static bool gss_check_seq_num(const struct svc_rqst *rqstp, struct rsc *rsci,
+ u32 seq_num)
+{
+ struct gss_svc_seq_data *sd = &rsci->seqdata;
+ bool result = false;
+
+ spin_lock(&sd->sd_lock);
+ if (seq_num > sd->sd_max) {
+ if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
+ memset(sd->sd_win, 0, sizeof(sd->sd_win));
+ sd->sd_max = seq_num;
+ } else while (sd->sd_max < seq_num) {
+ sd->sd_max++;
+ __clear_bit(sd->sd_max % GSS_SEQ_WIN, sd->sd_win);
+ }
+ __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
+ goto ok;
+ } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
+ goto toolow;
+ }
+ if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win))
+ goto alreadyseen;
+
+ok:
+ result = true;
+out:
+ spin_unlock(&sd->sd_lock);
+ return result;
+
+toolow:
+ trace_rpcgss_svc_seqno_low(rqstp, seq_num,
+ sd->sd_max - GSS_SEQ_WIN,
+ sd->sd_max);
+ goto out;
+alreadyseen:
+ trace_rpcgss_svc_seqno_seen(rqstp, seq_num);
+ goto out;
+}
+
+/*
+ * Decode and verify a Call's verifier field. For RPC_AUTH_GSS Calls,
+ * the body of this field contains a variable length checksum.
+ *
+ * GSS-specific auth_stat values are mandated by RFC 2203 Section
+ * 5.3.3.3.
+ */
+static int
+svcauth_gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
+ __be32 *rpcstart, struct rpc_gss_wire_cred *gc)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ struct gss_ctx *ctx_id = rsci->mechctx;
+ u32 flavor, maj_stat;
+ struct xdr_buf rpchdr;
+ struct xdr_netobj checksum;
+ struct kvec iov;
+
+ /*
+ * Compute the checksum of the incoming Call from the
+ * XID field to credential field:
+ */
+ iov.iov_base = rpcstart;
+ iov.iov_len = (u8 *)xdr->p - (u8 *)rpcstart;
+ xdr_buf_from_iov(&iov, &rpchdr);
+
+ /* Call's verf field: */
+ if (xdr_stream_decode_opaque_auth(xdr, &flavor,
+ (void **)&checksum.data,
+ &checksum.len) < 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+ if (flavor != RPC_AUTH_GSS) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+
+ if (rqstp->rq_deferred)
+ return SVC_OK;
+ maj_stat = gss_verify_mic(ctx_id, &rpchdr, &checksum);
+ if (maj_stat != GSS_S_COMPLETE) {
+ trace_rpcgss_svc_mic(rqstp, maj_stat);
+ rqstp->rq_auth_stat = rpcsec_gsserr_credproblem;
+ return SVC_DENIED;
+ }
+
+ if (gc->gc_seq > MAXSEQ) {
+ trace_rpcgss_svc_seqno_large(rqstp, gc->gc_seq);
+ rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem;
+ return SVC_DENIED;
+ }
+ if (!gss_check_seq_num(rqstp, rsci, gc->gc_seq))
+ return SVC_DROP;
+ return SVC_OK;
+}
+
+/*
+ * Construct and encode a Reply's verifier field. The verifier's body
+ * field contains a variable-length checksum of the GSS sequence
+ * number.
+ */
+static bool
+svcauth_gss_encode_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
+{
+ struct gss_svc_data *gsd = rqstp->rq_auth_data;
+ u32 maj_stat;
+ struct xdr_buf verf_data;
+ struct xdr_netobj checksum;
+ struct kvec iov;
+
+ gsd->gsd_seq_num = cpu_to_be32(seq);
+ iov.iov_base = &gsd->gsd_seq_num;
+ iov.iov_len = XDR_UNIT;
+ xdr_buf_from_iov(&iov, &verf_data);
+
+ checksum.data = gsd->gsd_scratch;
+ maj_stat = gss_get_mic(ctx_id, &verf_data, &checksum);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_mic;
+
+ return xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream, RPC_AUTH_GSS,
+ checksum.data, checksum.len) > 0;
+
+bad_mic:
+ trace_rpcgss_svc_get_mic(rqstp, maj_stat);
+ return false;
+}
+
+struct gss_domain {
+ struct auth_domain h;
+ u32 pseudoflavor;
+};
+
+static struct auth_domain *
+find_gss_auth_domain(struct gss_ctx *ctx, u32 svc)
+{
+ char *name;
+
+ name = gss_service_to_auth_domain_name(ctx->mech_type, svc);
+ if (!name)
+ return NULL;
+ return auth_domain_find(name);
+}
+
+static struct auth_ops svcauthops_gss;
+
+u32 svcauth_gss_flavor(struct auth_domain *dom)
+{
+ struct gss_domain *gd = container_of(dom, struct gss_domain, h);
+
+ return gd->pseudoflavor;
+}
+
+EXPORT_SYMBOL_GPL(svcauth_gss_flavor);
+
+struct auth_domain *
+svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name)
+{
+ struct gss_domain *new;
+ struct auth_domain *test;
+ int stat = -ENOMEM;
+
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (!new)
+ goto out;
+ kref_init(&new->h.ref);
+ new->h.name = kstrdup(name, GFP_KERNEL);
+ if (!new->h.name)
+ goto out_free_dom;
+ new->h.flavour = &svcauthops_gss;
+ new->pseudoflavor = pseudoflavor;
+
+ test = auth_domain_lookup(name, &new->h);
+ if (test != &new->h) {
+ pr_warn("svc: duplicate registration of gss pseudo flavour %s.\n",
+ name);
+ stat = -EADDRINUSE;
+ auth_domain_put(test);
+ goto out_free_name;
+ }
+ return test;
+
+out_free_name:
+ kfree(new->h.name);
+out_free_dom:
+ kfree(new);
+out:
+ return ERR_PTR(stat);
+}
+EXPORT_SYMBOL_GPL(svcauth_gss_register_pseudoflavor);
+
+/*
+ * RFC 2203, Section 5.3.2.2
+ *
+ * struct rpc_gss_integ_data {
+ * opaque databody_integ<>;
+ * opaque checksum<>;
+ * };
+ *
+ * struct rpc_gss_data_t {
+ * unsigned int seq_num;
+ * proc_req_arg_t arg;
+ * };
+ */
+static noinline_for_stack int
+svcauth_gss_unwrap_integ(struct svc_rqst *rqstp, u32 seq, struct gss_ctx *ctx)
+{
+ struct gss_svc_data *gsd = rqstp->rq_auth_data;
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ u32 len, offset, seq_num, maj_stat;
+ struct xdr_buf *buf = xdr->buf;
+ struct xdr_buf databody_integ;
+ struct xdr_netobj checksum;
+
+ /* NFS READ normally uses splice to send data in-place. However
+ * the data in cache can change after the reply's MIC is computed
+ * but before the RPC reply is sent. To prevent the client from
+ * rejecting the server-computed MIC in this somewhat rare case,
+ * do not use splice with the GSS integrity service.
+ */
+ clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags);
+
+ /* Did we already verify the signature on the original pass through? */
+ if (rqstp->rq_deferred)
+ return 0;
+
+ if (xdr_stream_decode_u32(xdr, &len) < 0)
+ goto unwrap_failed;
+ if (len & 3)
+ goto unwrap_failed;
+ offset = xdr_stream_pos(xdr);
+ if (xdr_buf_subsegment(buf, &databody_integ, offset, len))
+ goto unwrap_failed;
+
+ /*
+ * The xdr_stream now points to the @seq_num field. The next
+ * XDR data item is the @arg field, which contains the clear
+ * text RPC program payload. The checksum, which follows the
+ * @arg field, is located and decoded without updating the
+ * xdr_stream.
+ */
+
+ offset += len;
+ if (xdr_decode_word(buf, offset, &checksum.len))
+ goto unwrap_failed;
+ if (checksum.len > sizeof(gsd->gsd_scratch))
+ goto unwrap_failed;
+ checksum.data = gsd->gsd_scratch;
+ if (read_bytes_from_xdr_buf(buf, offset + XDR_UNIT, checksum.data,
+ checksum.len))
+ goto unwrap_failed;
+
+ maj_stat = gss_verify_mic(ctx, &databody_integ, &checksum);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_mic;
+
+ /* The received seqno is protected by the checksum. */
+ if (xdr_stream_decode_u32(xdr, &seq_num) < 0)
+ goto unwrap_failed;
+ if (seq_num != seq)
+ goto bad_seqno;
+
+ xdr_truncate_decode(xdr, XDR_UNIT + checksum.len);
+ return 0;
+
+unwrap_failed:
+ trace_rpcgss_svc_unwrap_failed(rqstp);
+ return -EINVAL;
+bad_seqno:
+ trace_rpcgss_svc_seqno_bad(rqstp, seq, seq_num);
+ return -EINVAL;
+bad_mic:
+ trace_rpcgss_svc_mic(rqstp, maj_stat);
+ return -EINVAL;
+}
+
+/*
+ * RFC 2203, Section 5.3.2.3
+ *
+ * struct rpc_gss_priv_data {
+ * opaque databody_priv<>
+ * };
+ *
+ * struct rpc_gss_data_t {
+ * unsigned int seq_num;
+ * proc_req_arg_t arg;
+ * };
+ */
+static noinline_for_stack int
+svcauth_gss_unwrap_priv(struct svc_rqst *rqstp, u32 seq, struct gss_ctx *ctx)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ u32 len, maj_stat, seq_num, offset;
+ struct xdr_buf *buf = xdr->buf;
+ unsigned int saved_len;
+
+ clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags);
+
+ if (xdr_stream_decode_u32(xdr, &len) < 0)
+ goto unwrap_failed;
+ if (rqstp->rq_deferred) {
+ /* Already decrypted last time through! The sequence number
+ * check at out_seq is unnecessary but harmless: */
+ goto out_seq;
+ }
+ if (len > xdr_stream_remaining(xdr))
+ goto unwrap_failed;
+ offset = xdr_stream_pos(xdr);
+
+ saved_len = buf->len;
+ maj_stat = gss_unwrap(ctx, offset, offset + len, buf);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_unwrap;
+ xdr->nwords -= XDR_QUADLEN(saved_len - buf->len);
+
+out_seq:
+ /* gss_unwrap() decrypted the sequence number. */
+ if (xdr_stream_decode_u32(xdr, &seq_num) < 0)
+ goto unwrap_failed;
+ if (seq_num != seq)
+ goto bad_seqno;
+ return 0;
+
+unwrap_failed:
+ trace_rpcgss_svc_unwrap_failed(rqstp);
+ return -EINVAL;
+bad_seqno:
+ trace_rpcgss_svc_seqno_bad(rqstp, seq, seq_num);
+ return -EINVAL;
+bad_unwrap:
+ trace_rpcgss_svc_unwrap(rqstp, maj_stat);
+ return -EINVAL;
+}
+
+static enum svc_auth_status
+svcauth_gss_set_client(struct svc_rqst *rqstp)
+{
+ struct gss_svc_data *svcdata = rqstp->rq_auth_data;
+ struct rsc *rsci = svcdata->rsci;
+ struct rpc_gss_wire_cred *gc = &svcdata->clcred;
+ int stat;
+
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+
+ /*
+ * A gss export can be specified either by:
+ * export *(sec=krb5,rw)
+ * or by
+ * export gss/krb5(rw)
+ * The latter is deprecated; but for backwards compatibility reasons
+ * the nfsd code will still fall back on trying it if the former
+ * doesn't work; so we try to make both available to nfsd, below.
+ */
+ rqstp->rq_gssclient = find_gss_auth_domain(rsci->mechctx, gc->gc_svc);
+ if (rqstp->rq_gssclient == NULL)
+ return SVC_DENIED;
+ stat = svcauth_unix_set_client(rqstp);
+ if (stat == SVC_DROP || stat == SVC_CLOSE)
+ return stat;
+
+ rqstp->rq_auth_stat = rpc_auth_ok;
+ return SVC_OK;
+}
+
+static bool
+svcauth_gss_proc_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp,
+ struct xdr_netobj *out_handle, int *major_status,
+ u32 seq_num)
+{
+ struct xdr_stream *xdr = &rqstp->rq_res_stream;
+ struct rsc *rsci;
+ bool rc;
+
+ if (*major_status != GSS_S_COMPLETE)
+ goto null_verifier;
+ rsci = gss_svc_searchbyctx(cd, out_handle);
+ if (rsci == NULL) {
+ *major_status = GSS_S_NO_CONTEXT;
+ goto null_verifier;
+ }
+
+ rc = svcauth_gss_encode_verf(rqstp, rsci->mechctx, seq_num);
+ cache_put(&rsci->h, cd);
+ return rc;
+
+null_verifier:
+ return xdr_stream_encode_opaque_auth(xdr, RPC_AUTH_NULL, NULL, 0) > 0;
+}
+
+static void gss_free_in_token_pages(struct gssp_in_token *in_token)
+{
+ u32 inlen;
+ int i;
+
+ i = 0;
+ inlen = in_token->page_len;
+ while (inlen) {
+ if (in_token->pages[i])
+ put_page(in_token->pages[i]);
+ inlen -= inlen > PAGE_SIZE ? PAGE_SIZE : inlen;
+ }
+
+ kfree(in_token->pages);
+ in_token->pages = NULL;
+}
+
+static int gss_read_proxy_verf(struct svc_rqst *rqstp,
+ struct rpc_gss_wire_cred *gc,
+ struct xdr_netobj *in_handle,
+ struct gssp_in_token *in_token)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ unsigned int length, pgto_offs, pgfrom_offs;
+ int pages, i, pgto, pgfrom;
+ size_t to_offs, from_offs;
+ u32 inlen;
+
+ if (dup_netobj(in_handle, &gc->gc_ctx))
+ return SVC_CLOSE;
+
+ /*
+ * RFC 2203 Section 5.2.2
+ *
+ * struct rpc_gss_init_arg {
+ * opaque gss_token<>;
+ * };
+ */
+ if (xdr_stream_decode_u32(xdr, &inlen) < 0)
+ goto out_denied_free;
+ if (inlen > xdr_stream_remaining(xdr))
+ goto out_denied_free;
+
+ pages = DIV_ROUND_UP(inlen, PAGE_SIZE);
+ in_token->pages = kcalloc(pages, sizeof(struct page *), GFP_KERNEL);
+ if (!in_token->pages)
+ goto out_denied_free;
+ in_token->page_base = 0;
+ in_token->page_len = inlen;
+ for (i = 0; i < pages; i++) {
+ in_token->pages[i] = alloc_page(GFP_KERNEL);
+ if (!in_token->pages[i]) {
+ gss_free_in_token_pages(in_token);
+ goto out_denied_free;
+ }
+ }
+
+ length = min_t(unsigned int, inlen, (char *)xdr->end - (char *)xdr->p);
+ memcpy(page_address(in_token->pages[0]), xdr->p, length);
+ inlen -= length;
+
+ to_offs = length;
+ from_offs = rqstp->rq_arg.page_base;
+ while (inlen) {
+ pgto = to_offs >> PAGE_SHIFT;
+ pgfrom = from_offs >> PAGE_SHIFT;
+ pgto_offs = to_offs & ~PAGE_MASK;
+ pgfrom_offs = from_offs & ~PAGE_MASK;
+
+ length = min_t(unsigned int, inlen,
+ min_t(unsigned int, PAGE_SIZE - pgto_offs,
+ PAGE_SIZE - pgfrom_offs));
+ memcpy(page_address(in_token->pages[pgto]) + pgto_offs,
+ page_address(rqstp->rq_arg.pages[pgfrom]) + pgfrom_offs,
+ length);
+
+ to_offs += length;
+ from_offs += length;
+ inlen -= length;
+ }
+ return 0;
+
+out_denied_free:
+ kfree(in_handle->data);
+ return SVC_DENIED;
+}
+
+/*
+ * RFC 2203, Section 5.2.3.1.
+ *
+ * struct rpc_gss_init_res {
+ * opaque handle<>;
+ * unsigned int gss_major;
+ * unsigned int gss_minor;
+ * unsigned int seq_window;
+ * opaque gss_token<>;
+ * };
+ */
+static bool
+svcxdr_encode_gss_init_res(struct xdr_stream *xdr,
+ struct xdr_netobj *handle,
+ struct xdr_netobj *gss_token,
+ unsigned int major_status,
+ unsigned int minor_status, u32 seq_num)
+{
+ if (xdr_stream_encode_opaque(xdr, handle->data, handle->len) < 0)
+ return false;
+ if (xdr_stream_encode_u32(xdr, major_status) < 0)
+ return false;
+ if (xdr_stream_encode_u32(xdr, minor_status) < 0)
+ return false;
+ if (xdr_stream_encode_u32(xdr, seq_num) < 0)
+ return false;
+ if (xdr_stream_encode_opaque(xdr, gss_token->data, gss_token->len) < 0)
+ return false;
+ return true;
+}
+
+/*
+ * Having read the cred already and found we're in the context
+ * initiation case, read the verifier and initiate (or check the results
+ * of) upcalls to userspace for help with context initiation. If
+ * the upcall results are available, write the verifier and result.
+ * Otherwise, drop the request pending an answer to the upcall.
+ */
+static int
+svcauth_gss_legacy_init(struct svc_rqst *rqstp,
+ struct rpc_gss_wire_cred *gc)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ struct rsi *rsip, rsikey;
+ __be32 *p;
+ u32 len;
+ int ret;
+ struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id);
+
+ memset(&rsikey, 0, sizeof(rsikey));
+ if (dup_netobj(&rsikey.in_handle, &gc->gc_ctx))
+ return SVC_CLOSE;
+
+ /*
+ * RFC 2203 Section 5.2.2
+ *
+ * struct rpc_gss_init_arg {
+ * opaque gss_token<>;
+ * };
+ */
+ if (xdr_stream_decode_u32(xdr, &len) < 0) {
+ kfree(rsikey.in_handle.data);
+ return SVC_DENIED;
+ }
+ p = xdr_inline_decode(xdr, len);
+ if (!p) {
+ kfree(rsikey.in_handle.data);
+ return SVC_DENIED;
+ }
+ rsikey.in_token.data = kmalloc(len, GFP_KERNEL);
+ if (ZERO_OR_NULL_PTR(rsikey.in_token.data)) {
+ kfree(rsikey.in_handle.data);
+ return SVC_CLOSE;
+ }
+ memcpy(rsikey.in_token.data, p, len);
+ rsikey.in_token.len = len;
+
+ /* Perform upcall, or find upcall result: */
+ rsip = rsi_lookup(sn->rsi_cache, &rsikey);
+ rsi_free(&rsikey);
+ if (!rsip)
+ return SVC_CLOSE;
+ if (cache_check(sn->rsi_cache, &rsip->h, &rqstp->rq_chandle) < 0)
+ /* No upcall result: */
+ return SVC_CLOSE;
+
+ ret = SVC_CLOSE;
+ if (!svcauth_gss_proc_init_verf(sn->rsc_cache, rqstp, &rsip->out_handle,
+ &rsip->major_status, GSS_SEQ_WIN))
+ goto out;
+ if (!svcxdr_set_accept_stat(rqstp))
+ goto out;
+ if (!svcxdr_encode_gss_init_res(&rqstp->rq_res_stream, &rsip->out_handle,
+ &rsip->out_token, rsip->major_status,
+ rsip->minor_status, GSS_SEQ_WIN))
+ goto out;
+
+ ret = SVC_COMPLETE;
+out:
+ cache_put(&rsip->h, sn->rsi_cache);
+ return ret;
+}
+
+static int gss_proxy_save_rsc(struct cache_detail *cd,
+ struct gssp_upcall_data *ud,
+ uint64_t *handle)
+{
+ struct rsc rsci, *rscp = NULL;
+ static atomic64_t ctxhctr;
+ long long ctxh;
+ struct gss_api_mech *gm = NULL;
+ time64_t expiry;
+ int status;
+
+ memset(&rsci, 0, sizeof(rsci));
+ /* context handle */
+ status = -ENOMEM;
+ /* the handle needs to be just a unique id,
+ * use a static counter */
+ ctxh = atomic64_inc_return(&ctxhctr);
+
+ /* make a copy for the caller */
+ *handle = ctxh;
+
+ /* make a copy for the rsc cache */
+ if (dup_to_netobj(&rsci.handle, (char *)handle, sizeof(uint64_t)))
+ goto out;
+ rscp = rsc_lookup(cd, &rsci);
+ if (!rscp)
+ goto out;
+
+ /* creds */
+ if (!ud->found_creds) {
+ /* userspace seem buggy, we should always get at least a
+ * mapping to nobody */
+ goto out;
+ } else {
+ struct timespec64 boot;
+
+ /* steal creds */
+ rsci.cred = ud->creds;
+ memset(&ud->creds, 0, sizeof(struct svc_cred));
+
+ status = -EOPNOTSUPP;
+ /* get mech handle from OID */
+ gm = gss_mech_get_by_OID(&ud->mech_oid);
+ if (!gm)
+ goto out;
+ rsci.cred.cr_gss_mech = gm;
+
+ status = -EINVAL;
+ /* mech-specific data: */
+ status = gss_import_sec_context(ud->out_handle.data,
+ ud->out_handle.len,
+ gm, &rsci.mechctx,
+ &expiry, GFP_KERNEL);
+ if (status)
+ goto out;
+
+ getboottime64(&boot);
+ expiry -= boot.tv_sec;
+ }
+
+ rsci.h.expiry_time = expiry;
+ rscp = rsc_update(cd, &rsci, rscp);
+ status = 0;
+out:
+ rsc_free(&rsci);
+ if (rscp)
+ cache_put(&rscp->h, cd);
+ else
+ status = -ENOMEM;
+ return status;
+}
+
+static int svcauth_gss_proxy_init(struct svc_rqst *rqstp,
+ struct rpc_gss_wire_cred *gc)
+{
+ struct xdr_netobj cli_handle;
+ struct gssp_upcall_data ud;
+ uint64_t handle;
+ int status;
+ int ret;
+ struct net *net = SVC_NET(rqstp);
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ memset(&ud, 0, sizeof(ud));
+ ret = gss_read_proxy_verf(rqstp, gc, &ud.in_handle, &ud.in_token);
+ if (ret)
+ return ret;
+
+ ret = SVC_CLOSE;
+
+ /* Perform synchronous upcall to gss-proxy */
+ status = gssp_accept_sec_context_upcall(net, &ud);
+ if (status)
+ goto out;
+
+ trace_rpcgss_svc_accept_upcall(rqstp, ud.major_status, ud.minor_status);
+
+ switch (ud.major_status) {
+ case GSS_S_CONTINUE_NEEDED:
+ cli_handle = ud.out_handle;
+ break;
+ case GSS_S_COMPLETE:
+ status = gss_proxy_save_rsc(sn->rsc_cache, &ud, &handle);
+ if (status)
+ goto out;
+ cli_handle.data = (u8 *)&handle;
+ cli_handle.len = sizeof(handle);
+ break;
+ default:
+ goto out;
+ }
+
+ if (!svcauth_gss_proc_init_verf(sn->rsc_cache, rqstp, &cli_handle,
+ &ud.major_status, GSS_SEQ_WIN))
+ goto out;
+ if (!svcxdr_set_accept_stat(rqstp))
+ goto out;
+ if (!svcxdr_encode_gss_init_res(&rqstp->rq_res_stream, &cli_handle,
+ &ud.out_token, ud.major_status,
+ ud.minor_status, GSS_SEQ_WIN))
+ goto out;
+
+ ret = SVC_COMPLETE;
+out:
+ gss_free_in_token_pages(&ud.in_token);
+ gssp_free_upcall_data(&ud);
+ return ret;
+}
+
+/*
+ * Try to set the sn->use_gss_proxy variable to a new value. We only allow
+ * it to be changed if it's currently undefined (-1). If it's any other value
+ * then return -EBUSY unless the type wouldn't have changed anyway.
+ */
+static int set_gss_proxy(struct net *net, int type)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ int ret;
+
+ WARN_ON_ONCE(type != 0 && type != 1);
+ ret = cmpxchg(&sn->use_gss_proxy, -1, type);
+ if (ret != -1 && ret != type)
+ return -EBUSY;
+ return 0;
+}
+
+static bool use_gss_proxy(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ /* If use_gss_proxy is still undefined, then try to disable it */
+ if (sn->use_gss_proxy == -1)
+ set_gss_proxy(net, 0);
+ return sn->use_gss_proxy;
+}
+
+static noinline_for_stack int
+svcauth_gss_proc_init(struct svc_rqst *rqstp, struct rpc_gss_wire_cred *gc)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ u32 flavor, len;
+ void *body;
+
+ /* Call's verf field: */
+ if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0)
+ return SVC_GARBAGE;
+ if (flavor != RPC_AUTH_NULL || len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+
+ if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+ }
+
+ if (!use_gss_proxy(SVC_NET(rqstp)))
+ return svcauth_gss_legacy_init(rqstp, gc);
+ return svcauth_gss_proxy_init(rqstp, gc);
+}
+
+#ifdef CONFIG_PROC_FS
+
+static ssize_t write_gssp(struct file *file, const char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct net *net = pde_data(file_inode(file));
+ char tbuf[20];
+ unsigned long i;
+ int res;
+
+ if (*ppos || count > sizeof(tbuf)-1)
+ return -EINVAL;
+ if (copy_from_user(tbuf, buf, count))
+ return -EFAULT;
+
+ tbuf[count] = 0;
+ res = kstrtoul(tbuf, 0, &i);
+ if (res)
+ return res;
+ if (i != 1)
+ return -EINVAL;
+ res = set_gssp_clnt(net);
+ if (res)
+ return res;
+ res = set_gss_proxy(net, 1);
+ if (res)
+ return res;
+ return count;
+}
+
+static ssize_t read_gssp(struct file *file, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct net *net = pde_data(file_inode(file));
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ unsigned long p = *ppos;
+ char tbuf[10];
+ size_t len;
+
+ snprintf(tbuf, sizeof(tbuf), "%d\n", sn->use_gss_proxy);
+ len = strlen(tbuf);
+ if (p >= len)
+ return 0;
+ len -= p;
+ if (len > count)
+ len = count;
+ if (copy_to_user(buf, (void *)(tbuf+p), len))
+ return -EFAULT;
+ *ppos += len;
+ return len;
+}
+
+static const struct proc_ops use_gss_proxy_proc_ops = {
+ .proc_open = nonseekable_open,
+ .proc_write = write_gssp,
+ .proc_read = read_gssp,
+};
+
+static int create_use_gss_proxy_proc_entry(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct proc_dir_entry **p = &sn->use_gssp_proc;
+
+ sn->use_gss_proxy = -1;
+ *p = proc_create_data("use-gss-proxy", S_IFREG | 0600,
+ sn->proc_net_rpc,
+ &use_gss_proxy_proc_ops, net);
+ if (!*p)
+ return -ENOMEM;
+ init_gssp_clnt(sn);
+ return 0;
+}
+
+static void destroy_use_gss_proxy_proc_entry(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ if (sn->use_gssp_proc) {
+ remove_proc_entry("use-gss-proxy", sn->proc_net_rpc);
+ clear_gssp_clnt(sn);
+ }
+}
+
+static ssize_t read_gss_krb5_enctypes(struct file *file, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct rpcsec_gss_oid oid = {
+ .len = 9,
+ .data = "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02",
+ };
+ struct gss_api_mech *mech;
+ ssize_t ret;
+
+ mech = gss_mech_get_by_OID(&oid);
+ if (!mech)
+ return 0;
+ if (!mech->gm_upcall_enctypes) {
+ gss_mech_put(mech);
+ return 0;
+ }
+
+ ret = simple_read_from_buffer(buf, count, ppos,
+ mech->gm_upcall_enctypes,
+ strlen(mech->gm_upcall_enctypes));
+ gss_mech_put(mech);
+ return ret;
+}
+
+static const struct proc_ops gss_krb5_enctypes_proc_ops = {
+ .proc_open = nonseekable_open,
+ .proc_read = read_gss_krb5_enctypes,
+};
+
+static int create_krb5_enctypes_proc_entry(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ sn->gss_krb5_enctypes =
+ proc_create_data("gss_krb5_enctypes", S_IFREG | 0444,
+ sn->proc_net_rpc, &gss_krb5_enctypes_proc_ops,
+ net);
+ return sn->gss_krb5_enctypes ? 0 : -ENOMEM;
+}
+
+static void destroy_krb5_enctypes_proc_entry(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ if (sn->gss_krb5_enctypes)
+ remove_proc_entry("gss_krb5_enctypes", sn->proc_net_rpc);
+}
+
+#else /* CONFIG_PROC_FS */
+
+static int create_use_gss_proxy_proc_entry(struct net *net)
+{
+ return 0;
+}
+
+static void destroy_use_gss_proxy_proc_entry(struct net *net) {}
+
+static int create_krb5_enctypes_proc_entry(struct net *net)
+{
+ return 0;
+}
+
+static void destroy_krb5_enctypes_proc_entry(struct net *net) {}
+
+#endif /* CONFIG_PROC_FS */
+
+/*
+ * The Call's credential body should contain a struct rpc_gss_cred_t.
+ *
+ * RFC 2203 Section 5
+ *
+ * struct rpc_gss_cred_t {
+ * union switch (unsigned int version) {
+ * case RPCSEC_GSS_VERS_1:
+ * struct {
+ * rpc_gss_proc_t gss_proc;
+ * unsigned int seq_num;
+ * rpc_gss_service_t service;
+ * opaque handle<>;
+ * } rpc_gss_cred_vers_1_t;
+ * }
+ * };
+ */
+static bool
+svcauth_gss_decode_credbody(struct xdr_stream *xdr,
+ struct rpc_gss_wire_cred *gc,
+ __be32 **rpcstart)
+{
+ ssize_t handle_len;
+ u32 body_len;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, XDR_UNIT);
+ if (!p)
+ return false;
+ /*
+ * start of rpc packet is 7 u32's back from here:
+ * xid direction rpcversion prog vers proc flavour
+ */
+ *rpcstart = p - 7;
+ body_len = be32_to_cpup(p);
+ if (body_len > RPC_MAX_AUTH_SIZE)
+ return false;
+
+ /* struct rpc_gss_cred_t */
+ if (xdr_stream_decode_u32(xdr, &gc->gc_v) < 0)
+ return false;
+ if (xdr_stream_decode_u32(xdr, &gc->gc_proc) < 0)
+ return false;
+ if (xdr_stream_decode_u32(xdr, &gc->gc_seq) < 0)
+ return false;
+ if (xdr_stream_decode_u32(xdr, &gc->gc_svc) < 0)
+ return false;
+ handle_len = xdr_stream_decode_opaque_inline(xdr,
+ (void **)&gc->gc_ctx.data,
+ body_len);
+ if (handle_len < 0)
+ return false;
+ if (body_len != XDR_UNIT * 5 + xdr_align_size(handle_len))
+ return false;
+
+ gc->gc_ctx.len = handle_len;
+ return true;
+}
+
+/**
+ * svcauth_gss_accept - Decode and validate incoming RPC_AUTH_GSS credential
+ * @rqstp: RPC transaction
+ *
+ * Return values:
+ * %SVC_OK: Success
+ * %SVC_COMPLETE: GSS context lifetime event
+ * %SVC_DENIED: Credential or verifier is not valid
+ * %SVC_GARBAGE: Failed to decode credential or verifier
+ * %SVC_CLOSE: Temporary failure
+ *
+ * The rqstp->rq_auth_stat field is also set (see RFCs 2203 and 5531).
+ */
+static enum svc_auth_status
+svcauth_gss_accept(struct svc_rqst *rqstp)
+{
+ struct gss_svc_data *svcdata = rqstp->rq_auth_data;
+ __be32 *rpcstart;
+ struct rpc_gss_wire_cred *gc;
+ struct rsc *rsci = NULL;
+ int ret;
+ struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id);
+
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ if (!svcdata)
+ svcdata = kmalloc(sizeof(*svcdata), GFP_KERNEL);
+ if (!svcdata)
+ goto auth_err;
+ rqstp->rq_auth_data = svcdata;
+ svcdata->gsd_databody_offset = 0;
+ svcdata->rsci = NULL;
+ gc = &svcdata->clcred;
+
+ if (!svcauth_gss_decode_credbody(&rqstp->rq_arg_stream, gc, &rpcstart))
+ goto auth_err;
+ if (gc->gc_v != RPC_GSS_VERSION)
+ goto auth_err;
+
+ switch (gc->gc_proc) {
+ case RPC_GSS_PROC_INIT:
+ case RPC_GSS_PROC_CONTINUE_INIT:
+ if (rqstp->rq_proc != 0)
+ goto auth_err;
+ return svcauth_gss_proc_init(rqstp, gc);
+ case RPC_GSS_PROC_DESTROY:
+ if (rqstp->rq_proc != 0)
+ goto auth_err;
+ fallthrough;
+ case RPC_GSS_PROC_DATA:
+ rqstp->rq_auth_stat = rpcsec_gsserr_credproblem;
+ rsci = gss_svc_searchbyctx(sn->rsc_cache, &gc->gc_ctx);
+ if (!rsci)
+ goto auth_err;
+ switch (svcauth_gss_verify_header(rqstp, rsci, rpcstart, gc)) {
+ case SVC_OK:
+ break;
+ case SVC_DENIED:
+ goto auth_err;
+ case SVC_DROP:
+ goto drop;
+ }
+ break;
+ default:
+ if (rqstp->rq_proc != 0)
+ goto auth_err;
+ rqstp->rq_auth_stat = rpc_autherr_rejectedcred;
+ goto auth_err;
+ }
+
+ /* now act upon the command: */
+ switch (gc->gc_proc) {
+ case RPC_GSS_PROC_DESTROY:
+ if (!svcauth_gss_encode_verf(rqstp, rsci->mechctx, gc->gc_seq))
+ goto auth_err;
+ if (!svcxdr_set_accept_stat(rqstp))
+ goto auth_err;
+ /* Delete the entry from the cache_list and call cache_put */
+ sunrpc_cache_unhash(sn->rsc_cache, &rsci->h);
+ goto complete;
+ case RPC_GSS_PROC_DATA:
+ rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem;
+ if (!svcauth_gss_encode_verf(rqstp, rsci->mechctx, gc->gc_seq))
+ goto auth_err;
+ if (!svcxdr_set_accept_stat(rqstp))
+ goto auth_err;
+ svcdata->gsd_databody_offset = xdr_stream_pos(&rqstp->rq_res_stream);
+ rqstp->rq_cred = rsci->cred;
+ get_group_info(rsci->cred.cr_group_info);
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ switch (gc->gc_svc) {
+ case RPC_GSS_SVC_NONE:
+ break;
+ case RPC_GSS_SVC_INTEGRITY:
+ /* placeholders for body length and seq. number: */
+ xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2);
+ if (svcauth_gss_unwrap_integ(rqstp, gc->gc_seq,
+ rsci->mechctx))
+ goto garbage_args;
+ svcxdr_set_auth_slack(rqstp, RPC_MAX_AUTH_SIZE);
+ break;
+ case RPC_GSS_SVC_PRIVACY:
+ /* placeholders for body length and seq. number: */
+ xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2);
+ if (svcauth_gss_unwrap_priv(rqstp, gc->gc_seq,
+ rsci->mechctx))
+ goto garbage_args;
+ svcxdr_set_auth_slack(rqstp, RPC_MAX_AUTH_SIZE * 2);
+ break;
+ default:
+ goto auth_err;
+ }
+ svcdata->rsci = rsci;
+ cache_get(&rsci->h);
+ rqstp->rq_cred.cr_flavor = gss_svc_to_pseudoflavor(
+ rsci->mechctx->mech_type,
+ GSS_C_QOP_DEFAULT,
+ gc->gc_svc);
+ ret = SVC_OK;
+ trace_rpcgss_svc_authenticate(rqstp, gc);
+ goto out;
+ }
+garbage_args:
+ ret = SVC_GARBAGE;
+ goto out;
+auth_err:
+ xdr_truncate_encode(&rqstp->rq_res_stream, XDR_UNIT * 2);
+ ret = SVC_DENIED;
+ goto out;
+complete:
+ ret = SVC_COMPLETE;
+ goto out;
+drop:
+ ret = SVC_CLOSE;
+out:
+ if (rsci)
+ cache_put(&rsci->h, sn->rsc_cache);
+ return ret;
+}
+
+static u32
+svcauth_gss_prepare_to_wrap(struct svc_rqst *rqstp, struct gss_svc_data *gsd)
+{
+ u32 offset;
+
+ /* Release can be called twice, but we only wrap once. */
+ offset = gsd->gsd_databody_offset;
+ gsd->gsd_databody_offset = 0;
+
+ /* AUTH_ERROR replies are not wrapped. */
+ if (rqstp->rq_auth_stat != rpc_auth_ok)
+ return 0;
+
+ /* Also don't wrap if the accept_stat is nonzero: */
+ if (*rqstp->rq_accept_statp != rpc_success)
+ return 0;
+
+ return offset;
+}
+
+/*
+ * RFC 2203, Section 5.3.2.2
+ *
+ * struct rpc_gss_integ_data {
+ * opaque databody_integ<>;
+ * opaque checksum<>;
+ * };
+ *
+ * struct rpc_gss_data_t {
+ * unsigned int seq_num;
+ * proc_req_arg_t arg;
+ * };
+ *
+ * The RPC Reply message has already been XDR-encoded. rq_res_stream
+ * is now positioned so that the checksum can be written just past
+ * the RPC Reply message.
+ */
+static int svcauth_gss_wrap_integ(struct svc_rqst *rqstp)
+{
+ struct gss_svc_data *gsd = rqstp->rq_auth_data;
+ struct xdr_stream *xdr = &rqstp->rq_res_stream;
+ struct rpc_gss_wire_cred *gc = &gsd->clcred;
+ struct xdr_buf *buf = xdr->buf;
+ struct xdr_buf databody_integ;
+ struct xdr_netobj checksum;
+ u32 offset, maj_stat;
+
+ offset = svcauth_gss_prepare_to_wrap(rqstp, gsd);
+ if (!offset)
+ goto out;
+
+ if (xdr_buf_subsegment(buf, &databody_integ, offset + XDR_UNIT,
+ buf->len - offset - XDR_UNIT))
+ goto wrap_failed;
+ /* Buffer space for these has already been reserved in
+ * svcauth_gss_accept(). */
+ if (xdr_encode_word(buf, offset, databody_integ.len))
+ goto wrap_failed;
+ if (xdr_encode_word(buf, offset + XDR_UNIT, gc->gc_seq))
+ goto wrap_failed;
+
+ checksum.data = gsd->gsd_scratch;
+ maj_stat = gss_get_mic(gsd->rsci->mechctx, &databody_integ, &checksum);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_mic;
+
+ if (xdr_stream_encode_opaque(xdr, checksum.data, checksum.len) < 0)
+ goto wrap_failed;
+ xdr_commit_encode(xdr);
+
+out:
+ return 0;
+
+bad_mic:
+ trace_rpcgss_svc_get_mic(rqstp, maj_stat);
+ return -EINVAL;
+wrap_failed:
+ trace_rpcgss_svc_wrap_failed(rqstp);
+ return -EINVAL;
+}
+
+/*
+ * RFC 2203, Section 5.3.2.3
+ *
+ * struct rpc_gss_priv_data {
+ * opaque databody_priv<>
+ * };
+ *
+ * struct rpc_gss_data_t {
+ * unsigned int seq_num;
+ * proc_req_arg_t arg;
+ * };
+ *
+ * gss_wrap() expands the size of the RPC message payload in the
+ * response buffer. The main purpose of svcauth_gss_wrap_priv()
+ * is to ensure there is adequate space in the response buffer to
+ * avoid overflow during the wrap.
+ */
+static int svcauth_gss_wrap_priv(struct svc_rqst *rqstp)
+{
+ struct gss_svc_data *gsd = rqstp->rq_auth_data;
+ struct rpc_gss_wire_cred *gc = &gsd->clcred;
+ struct xdr_buf *buf = &rqstp->rq_res;
+ struct kvec *head = buf->head;
+ struct kvec *tail = buf->tail;
+ u32 offset, pad, maj_stat;
+ __be32 *p;
+
+ offset = svcauth_gss_prepare_to_wrap(rqstp, gsd);
+ if (!offset)
+ return 0;
+
+ /*
+ * Buffer space for this field has already been reserved
+ * in svcauth_gss_accept(). Note that the GSS sequence
+ * number is encrypted along with the RPC reply payload.
+ */
+ if (xdr_encode_word(buf, offset + XDR_UNIT, gc->gc_seq))
+ goto wrap_failed;
+
+ /*
+ * If there is currently tail data, make sure there is
+ * room for the head, tail, and 2 * RPC_MAX_AUTH_SIZE in
+ * the page, and move the current tail data such that
+ * there is RPC_MAX_AUTH_SIZE slack space available in
+ * both the head and tail.
+ */
+ if (tail->iov_base) {
+ if (tail->iov_base >= head->iov_base + PAGE_SIZE)
+ goto wrap_failed;
+ if (tail->iov_base < head->iov_base)
+ goto wrap_failed;
+ if (tail->iov_len + head->iov_len
+ + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE)
+ goto wrap_failed;
+ memmove(tail->iov_base + RPC_MAX_AUTH_SIZE, tail->iov_base,
+ tail->iov_len);
+ tail->iov_base += RPC_MAX_AUTH_SIZE;
+ }
+ /*
+ * If there is no current tail data, make sure there is
+ * room for the head data, and 2 * RPC_MAX_AUTH_SIZE in the
+ * allotted page, and set up tail information such that there
+ * is RPC_MAX_AUTH_SIZE slack space available in both the
+ * head and tail.
+ */
+ if (!tail->iov_base) {
+ if (head->iov_len + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE)
+ goto wrap_failed;
+ tail->iov_base = head->iov_base
+ + head->iov_len + RPC_MAX_AUTH_SIZE;
+ tail->iov_len = 0;
+ }
+
+ maj_stat = gss_wrap(gsd->rsci->mechctx, offset + XDR_UNIT, buf,
+ buf->pages);
+ if (maj_stat != GSS_S_COMPLETE)
+ goto bad_wrap;
+
+ /* Wrapping can change the size of databody_priv. */
+ if (xdr_encode_word(buf, offset, buf->len - offset - XDR_UNIT))
+ goto wrap_failed;
+ pad = xdr_pad_size(buf->len - offset - XDR_UNIT);
+ p = (__be32 *)(tail->iov_base + tail->iov_len);
+ memset(p, 0, pad);
+ tail->iov_len += pad;
+ buf->len += pad;
+
+ return 0;
+wrap_failed:
+ trace_rpcgss_svc_wrap_failed(rqstp);
+ return -EINVAL;
+bad_wrap:
+ trace_rpcgss_svc_wrap(rqstp, maj_stat);
+ return -ENOMEM;
+}
+
+/**
+ * svcauth_gss_release - Wrap payload and release resources
+ * @rqstp: RPC transaction context
+ *
+ * Return values:
+ * %0: the Reply is ready to be sent
+ * %-ENOMEM: failed to allocate memory
+ * %-EINVAL: encoding error
+ */
+static int
+svcauth_gss_release(struct svc_rqst *rqstp)
+{
+ struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id);
+ struct gss_svc_data *gsd = rqstp->rq_auth_data;
+ struct rpc_gss_wire_cred *gc;
+ int stat;
+
+ if (!gsd)
+ goto out;
+ gc = &gsd->clcred;
+ if (gc->gc_proc != RPC_GSS_PROC_DATA)
+ goto out;
+
+ switch (gc->gc_svc) {
+ case RPC_GSS_SVC_NONE:
+ break;
+ case RPC_GSS_SVC_INTEGRITY:
+ stat = svcauth_gss_wrap_integ(rqstp);
+ if (stat)
+ goto out_err;
+ break;
+ case RPC_GSS_SVC_PRIVACY:
+ stat = svcauth_gss_wrap_priv(rqstp);
+ if (stat)
+ goto out_err;
+ break;
+ /*
+ * For any other gc_svc value, svcauth_gss_accept() already set
+ * the auth_error appropriately; just fall through:
+ */
+ }
+
+out:
+ stat = 0;
+out_err:
+ if (rqstp->rq_client)
+ auth_domain_put(rqstp->rq_client);
+ rqstp->rq_client = NULL;
+ if (rqstp->rq_gssclient)
+ auth_domain_put(rqstp->rq_gssclient);
+ rqstp->rq_gssclient = NULL;
+ if (rqstp->rq_cred.cr_group_info)
+ put_group_info(rqstp->rq_cred.cr_group_info);
+ rqstp->rq_cred.cr_group_info = NULL;
+ if (gsd && gsd->rsci) {
+ cache_put(&gsd->rsci->h, sn->rsc_cache);
+ gsd->rsci = NULL;
+ }
+ return stat;
+}
+
+static void
+svcauth_gss_domain_release_rcu(struct rcu_head *head)
+{
+ struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head);
+ struct gss_domain *gd = container_of(dom, struct gss_domain, h);
+
+ kfree(dom->name);
+ kfree(gd);
+}
+
+static void
+svcauth_gss_domain_release(struct auth_domain *dom)
+{
+ call_rcu(&dom->rcu_head, svcauth_gss_domain_release_rcu);
+}
+
+static struct auth_ops svcauthops_gss = {
+ .name = "rpcsec_gss",
+ .owner = THIS_MODULE,
+ .flavour = RPC_AUTH_GSS,
+ .accept = svcauth_gss_accept,
+ .release = svcauth_gss_release,
+ .domain_release = svcauth_gss_domain_release,
+ .set_client = svcauth_gss_set_client,
+};
+
+static int rsi_cache_create_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd;
+ int err;
+
+ cd = cache_create_net(&rsi_cache_template, net);
+ if (IS_ERR(cd))
+ return PTR_ERR(cd);
+ err = cache_register_net(cd, net);
+ if (err) {
+ cache_destroy_net(cd, net);
+ return err;
+ }
+ sn->rsi_cache = cd;
+ return 0;
+}
+
+static void rsi_cache_destroy_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd = sn->rsi_cache;
+
+ sn->rsi_cache = NULL;
+ cache_purge(cd);
+ cache_unregister_net(cd, net);
+ cache_destroy_net(cd, net);
+}
+
+static int rsc_cache_create_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd;
+ int err;
+
+ cd = cache_create_net(&rsc_cache_template, net);
+ if (IS_ERR(cd))
+ return PTR_ERR(cd);
+ err = cache_register_net(cd, net);
+ if (err) {
+ cache_destroy_net(cd, net);
+ return err;
+ }
+ sn->rsc_cache = cd;
+ return 0;
+}
+
+static void rsc_cache_destroy_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd = sn->rsc_cache;
+
+ sn->rsc_cache = NULL;
+ cache_purge(cd);
+ cache_unregister_net(cd, net);
+ cache_destroy_net(cd, net);
+}
+
+int
+gss_svc_init_net(struct net *net)
+{
+ int rv;
+
+ rv = rsc_cache_create_net(net);
+ if (rv)
+ return rv;
+ rv = rsi_cache_create_net(net);
+ if (rv)
+ goto out1;
+ rv = create_use_gss_proxy_proc_entry(net);
+ if (rv)
+ goto out2;
+
+ rv = create_krb5_enctypes_proc_entry(net);
+ if (rv)
+ goto out3;
+
+ return 0;
+
+out3:
+ destroy_use_gss_proxy_proc_entry(net);
+out2:
+ rsi_cache_destroy_net(net);
+out1:
+ rsc_cache_destroy_net(net);
+ return rv;
+}
+
+void
+gss_svc_shutdown_net(struct net *net)
+{
+ destroy_krb5_enctypes_proc_entry(net);
+ destroy_use_gss_proxy_proc_entry(net);
+ rsi_cache_destroy_net(net);
+ rsc_cache_destroy_net(net);
+}
+
+int
+gss_svc_init(void)
+{
+ return svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss);
+}
+
+void
+gss_svc_shutdown(void)
+{
+ svc_auth_unregister(RPC_AUTH_GSS);
+}
diff --git a/net/sunrpc/auth_gss/trace.c b/net/sunrpc/auth_gss/trace.c
new file mode 100644
index 0000000000..76685abba6
--- /dev/null
+++ b/net/sunrpc/auth_gss/trace.c
@@ -0,0 +1,14 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2018, 2019 Oracle. All rights reserved.
+ */
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/svc.h>
+#include <linux/sunrpc/svc_xprt.h>
+#include <linux/sunrpc/auth_gss.h>
+#include <linux/sunrpc/gss_err.h>
+
+#define CREATE_TRACE_POINTS
+#include <trace/events/rpcgss.h>
diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c
new file mode 100644
index 0000000000..41a633a404
--- /dev/null
+++ b/net/sunrpc/auth_null.c
@@ -0,0 +1,143 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * linux/net/sunrpc/auth_null.c
+ *
+ * AUTH_NULL authentication. Really :-)
+ *
+ * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/sunrpc/clnt.h>
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+static struct rpc_auth null_auth;
+static struct rpc_cred null_cred;
+
+static struct rpc_auth *
+nul_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
+{
+ refcount_inc(&null_auth.au_count);
+ return &null_auth;
+}
+
+static void
+nul_destroy(struct rpc_auth *auth)
+{
+}
+
+/*
+ * Lookup NULL creds for current process
+ */
+static struct rpc_cred *
+nul_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
+{
+ return get_rpccred(&null_cred);
+}
+
+/*
+ * Destroy cred handle.
+ */
+static void
+nul_destroy_cred(struct rpc_cred *cred)
+{
+}
+
+/*
+ * Match cred handle against current process
+ */
+static int
+nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
+{
+ return 1;
+}
+
+/*
+ * Marshal credential.
+ */
+static int
+nul_marshal(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4 * sizeof(*p));
+ if (!p)
+ return -EMSGSIZE;
+ /* Credential */
+ *p++ = rpc_auth_null;
+ *p++ = xdr_zero;
+ /* Verifier */
+ *p++ = rpc_auth_null;
+ *p = xdr_zero;
+ return 0;
+}
+
+/*
+ * Refresh credential. This is a no-op for AUTH_NULL
+ */
+static int
+nul_refresh(struct rpc_task *task)
+{
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
+ return 0;
+}
+
+static int
+nul_validate(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+ if (!p)
+ return -EIO;
+ if (*p++ != rpc_auth_null)
+ return -EIO;
+ if (*p != xdr_zero)
+ return -EIO;
+ return 0;
+}
+
+const struct rpc_authops authnull_ops = {
+ .owner = THIS_MODULE,
+ .au_flavor = RPC_AUTH_NULL,
+ .au_name = "NULL",
+ .create = nul_create,
+ .destroy = nul_destroy,
+ .lookup_cred = nul_lookup_cred,
+};
+
+static
+struct rpc_auth null_auth = {
+ .au_cslack = NUL_CALLSLACK,
+ .au_rslack = NUL_REPLYSLACK,
+ .au_verfsize = NUL_REPLYSLACK,
+ .au_ralign = NUL_REPLYSLACK,
+ .au_ops = &authnull_ops,
+ .au_flavor = RPC_AUTH_NULL,
+ .au_count = REFCOUNT_INIT(1),
+};
+
+static
+const struct rpc_credops null_credops = {
+ .cr_name = "AUTH_NULL",
+ .crdestroy = nul_destroy_cred,
+ .crmatch = nul_match,
+ .crmarshal = nul_marshal,
+ .crwrap_req = rpcauth_wrap_req_encode,
+ .crrefresh = nul_refresh,
+ .crvalidate = nul_validate,
+ .crunwrap_resp = rpcauth_unwrap_resp_decode,
+};
+
+static
+struct rpc_cred null_cred = {
+ .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru),
+ .cr_auth = &null_auth,
+ .cr_ops = &null_credops,
+ .cr_count = REFCOUNT_INIT(2),
+ .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
+};
diff --git a/net/sunrpc/auth_tls.c b/net/sunrpc/auth_tls.c
new file mode 100644
index 0000000000..87f570fd3b
--- /dev/null
+++ b/net/sunrpc/auth_tls.c
@@ -0,0 +1,175 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (c) 2021, 2022 Oracle. All rights reserved.
+ *
+ * The AUTH_TLS credential is used only to probe a remote peer
+ * for RPC-over-TLS support.
+ */
+
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/sunrpc/clnt.h>
+
+static const char *starttls_token = "STARTTLS";
+static const size_t starttls_len = 8;
+
+static struct rpc_auth tls_auth;
+static struct rpc_cred tls_cred;
+
+static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ const void *obj)
+{
+}
+
+static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ void *obj)
+{
+ return 0;
+}
+
+static const struct rpc_procinfo rpcproc_tls_probe = {
+ .p_encode = tls_encode_probe,
+ .p_decode = tls_decode_probe,
+};
+
+static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
+{
+ task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
+ rpc_call_start(task);
+}
+
+static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
+{
+}
+
+static const struct rpc_call_ops rpc_tls_probe_ops = {
+ .rpc_call_prepare = rpc_tls_probe_call_prepare,
+ .rpc_call_done = rpc_tls_probe_call_done,
+};
+
+static int tls_probe(struct rpc_clnt *clnt)
+{
+ struct rpc_message msg = {
+ .rpc_proc = &rpcproc_tls_probe,
+ };
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_message = &msg,
+ .rpc_op_cred = &tls_cred,
+ .callback_ops = &rpc_tls_probe_ops,
+ .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
+ };
+ struct rpc_task *task;
+ int status;
+
+ task = rpc_run_task(&task_setup_data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ status = task->tk_status;
+ rpc_put_task(task);
+ return status;
+}
+
+static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
+ struct rpc_clnt *clnt)
+{
+ refcount_inc(&tls_auth.au_count);
+ return &tls_auth;
+}
+
+static void tls_destroy(struct rpc_auth *auth)
+{
+}
+
+static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
+ struct auth_cred *acred, int flags)
+{
+ return get_rpccred(&tls_cred);
+}
+
+static void tls_destroy_cred(struct rpc_cred *cred)
+{
+}
+
+static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
+{
+ return 1;
+}
+
+static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
+ if (!p)
+ return -EMSGSIZE;
+ /* Credential */
+ *p++ = rpc_auth_tls;
+ *p++ = xdr_zero;
+ /* Verifier */
+ *p++ = rpc_auth_null;
+ *p = xdr_zero;
+ return 0;
+}
+
+static int tls_refresh(struct rpc_task *task)
+{
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
+ return 0;
+}
+
+static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+ void *str;
+
+ p = xdr_inline_decode(xdr, XDR_UNIT);
+ if (!p)
+ return -EIO;
+ if (*p != rpc_auth_null)
+ return -EIO;
+ if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
+ return -EPROTONOSUPPORT;
+ if (memcmp(str, starttls_token, starttls_len))
+ return -EPROTONOSUPPORT;
+ return 0;
+}
+
+const struct rpc_authops authtls_ops = {
+ .owner = THIS_MODULE,
+ .au_flavor = RPC_AUTH_TLS,
+ .au_name = "NULL",
+ .create = tls_create,
+ .destroy = tls_destroy,
+ .lookup_cred = tls_lookup_cred,
+ .ping = tls_probe,
+};
+
+static struct rpc_auth tls_auth = {
+ .au_cslack = NUL_CALLSLACK,
+ .au_rslack = NUL_REPLYSLACK,
+ .au_verfsize = NUL_REPLYSLACK,
+ .au_ralign = NUL_REPLYSLACK,
+ .au_ops = &authtls_ops,
+ .au_flavor = RPC_AUTH_TLS,
+ .au_count = REFCOUNT_INIT(1),
+};
+
+static const struct rpc_credops tls_credops = {
+ .cr_name = "AUTH_TLS",
+ .crdestroy = tls_destroy_cred,
+ .crmatch = tls_match,
+ .crmarshal = tls_marshal,
+ .crwrap_req = rpcauth_wrap_req_encode,
+ .crrefresh = tls_refresh,
+ .crvalidate = tls_validate,
+ .crunwrap_resp = rpcauth_unwrap_resp_decode,
+};
+
+static struct rpc_cred tls_cred = {
+ .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
+ .cr_auth = &tls_auth,
+ .cr_ops = &tls_credops,
+ .cr_count = REFCOUNT_INIT(2),
+ .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
+};
diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c
new file mode 100644
index 0000000000..1e091d3fa6
--- /dev/null
+++ b/net/sunrpc/auth_unix.c
@@ -0,0 +1,243 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * linux/net/sunrpc/auth_unix.c
+ *
+ * UNIX-style authentication; no AUTH_SHORT support
+ *
+ * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/slab.h>
+#include <linux/types.h>
+#include <linux/sched.h>
+#include <linux/module.h>
+#include <linux/mempool.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/auth.h>
+#include <linux/user_namespace.h>
+
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_AUTH
+#endif
+
+static struct rpc_auth unix_auth;
+static const struct rpc_credops unix_credops;
+static mempool_t *unix_pool;
+
+static struct rpc_auth *
+unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
+{
+ refcount_inc(&unix_auth.au_count);
+ return &unix_auth;
+}
+
+static void
+unx_destroy(struct rpc_auth *auth)
+{
+}
+
+/*
+ * Lookup AUTH_UNIX creds for current process
+ */
+static struct rpc_cred *unx_lookup_cred(struct rpc_auth *auth,
+ struct auth_cred *acred, int flags)
+{
+ struct rpc_cred *ret;
+
+ ret = kmalloc(sizeof(*ret), rpc_task_gfp_mask());
+ if (!ret) {
+ if (!(flags & RPCAUTH_LOOKUP_ASYNC))
+ return ERR_PTR(-ENOMEM);
+ ret = mempool_alloc(unix_pool, GFP_NOWAIT);
+ if (!ret)
+ return ERR_PTR(-ENOMEM);
+ }
+ rpcauth_init_cred(ret, acred, auth, &unix_credops);
+ ret->cr_flags = 1UL << RPCAUTH_CRED_UPTODATE;
+ return ret;
+}
+
+static void
+unx_free_cred_callback(struct rcu_head *head)
+{
+ struct rpc_cred *rpc_cred = container_of(head, struct rpc_cred, cr_rcu);
+
+ put_cred(rpc_cred->cr_cred);
+ mempool_free(rpc_cred, unix_pool);
+}
+
+static void
+unx_destroy_cred(struct rpc_cred *cred)
+{
+ call_rcu(&cred->cr_rcu, unx_free_cred_callback);
+}
+
+/*
+ * Match credentials against current the auth_cred.
+ */
+static int
+unx_match(struct auth_cred *acred, struct rpc_cred *cred, int flags)
+{
+ unsigned int groups = 0;
+ unsigned int i;
+
+ if (cred->cr_cred == acred->cred)
+ return 1;
+
+ if (!uid_eq(cred->cr_cred->fsuid, acred->cred->fsuid) || !gid_eq(cred->cr_cred->fsgid, acred->cred->fsgid))
+ return 0;
+
+ if (acred->cred->group_info != NULL)
+ groups = acred->cred->group_info->ngroups;
+ if (groups > UNX_NGROUPS)
+ groups = UNX_NGROUPS;
+ if (cred->cr_cred->group_info == NULL)
+ return groups == 0;
+ if (groups != cred->cr_cred->group_info->ngroups)
+ return 0;
+
+ for (i = 0; i < groups ; i++)
+ if (!gid_eq(cred->cr_cred->group_info->gid[i], acred->cred->group_info->gid[i]))
+ return 0;
+ return 1;
+}
+
+/*
+ * Marshal credentials.
+ * Maybe we should keep a cached credential for performance reasons.
+ */
+static int
+unx_marshal(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+ __be32 *p, *cred_len, *gidarr_len;
+ int i;
+ struct group_info *gi = cred->cr_cred->group_info;
+ struct user_namespace *userns = clnt->cl_cred ?
+ clnt->cl_cred->user_ns : &init_user_ns;
+
+ /* Credential */
+
+ p = xdr_reserve_space(xdr, 3 * sizeof(*p));
+ if (!p)
+ goto marshal_failed;
+ *p++ = rpc_auth_unix;
+ cred_len = p++;
+ *p++ = xdr_zero; /* stamp */
+ if (xdr_stream_encode_opaque(xdr, clnt->cl_nodename,
+ clnt->cl_nodelen) < 0)
+ goto marshal_failed;
+ p = xdr_reserve_space(xdr, 3 * sizeof(*p));
+ if (!p)
+ goto marshal_failed;
+ *p++ = cpu_to_be32(from_kuid_munged(userns, cred->cr_cred->fsuid));
+ *p++ = cpu_to_be32(from_kgid_munged(userns, cred->cr_cred->fsgid));
+
+ gidarr_len = p++;
+ if (gi)
+ for (i = 0; i < UNX_NGROUPS && i < gi->ngroups; i++)
+ *p++ = cpu_to_be32(from_kgid_munged(userns, gi->gid[i]));
+ *gidarr_len = cpu_to_be32(p - gidarr_len - 1);
+ *cred_len = cpu_to_be32((p - cred_len - 1) << 2);
+ p = xdr_reserve_space(xdr, (p - gidarr_len - 1) << 2);
+ if (!p)
+ goto marshal_failed;
+
+ /* Verifier */
+
+ p = xdr_reserve_space(xdr, 2 * sizeof(*p));
+ if (!p)
+ goto marshal_failed;
+ *p++ = rpc_auth_null;
+ *p = xdr_zero;
+
+ return 0;
+
+marshal_failed:
+ return -EMSGSIZE;
+}
+
+/*
+ * Refresh credentials. This is a no-op for AUTH_UNIX
+ */
+static int
+unx_refresh(struct rpc_task *task)
+{
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
+ return 0;
+}
+
+static int
+unx_validate(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth;
+ __be32 *p;
+ u32 size;
+
+ p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+ if (!p)
+ return -EIO;
+ switch (*p++) {
+ case rpc_auth_null:
+ case rpc_auth_unix:
+ case rpc_auth_short:
+ break;
+ default:
+ return -EIO;
+ }
+ size = be32_to_cpup(p);
+ if (size > RPC_MAX_AUTH_SIZE)
+ return -EIO;
+ p = xdr_inline_decode(xdr, size);
+ if (!p)
+ return -EIO;
+
+ auth->au_verfsize = XDR_QUADLEN(size) + 2;
+ auth->au_rslack = XDR_QUADLEN(size) + 2;
+ auth->au_ralign = XDR_QUADLEN(size) + 2;
+ return 0;
+}
+
+int __init rpc_init_authunix(void)
+{
+ unix_pool = mempool_create_kmalloc_pool(16, sizeof(struct rpc_cred));
+ return unix_pool ? 0 : -ENOMEM;
+}
+
+void rpc_destroy_authunix(void)
+{
+ mempool_destroy(unix_pool);
+}
+
+const struct rpc_authops authunix_ops = {
+ .owner = THIS_MODULE,
+ .au_flavor = RPC_AUTH_UNIX,
+ .au_name = "UNIX",
+ .create = unx_create,
+ .destroy = unx_destroy,
+ .lookup_cred = unx_lookup_cred,
+};
+
+static
+struct rpc_auth unix_auth = {
+ .au_cslack = UNX_CALLSLACK,
+ .au_rslack = NUL_REPLYSLACK,
+ .au_verfsize = NUL_REPLYSLACK,
+ .au_ops = &authunix_ops,
+ .au_flavor = RPC_AUTH_UNIX,
+ .au_count = REFCOUNT_INIT(1),
+};
+
+static
+const struct rpc_credops unix_credops = {
+ .cr_name = "AUTH_UNIX",
+ .crdestroy = unx_destroy_cred,
+ .crmatch = unx_match,
+ .crmarshal = unx_marshal,
+ .crwrap_req = rpcauth_wrap_req_encode,
+ .crrefresh = unx_refresh,
+ .crvalidate = unx_validate,
+ .crunwrap_resp = rpcauth_unwrap_resp_decode,
+};
diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c
new file mode 100644
index 0000000000..65a6c6429a
--- /dev/null
+++ b/net/sunrpc/backchannel_rqst.c
@@ -0,0 +1,376 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/******************************************************************************
+
+(c) 2007 Network Appliance, Inc. All Rights Reserved.
+(c) 2009 NetApp. All Rights Reserved.
+
+
+******************************************************************************/
+
+#include <linux/tcp.h>
+#include <linux/slab.h>
+#include <linux/sunrpc/xprt.h>
+#include <linux/export.h>
+#include <linux/sunrpc/bc_xprt.h>
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+#define RPCDBG_FACILITY RPCDBG_TRANS
+#endif
+
+#define BC_MAX_SLOTS 64U
+
+unsigned int xprt_bc_max_slots(struct rpc_xprt *xprt)
+{
+ return BC_MAX_SLOTS;
+}
+
+/*
+ * Helper routines that track the number of preallocation elements
+ * on the transport.
+ */
+static inline int xprt_need_to_requeue(struct rpc_xprt *xprt)
+{
+ return xprt->bc_alloc_count < xprt->bc_alloc_max;
+}
+
+/*
+ * Free the preallocated rpc_rqst structure and the memory
+ * buffers hanging off of it.
+ */
+static void xprt_free_allocation(struct rpc_rqst *req)
+{
+ struct xdr_buf *xbufp;
+
+ dprintk("RPC: free allocations for req= %p\n", req);
+ WARN_ON_ONCE(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state));
+ xbufp = &req->rq_rcv_buf;
+ free_page((unsigned long)xbufp->head[0].iov_base);
+ xbufp = &req->rq_snd_buf;
+ free_page((unsigned long)xbufp->head[0].iov_base);
+ kfree(req);
+}
+
+static void xprt_bc_reinit_xdr_buf(struct xdr_buf *buf)
+{
+ buf->head[0].iov_len = PAGE_SIZE;
+ buf->tail[0].iov_len = 0;
+ buf->pages = NULL;
+ buf->page_len = 0;
+ buf->flags = 0;
+ buf->len = 0;
+ buf->buflen = PAGE_SIZE;
+}
+
+static int xprt_alloc_xdr_buf(struct xdr_buf *buf, gfp_t gfp_flags)
+{
+ struct page *page;
+ /* Preallocate one XDR receive buffer */
+ page = alloc_page(gfp_flags);
+ if (page == NULL)
+ return -ENOMEM;
+ xdr_buf_init(buf, page_address(page), PAGE_SIZE);
+ return 0;
+}
+
+static struct rpc_rqst *xprt_alloc_bc_req(struct rpc_xprt *xprt)
+{
+ gfp_t gfp_flags = GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN;
+ struct rpc_rqst *req;
+
+ /* Pre-allocate one backchannel rpc_rqst */
+ req = kzalloc(sizeof(*req), gfp_flags);
+ if (req == NULL)
+ return NULL;
+
+ req->rq_xprt = xprt;
+ INIT_LIST_HEAD(&req->rq_bc_list);
+
+ /* Preallocate one XDR receive buffer */
+ if (xprt_alloc_xdr_buf(&req->rq_rcv_buf, gfp_flags) < 0) {
+ printk(KERN_ERR "Failed to create bc receive xbuf\n");
+ goto out_free;
+ }
+ req->rq_rcv_buf.len = PAGE_SIZE;
+
+ /* Preallocate one XDR send buffer */
+ if (xprt_alloc_xdr_buf(&req->rq_snd_buf, gfp_flags) < 0) {
+ printk(KERN_ERR "Failed to create bc snd xbuf\n");
+ goto out_free;
+ }
+ return req;
+out_free:
+ xprt_free_allocation(req);
+ return NULL;
+}
+
+/*
+ * Preallocate up to min_reqs structures and related buffers for use
+ * by the backchannel. This function can be called multiple times
+ * when creating new sessions that use the same rpc_xprt. The
+ * preallocated buffers are added to the pool of resources used by
+ * the rpc_xprt. Any one of these resources may be used by an
+ * incoming callback request. It's up to the higher levels in the
+ * stack to enforce that the maximum number of session slots is not
+ * being exceeded.
+ *
+ * Some callback arguments can be large. For example, a pNFS server
+ * using multiple deviceids. The list can be unbound, but the client
+ * has the ability to tell the server the maximum size of the callback
+ * requests. Each deviceID is 16 bytes, so allocate one page
+ * for the arguments to have enough room to receive a number of these
+ * deviceIDs. The NFS client indicates to the pNFS server that its
+ * callback requests can be up to 4096 bytes in size.
+ */
+int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs)
+{
+ if (!xprt->ops->bc_setup)
+ return 0;
+ return xprt->ops->bc_setup(xprt, min_reqs);
+}
+EXPORT_SYMBOL_GPL(xprt_setup_backchannel);
+
+int xprt_setup_bc(struct rpc_xprt *xprt, unsigned int min_reqs)
+{
+ struct rpc_rqst *req;
+ struct list_head tmp_list;
+ int i;
+
+ dprintk("RPC: setup backchannel transport\n");
+
+ if (min_reqs > BC_MAX_SLOTS)
+ min_reqs = BC_MAX_SLOTS;
+
+ /*
+ * We use a temporary list to keep track of the preallocated
+ * buffers. Once we're done building the list we splice it
+ * into the backchannel preallocation list off of the rpc_xprt
+ * struct. This helps minimize the amount of time the list
+ * lock is held on the rpc_xprt struct. It also makes cleanup
+ * easier in case of memory allocation errors.
+ */
+ INIT_LIST_HEAD(&tmp_list);
+ for (i = 0; i < min_reqs; i++) {
+ /* Pre-allocate one backchannel rpc_rqst */
+ req = xprt_alloc_bc_req(xprt);
+ if (req == NULL) {
+ printk(KERN_ERR "Failed to create bc rpc_rqst\n");
+ goto out_free;
+ }
+
+ /* Add the allocated buffer to the tmp list */
+ dprintk("RPC: adding req= %p\n", req);
+ list_add(&req->rq_bc_pa_list, &tmp_list);
+ }
+
+ /*
+ * Add the temporary list to the backchannel preallocation list
+ */
+ spin_lock(&xprt->bc_pa_lock);
+ list_splice(&tmp_list, &xprt->bc_pa_list);
+ xprt->bc_alloc_count += min_reqs;
+ xprt->bc_alloc_max += min_reqs;
+ atomic_add(min_reqs, &xprt->bc_slot_count);
+ spin_unlock(&xprt->bc_pa_lock);
+
+ dprintk("RPC: setup backchannel transport done\n");
+ return 0;
+
+out_free:
+ /*
+ * Memory allocation failed, free the temporary list
+ */
+ while (!list_empty(&tmp_list)) {
+ req = list_first_entry(&tmp_list,
+ struct rpc_rqst,
+ rq_bc_pa_list);
+ list_del(&req->rq_bc_pa_list);
+ xprt_free_allocation(req);
+ }
+
+ dprintk("RPC: setup backchannel transport failed\n");
+ return -ENOMEM;
+}
+
+/**
+ * xprt_destroy_backchannel - Destroys the backchannel preallocated structures.
+ * @xprt: the transport holding the preallocated strucures
+ * @max_reqs: the maximum number of preallocated structures to destroy
+ *
+ * Since these structures may have been allocated by multiple calls
+ * to xprt_setup_backchannel, we only destroy up to the maximum number
+ * of reqs specified by the caller.
+ */
+void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs)
+{
+ if (xprt->ops->bc_destroy)
+ xprt->ops->bc_destroy(xprt, max_reqs);
+}
+EXPORT_SYMBOL_GPL(xprt_destroy_backchannel);
+
+void xprt_destroy_bc(struct rpc_xprt *xprt, unsigned int max_reqs)
+{
+ struct rpc_rqst *req = NULL, *tmp = NULL;
+
+ dprintk("RPC: destroy backchannel transport\n");
+
+ if (max_reqs == 0)
+ goto out;
+
+ spin_lock_bh(&xprt->bc_pa_lock);
+ xprt->bc_alloc_max -= min(max_reqs, xprt->bc_alloc_max);
+ list_for_each_entry_safe(req, tmp, &xprt->bc_pa_list, rq_bc_pa_list) {
+ dprintk("RPC: req=%p\n", req);
+ list_del(&req->rq_bc_pa_list);
+ xprt_free_allocation(req);
+ xprt->bc_alloc_count--;
+ atomic_dec(&xprt->bc_slot_count);
+ if (--max_reqs == 0)
+ break;
+ }
+ spin_unlock_bh(&xprt->bc_pa_lock);
+
+out:
+ dprintk("RPC: backchannel list empty= %s\n",
+ list_empty(&xprt->bc_pa_list) ? "true" : "false");
+}
+
+static struct rpc_rqst *xprt_get_bc_request(struct rpc_xprt *xprt, __be32 xid,
+ struct rpc_rqst *new)
+{
+ struct rpc_rqst *req = NULL;
+
+ dprintk("RPC: allocate a backchannel request\n");
+ if (list_empty(&xprt->bc_pa_list)) {
+ if (!new)
+ goto not_found;
+ if (atomic_read(&xprt->bc_slot_count) >= BC_MAX_SLOTS)
+ goto not_found;
+ list_add_tail(&new->rq_bc_pa_list, &xprt->bc_pa_list);
+ xprt->bc_alloc_count++;
+ atomic_inc(&xprt->bc_slot_count);
+ }
+ req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst,
+ rq_bc_pa_list);
+ req->rq_reply_bytes_recvd = 0;
+ memcpy(&req->rq_private_buf, &req->rq_rcv_buf,
+ sizeof(req->rq_private_buf));
+ req->rq_xid = xid;
+ req->rq_connect_cookie = xprt->connect_cookie;
+ dprintk("RPC: backchannel req=%p\n", req);
+not_found:
+ return req;
+}
+
+/*
+ * Return the preallocated rpc_rqst structure and XDR buffers
+ * associated with this rpc_task.
+ */
+void xprt_free_bc_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ xprt->ops->bc_free_rqst(req);
+}
+
+void xprt_free_bc_rqst(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ dprintk("RPC: free backchannel req=%p\n", req);
+
+ req->rq_connect_cookie = xprt->connect_cookie - 1;
+ smp_mb__before_atomic();
+ clear_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state);
+ smp_mb__after_atomic();
+
+ /*
+ * Return it to the list of preallocations so that it
+ * may be reused by a new callback request.
+ */
+ spin_lock_bh(&xprt->bc_pa_lock);
+ if (xprt_need_to_requeue(xprt)) {
+ xprt_bc_reinit_xdr_buf(&req->rq_snd_buf);
+ xprt_bc_reinit_xdr_buf(&req->rq_rcv_buf);
+ req->rq_rcv_buf.len = PAGE_SIZE;
+ list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list);
+ xprt->bc_alloc_count++;
+ atomic_inc(&xprt->bc_slot_count);
+ req = NULL;
+ }
+ spin_unlock_bh(&xprt->bc_pa_lock);
+ if (req != NULL) {
+ /*
+ * The last remaining session was destroyed while this
+ * entry was in use. Free the entry and don't attempt
+ * to add back to the list because there is no need to
+ * have anymore preallocated entries.
+ */
+ dprintk("RPC: Last session removed req=%p\n", req);
+ xprt_free_allocation(req);
+ }
+ xprt_put(xprt);
+}
+
+/*
+ * One or more rpc_rqst structure have been preallocated during the
+ * backchannel setup. Buffer space for the send and private XDR buffers
+ * has been preallocated as well. Use xprt_alloc_bc_request to allocate
+ * to this request. Use xprt_free_bc_request to return it.
+ *
+ * We know that we're called in soft interrupt context, grab the spin_lock
+ * since there is no need to grab the bottom half spin_lock.
+ *
+ * Return an available rpc_rqst, otherwise NULL if non are available.
+ */
+struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid)
+{
+ struct rpc_rqst *req, *new = NULL;
+
+ do {
+ spin_lock(&xprt->bc_pa_lock);
+ list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) {
+ if (req->rq_connect_cookie != xprt->connect_cookie)
+ continue;
+ if (req->rq_xid == xid)
+ goto found;
+ }
+ req = xprt_get_bc_request(xprt, xid, new);
+found:
+ spin_unlock(&xprt->bc_pa_lock);
+ if (new) {
+ if (req != new)
+ xprt_free_allocation(new);
+ break;
+ } else if (req)
+ break;
+ new = xprt_alloc_bc_req(xprt);
+ } while (new);
+ return req;
+}
+
+/*
+ * Add callback request to callback list. The callback
+ * service sleeps on the sv_cb_waitq waiting for new
+ * requests. Wake it up after adding enqueing the
+ * request.
+ */
+void xprt_complete_bc_request(struct rpc_rqst *req, uint32_t copied)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct svc_serv *bc_serv = xprt->bc_serv;
+
+ spin_lock(&xprt->bc_pa_lock);
+ list_del(&req->rq_bc_pa_list);
+ xprt->bc_alloc_count--;
+ spin_unlock(&xprt->bc_pa_lock);
+
+ req->rq_private_buf.len = copied;
+ set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state);
+
+ dprintk("RPC: add callback request to list\n");
+ xprt_get(xprt);
+ spin_lock(&bc_serv->sv_cb_lock);
+ list_add(&req->rq_bc_list, &bc_serv->sv_cb_list);
+ wake_up(&bc_serv->sv_cb_waitq);
+ spin_unlock(&bc_serv->sv_cb_lock);
+}
diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c
new file mode 100644
index 0000000000..95ff747061
--- /dev/null
+++ b/net/sunrpc/cache.c
@@ -0,0 +1,1918 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * net/sunrpc/cache.c
+ *
+ * Generic code for various authentication-related caches
+ * used by sunrpc clients and servers.
+ *
+ * Copyright (C) 2002 Neil Brown <neilb@cse.unsw.edu.au>
+ */
+
+#include <linux/types.h>
+#include <linux/fs.h>
+#include <linux/file.h>
+#include <linux/slab.h>
+#include <linux/signal.h>
+#include <linux/sched.h>
+#include <linux/kmod.h>
+#include <linux/list.h>
+#include <linux/module.h>
+#include <linux/ctype.h>
+#include <linux/string_helpers.h>
+#include <linux/uaccess.h>
+#include <linux/poll.h>
+#include <linux/seq_file.h>
+#include <linux/proc_fs.h>
+#include <linux/net.h>
+#include <linux/workqueue.h>
+#include <linux/mutex.h>
+#include <linux/pagemap.h>
+#include <asm/ioctls.h>
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/cache.h>
+#include <linux/sunrpc/stats.h>
+#include <linux/sunrpc/rpc_pipe_fs.h>
+#include <trace/events/sunrpc.h>
+
+#include "netns.h"
+#include "fail.h"
+
+#define RPCDBG_FACILITY RPCDBG_CACHE
+
+static bool cache_defer_req(struct cache_req *req, struct cache_head *item);
+static void cache_revisit_request(struct cache_head *item);
+
+static void cache_init(struct cache_head *h, struct cache_detail *detail)
+{
+ time64_t now = seconds_since_boot();
+ INIT_HLIST_NODE(&h->cache_list);
+ h->flags = 0;
+ kref_init(&h->ref);
+ h->expiry_time = now + CACHE_NEW_EXPIRY;
+ if (now <= detail->flush_time)
+ /* ensure it isn't already expired */
+ now = detail->flush_time + 1;
+ h->last_refresh = now;
+}
+
+static void cache_fresh_unlocked(struct cache_head *head,
+ struct cache_detail *detail);
+
+static struct cache_head *sunrpc_cache_find_rcu(struct cache_detail *detail,
+ struct cache_head *key,
+ int hash)
+{
+ struct hlist_head *head = &detail->hash_table[hash];
+ struct cache_head *tmp;
+
+ rcu_read_lock();
+ hlist_for_each_entry_rcu(tmp, head, cache_list) {
+ if (!detail->match(tmp, key))
+ continue;
+ if (test_bit(CACHE_VALID, &tmp->flags) &&
+ cache_is_expired(detail, tmp))
+ continue;
+ tmp = cache_get_rcu(tmp);
+ rcu_read_unlock();
+ return tmp;
+ }
+ rcu_read_unlock();
+ return NULL;
+}
+
+static void sunrpc_begin_cache_remove_entry(struct cache_head *ch,
+ struct cache_detail *cd)
+{
+ /* Must be called under cd->hash_lock */
+ hlist_del_init_rcu(&ch->cache_list);
+ set_bit(CACHE_CLEANED, &ch->flags);
+ cd->entries --;
+}
+
+static void sunrpc_end_cache_remove_entry(struct cache_head *ch,
+ struct cache_detail *cd)
+{
+ cache_fresh_unlocked(ch, cd);
+ cache_put(ch, cd);
+}
+
+static struct cache_head *sunrpc_cache_add_entry(struct cache_detail *detail,
+ struct cache_head *key,
+ int hash)
+{
+ struct cache_head *new, *tmp, *freeme = NULL;
+ struct hlist_head *head = &detail->hash_table[hash];
+
+ new = detail->alloc();
+ if (!new)
+ return NULL;
+ /* must fully initialise 'new', else
+ * we might get lose if we need to
+ * cache_put it soon.
+ */
+ cache_init(new, detail);
+ detail->init(new, key);
+
+ spin_lock(&detail->hash_lock);
+
+ /* check if entry appeared while we slept */
+ hlist_for_each_entry_rcu(tmp, head, cache_list,
+ lockdep_is_held(&detail->hash_lock)) {
+ if (!detail->match(tmp, key))
+ continue;
+ if (test_bit(CACHE_VALID, &tmp->flags) &&
+ cache_is_expired(detail, tmp)) {
+ sunrpc_begin_cache_remove_entry(tmp, detail);
+ trace_cache_entry_expired(detail, tmp);
+ freeme = tmp;
+ break;
+ }
+ cache_get(tmp);
+ spin_unlock(&detail->hash_lock);
+ cache_put(new, detail);
+ return tmp;
+ }
+
+ hlist_add_head_rcu(&new->cache_list, head);
+ detail->entries++;
+ cache_get(new);
+ spin_unlock(&detail->hash_lock);
+
+ if (freeme)
+ sunrpc_end_cache_remove_entry(freeme, detail);
+ return new;
+}
+
+struct cache_head *sunrpc_cache_lookup_rcu(struct cache_detail *detail,
+ struct cache_head *key, int hash)
+{
+ struct cache_head *ret;
+
+ ret = sunrpc_cache_find_rcu(detail, key, hash);
+ if (ret)
+ return ret;
+ /* Didn't find anything, insert an empty entry */
+ return sunrpc_cache_add_entry(detail, key, hash);
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_lookup_rcu);
+
+static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch);
+
+static void cache_fresh_locked(struct cache_head *head, time64_t expiry,
+ struct cache_detail *detail)
+{
+ time64_t now = seconds_since_boot();
+ if (now <= detail->flush_time)
+ /* ensure it isn't immediately treated as expired */
+ now = detail->flush_time + 1;
+ head->expiry_time = expiry;
+ head->last_refresh = now;
+ smp_wmb(); /* paired with smp_rmb() in cache_is_valid() */
+ set_bit(CACHE_VALID, &head->flags);
+}
+
+static void cache_fresh_unlocked(struct cache_head *head,
+ struct cache_detail *detail)
+{
+ if (test_and_clear_bit(CACHE_PENDING, &head->flags)) {
+ cache_revisit_request(head);
+ cache_dequeue(detail, head);
+ }
+}
+
+static void cache_make_negative(struct cache_detail *detail,
+ struct cache_head *h)
+{
+ set_bit(CACHE_NEGATIVE, &h->flags);
+ trace_cache_entry_make_negative(detail, h);
+}
+
+static void cache_entry_update(struct cache_detail *detail,
+ struct cache_head *h,
+ struct cache_head *new)
+{
+ if (!test_bit(CACHE_NEGATIVE, &new->flags)) {
+ detail->update(h, new);
+ trace_cache_entry_update(detail, h);
+ } else {
+ cache_make_negative(detail, h);
+ }
+}
+
+struct cache_head *sunrpc_cache_update(struct cache_detail *detail,
+ struct cache_head *new, struct cache_head *old, int hash)
+{
+ /* The 'old' entry is to be replaced by 'new'.
+ * If 'old' is not VALID, we update it directly,
+ * otherwise we need to replace it
+ */
+ struct cache_head *tmp;
+
+ if (!test_bit(CACHE_VALID, &old->flags)) {
+ spin_lock(&detail->hash_lock);
+ if (!test_bit(CACHE_VALID, &old->flags)) {
+ cache_entry_update(detail, old, new);
+ cache_fresh_locked(old, new->expiry_time, detail);
+ spin_unlock(&detail->hash_lock);
+ cache_fresh_unlocked(old, detail);
+ return old;
+ }
+ spin_unlock(&detail->hash_lock);
+ }
+ /* We need to insert a new entry */
+ tmp = detail->alloc();
+ if (!tmp) {
+ cache_put(old, detail);
+ return NULL;
+ }
+ cache_init(tmp, detail);
+ detail->init(tmp, old);
+
+ spin_lock(&detail->hash_lock);
+ cache_entry_update(detail, tmp, new);
+ hlist_add_head(&tmp->cache_list, &detail->hash_table[hash]);
+ detail->entries++;
+ cache_get(tmp);
+ cache_fresh_locked(tmp, new->expiry_time, detail);
+ cache_fresh_locked(old, 0, detail);
+ spin_unlock(&detail->hash_lock);
+ cache_fresh_unlocked(tmp, detail);
+ cache_fresh_unlocked(old, detail);
+ cache_put(old, detail);
+ return tmp;
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_update);
+
+static inline int cache_is_valid(struct cache_head *h)
+{
+ if (!test_bit(CACHE_VALID, &h->flags))
+ return -EAGAIN;
+ else {
+ /* entry is valid */
+ if (test_bit(CACHE_NEGATIVE, &h->flags))
+ return -ENOENT;
+ else {
+ /*
+ * In combination with write barrier in
+ * sunrpc_cache_update, ensures that anyone
+ * using the cache entry after this sees the
+ * updated contents:
+ */
+ smp_rmb();
+ return 0;
+ }
+ }
+}
+
+static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h)
+{
+ int rv;
+
+ spin_lock(&detail->hash_lock);
+ rv = cache_is_valid(h);
+ if (rv == -EAGAIN) {
+ cache_make_negative(detail, h);
+ cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY,
+ detail);
+ rv = -ENOENT;
+ }
+ spin_unlock(&detail->hash_lock);
+ cache_fresh_unlocked(h, detail);
+ return rv;
+}
+
+/*
+ * This is the generic cache management routine for all
+ * the authentication caches.
+ * It checks the currency of a cache item and will (later)
+ * initiate an upcall to fill it if needed.
+ *
+ *
+ * Returns 0 if the cache_head can be used, or cache_puts it and returns
+ * -EAGAIN if upcall is pending and request has been queued
+ * -ETIMEDOUT if upcall failed or request could not be queue or
+ * upcall completed but item is still invalid (implying that
+ * the cache item has been replaced with a newer one).
+ * -ENOENT if cache entry was negative
+ */
+int cache_check(struct cache_detail *detail,
+ struct cache_head *h, struct cache_req *rqstp)
+{
+ int rv;
+ time64_t refresh_age, age;
+
+ /* First decide return status as best we can */
+ rv = cache_is_valid(h);
+
+ /* now see if we want to start an upcall */
+ refresh_age = (h->expiry_time - h->last_refresh);
+ age = seconds_since_boot() - h->last_refresh;
+
+ if (rqstp == NULL) {
+ if (rv == -EAGAIN)
+ rv = -ENOENT;
+ } else if (rv == -EAGAIN ||
+ (h->expiry_time != 0 && age > refresh_age/2)) {
+ dprintk("RPC: Want update, refage=%lld, age=%lld\n",
+ refresh_age, age);
+ switch (detail->cache_upcall(detail, h)) {
+ case -EINVAL:
+ rv = try_to_negate_entry(detail, h);
+ break;
+ case -EAGAIN:
+ cache_fresh_unlocked(h, detail);
+ break;
+ }
+ }
+
+ if (rv == -EAGAIN) {
+ if (!cache_defer_req(rqstp, h)) {
+ /*
+ * Request was not deferred; handle it as best
+ * we can ourselves:
+ */
+ rv = cache_is_valid(h);
+ if (rv == -EAGAIN)
+ rv = -ETIMEDOUT;
+ }
+ }
+ if (rv)
+ cache_put(h, detail);
+ return rv;
+}
+EXPORT_SYMBOL_GPL(cache_check);
+
+/*
+ * caches need to be periodically cleaned.
+ * For this we maintain a list of cache_detail and
+ * a current pointer into that list and into the table
+ * for that entry.
+ *
+ * Each time cache_clean is called it finds the next non-empty entry
+ * in the current table and walks the list in that entry
+ * looking for entries that can be removed.
+ *
+ * An entry gets removed if:
+ * - The expiry is before current time
+ * - The last_refresh time is before the flush_time for that cache
+ *
+ * later we might drop old entries with non-NEVER expiry if that table
+ * is getting 'full' for some definition of 'full'
+ *
+ * The question of "how often to scan a table" is an interesting one
+ * and is answered in part by the use of the "nextcheck" field in the
+ * cache_detail.
+ * When a scan of a table begins, the nextcheck field is set to a time
+ * that is well into the future.
+ * While scanning, if an expiry time is found that is earlier than the
+ * current nextcheck time, nextcheck is set to that expiry time.
+ * If the flush_time is ever set to a time earlier than the nextcheck
+ * time, the nextcheck time is then set to that flush_time.
+ *
+ * A table is then only scanned if the current time is at least
+ * the nextcheck time.
+ *
+ */
+
+static LIST_HEAD(cache_list);
+static DEFINE_SPINLOCK(cache_list_lock);
+static struct cache_detail *current_detail;
+static int current_index;
+
+static void do_cache_clean(struct work_struct *work);
+static struct delayed_work cache_cleaner;
+
+void sunrpc_init_cache_detail(struct cache_detail *cd)
+{
+ spin_lock_init(&cd->hash_lock);
+ INIT_LIST_HEAD(&cd->queue);
+ spin_lock(&cache_list_lock);
+ cd->nextcheck = 0;
+ cd->entries = 0;
+ atomic_set(&cd->writers, 0);
+ cd->last_close = 0;
+ cd->last_warn = -1;
+ list_add(&cd->others, &cache_list);
+ spin_unlock(&cache_list_lock);
+
+ /* start the cleaning process */
+ queue_delayed_work(system_power_efficient_wq, &cache_cleaner, 0);
+}
+EXPORT_SYMBOL_GPL(sunrpc_init_cache_detail);
+
+void sunrpc_destroy_cache_detail(struct cache_detail *cd)
+{
+ cache_purge(cd);
+ spin_lock(&cache_list_lock);
+ spin_lock(&cd->hash_lock);
+ if (current_detail == cd)
+ current_detail = NULL;
+ list_del_init(&cd->others);
+ spin_unlock(&cd->hash_lock);
+ spin_unlock(&cache_list_lock);
+ if (list_empty(&cache_list)) {
+ /* module must be being unloaded so its safe to kill the worker */
+ cancel_delayed_work_sync(&cache_cleaner);
+ }
+}
+EXPORT_SYMBOL_GPL(sunrpc_destroy_cache_detail);
+
+/* clean cache tries to find something to clean
+ * and cleans it.
+ * It returns 1 if it cleaned something,
+ * 0 if it didn't find anything this time
+ * -1 if it fell off the end of the list.
+ */
+static int cache_clean(void)
+{
+ int rv = 0;
+ struct list_head *next;
+
+ spin_lock(&cache_list_lock);
+
+ /* find a suitable table if we don't already have one */
+ while (current_detail == NULL ||
+ current_index >= current_detail->hash_size) {
+ if (current_detail)
+ next = current_detail->others.next;
+ else
+ next = cache_list.next;
+ if (next == &cache_list) {
+ current_detail = NULL;
+ spin_unlock(&cache_list_lock);
+ return -1;
+ }
+ current_detail = list_entry(next, struct cache_detail, others);
+ if (current_detail->nextcheck > seconds_since_boot())
+ current_index = current_detail->hash_size;
+ else {
+ current_index = 0;
+ current_detail->nextcheck = seconds_since_boot()+30*60;
+ }
+ }
+
+ /* find a non-empty bucket in the table */
+ while (current_detail &&
+ current_index < current_detail->hash_size &&
+ hlist_empty(&current_detail->hash_table[current_index]))
+ current_index++;
+
+ /* find a cleanable entry in the bucket and clean it, or set to next bucket */
+
+ if (current_detail && current_index < current_detail->hash_size) {
+ struct cache_head *ch = NULL;
+ struct cache_detail *d;
+ struct hlist_head *head;
+ struct hlist_node *tmp;
+
+ spin_lock(&current_detail->hash_lock);
+
+ /* Ok, now to clean this strand */
+
+ head = &current_detail->hash_table[current_index];
+ hlist_for_each_entry_safe(ch, tmp, head, cache_list) {
+ if (current_detail->nextcheck > ch->expiry_time)
+ current_detail->nextcheck = ch->expiry_time+1;
+ if (!cache_is_expired(current_detail, ch))
+ continue;
+
+ sunrpc_begin_cache_remove_entry(ch, current_detail);
+ trace_cache_entry_expired(current_detail, ch);
+ rv = 1;
+ break;
+ }
+
+ spin_unlock(&current_detail->hash_lock);
+ d = current_detail;
+ if (!ch)
+ current_index ++;
+ spin_unlock(&cache_list_lock);
+ if (ch)
+ sunrpc_end_cache_remove_entry(ch, d);
+ } else
+ spin_unlock(&cache_list_lock);
+
+ return rv;
+}
+
+/*
+ * We want to regularly clean the cache, so we need to schedule some work ...
+ */
+static void do_cache_clean(struct work_struct *work)
+{
+ int delay;
+
+ if (list_empty(&cache_list))
+ return;
+
+ if (cache_clean() == -1)
+ delay = round_jiffies_relative(30*HZ);
+ else
+ delay = 5;
+
+ queue_delayed_work(system_power_efficient_wq, &cache_cleaner, delay);
+}
+
+
+/*
+ * Clean all caches promptly. This just calls cache_clean
+ * repeatedly until we are sure that every cache has had a chance to
+ * be fully cleaned
+ */
+void cache_flush(void)
+{
+ while (cache_clean() != -1)
+ cond_resched();
+ while (cache_clean() != -1)
+ cond_resched();
+}
+EXPORT_SYMBOL_GPL(cache_flush);
+
+void cache_purge(struct cache_detail *detail)
+{
+ struct cache_head *ch = NULL;
+ struct hlist_head *head = NULL;
+ int i = 0;
+
+ spin_lock(&detail->hash_lock);
+ if (!detail->entries) {
+ spin_unlock(&detail->hash_lock);
+ return;
+ }
+
+ dprintk("RPC: %d entries in %s cache\n", detail->entries, detail->name);
+ for (i = 0; i < detail->hash_size; i++) {
+ head = &detail->hash_table[i];
+ while (!hlist_empty(head)) {
+ ch = hlist_entry(head->first, struct cache_head,
+ cache_list);
+ sunrpc_begin_cache_remove_entry(ch, detail);
+ spin_unlock(&detail->hash_lock);
+ sunrpc_end_cache_remove_entry(ch, detail);
+ spin_lock(&detail->hash_lock);
+ }
+ }
+ spin_unlock(&detail->hash_lock);
+}
+EXPORT_SYMBOL_GPL(cache_purge);
+
+
+/*
+ * Deferral and Revisiting of Requests.
+ *
+ * If a cache lookup finds a pending entry, we
+ * need to defer the request and revisit it later.
+ * All deferred requests are stored in a hash table,
+ * indexed by "struct cache_head *".
+ * As it may be wasteful to store a whole request
+ * structure, we allow the request to provide a
+ * deferred form, which must contain a
+ * 'struct cache_deferred_req'
+ * This cache_deferred_req contains a method to allow
+ * it to be revisited when cache info is available
+ */
+
+#define DFR_HASHSIZE (PAGE_SIZE/sizeof(struct list_head))
+#define DFR_HASH(item) ((((long)item)>>4 ^ (((long)item)>>13)) % DFR_HASHSIZE)
+
+#define DFR_MAX 300 /* ??? */
+
+static DEFINE_SPINLOCK(cache_defer_lock);
+static LIST_HEAD(cache_defer_list);
+static struct hlist_head cache_defer_hash[DFR_HASHSIZE];
+static int cache_defer_cnt;
+
+static void __unhash_deferred_req(struct cache_deferred_req *dreq)
+{
+ hlist_del_init(&dreq->hash);
+ if (!list_empty(&dreq->recent)) {
+ list_del_init(&dreq->recent);
+ cache_defer_cnt--;
+ }
+}
+
+static void __hash_deferred_req(struct cache_deferred_req *dreq, struct cache_head *item)
+{
+ int hash = DFR_HASH(item);
+
+ INIT_LIST_HEAD(&dreq->recent);
+ hlist_add_head(&dreq->hash, &cache_defer_hash[hash]);
+}
+
+static void setup_deferral(struct cache_deferred_req *dreq,
+ struct cache_head *item,
+ int count_me)
+{
+
+ dreq->item = item;
+
+ spin_lock(&cache_defer_lock);
+
+ __hash_deferred_req(dreq, item);
+
+ if (count_me) {
+ cache_defer_cnt++;
+ list_add(&dreq->recent, &cache_defer_list);
+ }
+
+ spin_unlock(&cache_defer_lock);
+
+}
+
+struct thread_deferred_req {
+ struct cache_deferred_req handle;
+ struct completion completion;
+};
+
+static void cache_restart_thread(struct cache_deferred_req *dreq, int too_many)
+{
+ struct thread_deferred_req *dr =
+ container_of(dreq, struct thread_deferred_req, handle);
+ complete(&dr->completion);
+}
+
+static void cache_wait_req(struct cache_req *req, struct cache_head *item)
+{
+ struct thread_deferred_req sleeper;
+ struct cache_deferred_req *dreq = &sleeper.handle;
+
+ sleeper.completion = COMPLETION_INITIALIZER_ONSTACK(sleeper.completion);
+ dreq->revisit = cache_restart_thread;
+
+ setup_deferral(dreq, item, 0);
+
+ if (!test_bit(CACHE_PENDING, &item->flags) ||
+ wait_for_completion_interruptible_timeout(
+ &sleeper.completion, req->thread_wait) <= 0) {
+ /* The completion wasn't completed, so we need
+ * to clean up
+ */
+ spin_lock(&cache_defer_lock);
+ if (!hlist_unhashed(&sleeper.handle.hash)) {
+ __unhash_deferred_req(&sleeper.handle);
+ spin_unlock(&cache_defer_lock);
+ } else {
+ /* cache_revisit_request already removed
+ * this from the hash table, but hasn't
+ * called ->revisit yet. It will very soon
+ * and we need to wait for it.
+ */
+ spin_unlock(&cache_defer_lock);
+ wait_for_completion(&sleeper.completion);
+ }
+ }
+}
+
+static void cache_limit_defers(void)
+{
+ /* Make sure we haven't exceed the limit of allowed deferred
+ * requests.
+ */
+ struct cache_deferred_req *discard = NULL;
+
+ if (cache_defer_cnt <= DFR_MAX)
+ return;
+
+ spin_lock(&cache_defer_lock);
+
+ /* Consider removing either the first or the last */
+ if (cache_defer_cnt > DFR_MAX) {
+ if (get_random_u32_below(2))
+ discard = list_entry(cache_defer_list.next,
+ struct cache_deferred_req, recent);
+ else
+ discard = list_entry(cache_defer_list.prev,
+ struct cache_deferred_req, recent);
+ __unhash_deferred_req(discard);
+ }
+ spin_unlock(&cache_defer_lock);
+ if (discard)
+ discard->revisit(discard, 1);
+}
+
+#if IS_ENABLED(CONFIG_FAIL_SUNRPC)
+static inline bool cache_defer_immediately(void)
+{
+ return !fail_sunrpc.ignore_cache_wait &&
+ should_fail(&fail_sunrpc.attr, 1);
+}
+#else
+static inline bool cache_defer_immediately(void)
+{
+ return false;
+}
+#endif
+
+/* Return true if and only if a deferred request is queued. */
+static bool cache_defer_req(struct cache_req *req, struct cache_head *item)
+{
+ struct cache_deferred_req *dreq;
+
+ if (!cache_defer_immediately()) {
+ cache_wait_req(req, item);
+ if (!test_bit(CACHE_PENDING, &item->flags))
+ return false;
+ }
+
+ dreq = req->defer(req);
+ if (dreq == NULL)
+ return false;
+ setup_deferral(dreq, item, 1);
+ if (!test_bit(CACHE_PENDING, &item->flags))
+ /* Bit could have been cleared before we managed to
+ * set up the deferral, so need to revisit just in case
+ */
+ cache_revisit_request(item);
+
+ cache_limit_defers();
+ return true;
+}
+
+static void cache_revisit_request(struct cache_head *item)
+{
+ struct cache_deferred_req *dreq;
+ struct list_head pending;
+ struct hlist_node *tmp;
+ int hash = DFR_HASH(item);
+
+ INIT_LIST_HEAD(&pending);
+ spin_lock(&cache_defer_lock);
+
+ hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash)
+ if (dreq->item == item) {
+ __unhash_deferred_req(dreq);
+ list_add(&dreq->recent, &pending);
+ }
+
+ spin_unlock(&cache_defer_lock);
+
+ while (!list_empty(&pending)) {
+ dreq = list_entry(pending.next, struct cache_deferred_req, recent);
+ list_del_init(&dreq->recent);
+ dreq->revisit(dreq, 0);
+ }
+}
+
+void cache_clean_deferred(void *owner)
+{
+ struct cache_deferred_req *dreq, *tmp;
+ struct list_head pending;
+
+
+ INIT_LIST_HEAD(&pending);
+ spin_lock(&cache_defer_lock);
+
+ list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) {
+ if (dreq->owner == owner) {
+ __unhash_deferred_req(dreq);
+ list_add(&dreq->recent, &pending);
+ }
+ }
+ spin_unlock(&cache_defer_lock);
+
+ while (!list_empty(&pending)) {
+ dreq = list_entry(pending.next, struct cache_deferred_req, recent);
+ list_del_init(&dreq->recent);
+ dreq->revisit(dreq, 1);
+ }
+}
+
+/*
+ * communicate with user-space
+ *
+ * We have a magic /proc file - /proc/net/rpc/<cachename>/channel.
+ * On read, you get a full request, or block.
+ * On write, an update request is processed.
+ * Poll works if anything to read, and always allows write.
+ *
+ * Implemented by linked list of requests. Each open file has
+ * a ->private that also exists in this list. New requests are added
+ * to the end and may wakeup and preceding readers.
+ * New readers are added to the head. If, on read, an item is found with
+ * CACHE_UPCALLING clear, we free it from the list.
+ *
+ */
+
+static DEFINE_SPINLOCK(queue_lock);
+
+struct cache_queue {
+ struct list_head list;
+ int reader; /* if 0, then request */
+};
+struct cache_request {
+ struct cache_queue q;
+ struct cache_head *item;
+ char * buf;
+ int len;
+ int readers;
+};
+struct cache_reader {
+ struct cache_queue q;
+ int offset; /* if non-0, we have a refcnt on next request */
+};
+
+static int cache_request(struct cache_detail *detail,
+ struct cache_request *crq)
+{
+ char *bp = crq->buf;
+ int len = PAGE_SIZE;
+
+ detail->cache_request(detail, crq->item, &bp, &len);
+ if (len < 0)
+ return -E2BIG;
+ return PAGE_SIZE - len;
+}
+
+static ssize_t cache_read(struct file *filp, char __user *buf, size_t count,
+ loff_t *ppos, struct cache_detail *cd)
+{
+ struct cache_reader *rp = filp->private_data;
+ struct cache_request *rq;
+ struct inode *inode = file_inode(filp);
+ int err;
+
+ if (count == 0)
+ return 0;
+
+ inode_lock(inode); /* protect against multiple concurrent
+ * readers on this file */
+ again:
+ spin_lock(&queue_lock);
+ /* need to find next request */
+ while (rp->q.list.next != &cd->queue &&
+ list_entry(rp->q.list.next, struct cache_queue, list)
+ ->reader) {
+ struct list_head *next = rp->q.list.next;
+ list_move(&rp->q.list, next);
+ }
+ if (rp->q.list.next == &cd->queue) {
+ spin_unlock(&queue_lock);
+ inode_unlock(inode);
+ WARN_ON_ONCE(rp->offset);
+ return 0;
+ }
+ rq = container_of(rp->q.list.next, struct cache_request, q.list);
+ WARN_ON_ONCE(rq->q.reader);
+ if (rp->offset == 0)
+ rq->readers++;
+ spin_unlock(&queue_lock);
+
+ if (rq->len == 0) {
+ err = cache_request(cd, rq);
+ if (err < 0)
+ goto out;
+ rq->len = err;
+ }
+
+ if (rp->offset == 0 && !test_bit(CACHE_PENDING, &rq->item->flags)) {
+ err = -EAGAIN;
+ spin_lock(&queue_lock);
+ list_move(&rp->q.list, &rq->q.list);
+ spin_unlock(&queue_lock);
+ } else {
+ if (rp->offset + count > rq->len)
+ count = rq->len - rp->offset;
+ err = -EFAULT;
+ if (copy_to_user(buf, rq->buf + rp->offset, count))
+ goto out;
+ rp->offset += count;
+ if (rp->offset >= rq->len) {
+ rp->offset = 0;
+ spin_lock(&queue_lock);
+ list_move(&rp->q.list, &rq->q.list);
+ spin_unlock(&queue_lock);
+ }
+ err = 0;
+ }
+ out:
+ if (rp->offset == 0) {
+ /* need to release rq */
+ spin_lock(&queue_lock);
+ rq->readers--;
+ if (rq->readers == 0 &&
+ !test_bit(CACHE_PENDING, &rq->item->flags)) {
+ list_del(&rq->q.list);
+ spin_unlock(&queue_lock);
+ cache_put(rq->item, cd);
+ kfree(rq->buf);
+ kfree(rq);
+ } else
+ spin_unlock(&queue_lock);
+ }
+ if (err == -EAGAIN)
+ goto again;
+ inode_unlock(inode);
+ return err ? err : count;
+}
+
+static ssize_t cache_do_downcall(char *kaddr, const char __user *buf,
+ size_t count, struct cache_detail *cd)
+{
+ ssize_t ret;
+
+ if (count == 0)
+ return -EINVAL;
+ if (copy_from_user(kaddr, buf, count))
+ return -EFAULT;
+ kaddr[count] = '\0';
+ ret = cd->cache_parse(cd, kaddr, count);
+ if (!ret)
+ ret = count;
+ return ret;
+}
+
+static ssize_t cache_downcall(struct address_space *mapping,
+ const char __user *buf,
+ size_t count, struct cache_detail *cd)
+{
+ char *write_buf;
+ ssize_t ret = -ENOMEM;
+
+ if (count >= 32768) { /* 32k is max userland buffer, lets check anyway */
+ ret = -EINVAL;
+ goto out;
+ }
+
+ write_buf = kvmalloc(count + 1, GFP_KERNEL);
+ if (!write_buf)
+ goto out;
+
+ ret = cache_do_downcall(write_buf, buf, count, cd);
+ kvfree(write_buf);
+out:
+ return ret;
+}
+
+static ssize_t cache_write(struct file *filp, const char __user *buf,
+ size_t count, loff_t *ppos,
+ struct cache_detail *cd)
+{
+ struct address_space *mapping = filp->f_mapping;
+ struct inode *inode = file_inode(filp);
+ ssize_t ret = -EINVAL;
+
+ if (!cd->cache_parse)
+ goto out;
+
+ inode_lock(inode);
+ ret = cache_downcall(mapping, buf, count, cd);
+ inode_unlock(inode);
+out:
+ return ret;
+}
+
+static DECLARE_WAIT_QUEUE_HEAD(queue_wait);
+
+static __poll_t cache_poll(struct file *filp, poll_table *wait,
+ struct cache_detail *cd)
+{
+ __poll_t mask;
+ struct cache_reader *rp = filp->private_data;
+ struct cache_queue *cq;
+
+ poll_wait(filp, &queue_wait, wait);
+
+ /* alway allow write */
+ mask = EPOLLOUT | EPOLLWRNORM;
+
+ if (!rp)
+ return mask;
+
+ spin_lock(&queue_lock);
+
+ for (cq= &rp->q; &cq->list != &cd->queue;
+ cq = list_entry(cq->list.next, struct cache_queue, list))
+ if (!cq->reader) {
+ mask |= EPOLLIN | EPOLLRDNORM;
+ break;
+ }
+ spin_unlock(&queue_lock);
+ return mask;
+}
+
+static int cache_ioctl(struct inode *ino, struct file *filp,
+ unsigned int cmd, unsigned long arg,
+ struct cache_detail *cd)
+{
+ int len = 0;
+ struct cache_reader *rp = filp->private_data;
+ struct cache_queue *cq;
+
+ if (cmd != FIONREAD || !rp)
+ return -EINVAL;
+
+ spin_lock(&queue_lock);
+
+ /* only find the length remaining in current request,
+ * or the length of the next request
+ */
+ for (cq= &rp->q; &cq->list != &cd->queue;
+ cq = list_entry(cq->list.next, struct cache_queue, list))
+ if (!cq->reader) {
+ struct cache_request *cr =
+ container_of(cq, struct cache_request, q);
+ len = cr->len - rp->offset;
+ break;
+ }
+ spin_unlock(&queue_lock);
+
+ return put_user(len, (int __user *)arg);
+}
+
+static int cache_open(struct inode *inode, struct file *filp,
+ struct cache_detail *cd)
+{
+ struct cache_reader *rp = NULL;
+
+ if (!cd || !try_module_get(cd->owner))
+ return -EACCES;
+ nonseekable_open(inode, filp);
+ if (filp->f_mode & FMODE_READ) {
+ rp = kmalloc(sizeof(*rp), GFP_KERNEL);
+ if (!rp) {
+ module_put(cd->owner);
+ return -ENOMEM;
+ }
+ rp->offset = 0;
+ rp->q.reader = 1;
+
+ spin_lock(&queue_lock);
+ list_add(&rp->q.list, &cd->queue);
+ spin_unlock(&queue_lock);
+ }
+ if (filp->f_mode & FMODE_WRITE)
+ atomic_inc(&cd->writers);
+ filp->private_data = rp;
+ return 0;
+}
+
+static int cache_release(struct inode *inode, struct file *filp,
+ struct cache_detail *cd)
+{
+ struct cache_reader *rp = filp->private_data;
+
+ if (rp) {
+ spin_lock(&queue_lock);
+ if (rp->offset) {
+ struct cache_queue *cq;
+ for (cq= &rp->q; &cq->list != &cd->queue;
+ cq = list_entry(cq->list.next, struct cache_queue, list))
+ if (!cq->reader) {
+ container_of(cq, struct cache_request, q)
+ ->readers--;
+ break;
+ }
+ rp->offset = 0;
+ }
+ list_del(&rp->q.list);
+ spin_unlock(&queue_lock);
+
+ filp->private_data = NULL;
+ kfree(rp);
+
+ }
+ if (filp->f_mode & FMODE_WRITE) {
+ atomic_dec(&cd->writers);
+ cd->last_close = seconds_since_boot();
+ }
+ module_put(cd->owner);
+ return 0;
+}
+
+
+
+static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch)
+{
+ struct cache_queue *cq, *tmp;
+ struct cache_request *cr;
+ struct list_head dequeued;
+
+ INIT_LIST_HEAD(&dequeued);
+ spin_lock(&queue_lock);
+ list_for_each_entry_safe(cq, tmp, &detail->queue, list)
+ if (!cq->reader) {
+ cr = container_of(cq, struct cache_request, q);
+ if (cr->item != ch)
+ continue;
+ if (test_bit(CACHE_PENDING, &ch->flags))
+ /* Lost a race and it is pending again */
+ break;
+ if (cr->readers != 0)
+ continue;
+ list_move(&cr->q.list, &dequeued);
+ }
+ spin_unlock(&queue_lock);
+ while (!list_empty(&dequeued)) {
+ cr = list_entry(dequeued.next, struct cache_request, q.list);
+ list_del(&cr->q.list);
+ cache_put(cr->item, detail);
+ kfree(cr->buf);
+ kfree(cr);
+ }
+}
+
+/*
+ * Support routines for text-based upcalls.
+ * Fields are separated by spaces.
+ * Fields are either mangled to quote space tab newline slosh with slosh
+ * or a hexified with a leading \x
+ * Record is terminated with newline.
+ *
+ */
+
+void qword_add(char **bpp, int *lp, char *str)
+{
+ char *bp = *bpp;
+ int len = *lp;
+ int ret;
+
+ if (len < 0) return;
+
+ ret = string_escape_str(str, bp, len, ESCAPE_OCTAL, "\\ \n\t");
+ if (ret >= len) {
+ bp += len;
+ len = -1;
+ } else {
+ bp += ret;
+ len -= ret;
+ *bp++ = ' ';
+ len--;
+ }
+ *bpp = bp;
+ *lp = len;
+}
+EXPORT_SYMBOL_GPL(qword_add);
+
+void qword_addhex(char **bpp, int *lp, char *buf, int blen)
+{
+ char *bp = *bpp;
+ int len = *lp;
+
+ if (len < 0) return;
+
+ if (len > 2) {
+ *bp++ = '\\';
+ *bp++ = 'x';
+ len -= 2;
+ while (blen && len >= 2) {
+ bp = hex_byte_pack(bp, *buf++);
+ len -= 2;
+ blen--;
+ }
+ }
+ if (blen || len<1) len = -1;
+ else {
+ *bp++ = ' ';
+ len--;
+ }
+ *bpp = bp;
+ *lp = len;
+}
+EXPORT_SYMBOL_GPL(qword_addhex);
+
+static void warn_no_listener(struct cache_detail *detail)
+{
+ if (detail->last_warn != detail->last_close) {
+ detail->last_warn = detail->last_close;
+ if (detail->warn_no_listener)
+ detail->warn_no_listener(detail, detail->last_close != 0);
+ }
+}
+
+static bool cache_listeners_exist(struct cache_detail *detail)
+{
+ if (atomic_read(&detail->writers))
+ return true;
+ if (detail->last_close == 0)
+ /* This cache was never opened */
+ return false;
+ if (detail->last_close < seconds_since_boot() - 30)
+ /*
+ * We allow for the possibility that someone might
+ * restart a userspace daemon without restarting the
+ * server; but after 30 seconds, we give up.
+ */
+ return false;
+ return true;
+}
+
+/*
+ * register an upcall request to user-space and queue it up for read() by the
+ * upcall daemon.
+ *
+ * Each request is at most one page long.
+ */
+static int cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h)
+{
+ char *buf;
+ struct cache_request *crq;
+ int ret = 0;
+
+ if (test_bit(CACHE_CLEANED, &h->flags))
+ /* Too late to make an upcall */
+ return -EAGAIN;
+
+ buf = kmalloc(PAGE_SIZE, GFP_KERNEL);
+ if (!buf)
+ return -EAGAIN;
+
+ crq = kmalloc(sizeof (*crq), GFP_KERNEL);
+ if (!crq) {
+ kfree(buf);
+ return -EAGAIN;
+ }
+
+ crq->q.reader = 0;
+ crq->buf = buf;
+ crq->len = 0;
+ crq->readers = 0;
+ spin_lock(&queue_lock);
+ if (test_bit(CACHE_PENDING, &h->flags)) {
+ crq->item = cache_get(h);
+ list_add_tail(&crq->q.list, &detail->queue);
+ trace_cache_entry_upcall(detail, h);
+ } else
+ /* Lost a race, no longer PENDING, so don't enqueue */
+ ret = -EAGAIN;
+ spin_unlock(&queue_lock);
+ wake_up(&queue_wait);
+ if (ret == -EAGAIN) {
+ kfree(buf);
+ kfree(crq);
+ }
+ return ret;
+}
+
+int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h)
+{
+ if (test_and_set_bit(CACHE_PENDING, &h->flags))
+ return 0;
+ return cache_pipe_upcall(detail, h);
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall);
+
+int sunrpc_cache_pipe_upcall_timeout(struct cache_detail *detail,
+ struct cache_head *h)
+{
+ if (!cache_listeners_exist(detail)) {
+ warn_no_listener(detail);
+ trace_cache_entry_no_listener(detail, h);
+ return -EINVAL;
+ }
+ return sunrpc_cache_pipe_upcall(detail, h);
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall_timeout);
+
+/*
+ * parse a message from user-space and pass it
+ * to an appropriate cache
+ * Messages are, like requests, separated into fields by
+ * spaces and dequotes as \xHEXSTRING or embedded \nnn octal
+ *
+ * Message is
+ * reply cachename expiry key ... content....
+ *
+ * key and content are both parsed by cache
+ */
+
+int qword_get(char **bpp, char *dest, int bufsize)
+{
+ /* return bytes copied, or -1 on error */
+ char *bp = *bpp;
+ int len = 0;
+
+ while (*bp == ' ') bp++;
+
+ if (bp[0] == '\\' && bp[1] == 'x') {
+ /* HEX STRING */
+ bp += 2;
+ while (len < bufsize - 1) {
+ int h, l;
+
+ h = hex_to_bin(bp[0]);
+ if (h < 0)
+ break;
+
+ l = hex_to_bin(bp[1]);
+ if (l < 0)
+ break;
+
+ *dest++ = (h << 4) | l;
+ bp += 2;
+ len++;
+ }
+ } else {
+ /* text with \nnn octal quoting */
+ while (*bp != ' ' && *bp != '\n' && *bp && len < bufsize-1) {
+ if (*bp == '\\' &&
+ isodigit(bp[1]) && (bp[1] <= '3') &&
+ isodigit(bp[2]) &&
+ isodigit(bp[3])) {
+ int byte = (*++bp -'0');
+ bp++;
+ byte = (byte << 3) | (*bp++ - '0');
+ byte = (byte << 3) | (*bp++ - '0');
+ *dest++ = byte;
+ len++;
+ } else {
+ *dest++ = *bp++;
+ len++;
+ }
+ }
+ }
+
+ if (*bp != ' ' && *bp != '\n' && *bp != '\0')
+ return -1;
+ while (*bp == ' ') bp++;
+ *bpp = bp;
+ *dest = '\0';
+ return len;
+}
+EXPORT_SYMBOL_GPL(qword_get);
+
+
+/*
+ * support /proc/net/rpc/$CACHENAME/content
+ * as a seqfile.
+ * We call ->cache_show passing NULL for the item to
+ * get a header, then pass each real item in the cache
+ */
+
+static void *__cache_seq_start(struct seq_file *m, loff_t *pos)
+{
+ loff_t n = *pos;
+ unsigned int hash, entry;
+ struct cache_head *ch;
+ struct cache_detail *cd = m->private;
+
+ if (!n--)
+ return SEQ_START_TOKEN;
+ hash = n >> 32;
+ entry = n & ((1LL<<32) - 1);
+
+ hlist_for_each_entry_rcu(ch, &cd->hash_table[hash], cache_list)
+ if (!entry--)
+ return ch;
+ n &= ~((1LL<<32) - 1);
+ do {
+ hash++;
+ n += 1LL<<32;
+ } while(hash < cd->hash_size &&
+ hlist_empty(&cd->hash_table[hash]));
+ if (hash >= cd->hash_size)
+ return NULL;
+ *pos = n+1;
+ return hlist_entry_safe(rcu_dereference_raw(
+ hlist_first_rcu(&cd->hash_table[hash])),
+ struct cache_head, cache_list);
+}
+
+static void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos)
+{
+ struct cache_head *ch = p;
+ int hash = (*pos >> 32);
+ struct cache_detail *cd = m->private;
+
+ if (p == SEQ_START_TOKEN)
+ hash = 0;
+ else if (ch->cache_list.next == NULL) {
+ hash++;
+ *pos += 1LL<<32;
+ } else {
+ ++*pos;
+ return hlist_entry_safe(rcu_dereference_raw(
+ hlist_next_rcu(&ch->cache_list)),
+ struct cache_head, cache_list);
+ }
+ *pos &= ~((1LL<<32) - 1);
+ while (hash < cd->hash_size &&
+ hlist_empty(&cd->hash_table[hash])) {
+ hash++;
+ *pos += 1LL<<32;
+ }
+ if (hash >= cd->hash_size)
+ return NULL;
+ ++*pos;
+ return hlist_entry_safe(rcu_dereference_raw(
+ hlist_first_rcu(&cd->hash_table[hash])),
+ struct cache_head, cache_list);
+}
+
+void *cache_seq_start_rcu(struct seq_file *m, loff_t *pos)
+ __acquires(RCU)
+{
+ rcu_read_lock();
+ return __cache_seq_start(m, pos);
+}
+EXPORT_SYMBOL_GPL(cache_seq_start_rcu);
+
+void *cache_seq_next_rcu(struct seq_file *file, void *p, loff_t *pos)
+{
+ return cache_seq_next(file, p, pos);
+}
+EXPORT_SYMBOL_GPL(cache_seq_next_rcu);
+
+void cache_seq_stop_rcu(struct seq_file *m, void *p)
+ __releases(RCU)
+{
+ rcu_read_unlock();
+}
+EXPORT_SYMBOL_GPL(cache_seq_stop_rcu);
+
+static int c_show(struct seq_file *m, void *p)
+{
+ struct cache_head *cp = p;
+ struct cache_detail *cd = m->private;
+
+ if (p == SEQ_START_TOKEN)
+ return cd->cache_show(m, cd, NULL);
+
+ ifdebug(CACHE)
+ seq_printf(m, "# expiry=%lld refcnt=%d flags=%lx\n",
+ convert_to_wallclock(cp->expiry_time),
+ kref_read(&cp->ref), cp->flags);
+ cache_get(cp);
+ if (cache_check(cd, cp, NULL))
+ /* cache_check does a cache_put on failure */
+ seq_puts(m, "# ");
+ else {
+ if (cache_is_expired(cd, cp))
+ seq_puts(m, "# ");
+ cache_put(cp, cd);
+ }
+
+ return cd->cache_show(m, cd, cp);
+}
+
+static const struct seq_operations cache_content_op = {
+ .start = cache_seq_start_rcu,
+ .next = cache_seq_next_rcu,
+ .stop = cache_seq_stop_rcu,
+ .show = c_show,
+};
+
+static int content_open(struct inode *inode, struct file *file,
+ struct cache_detail *cd)
+{
+ struct seq_file *seq;
+ int err;
+
+ if (!cd || !try_module_get(cd->owner))
+ return -EACCES;
+
+ err = seq_open(file, &cache_content_op);
+ if (err) {
+ module_put(cd->owner);
+ return err;
+ }
+
+ seq = file->private_data;
+ seq->private = cd;
+ return 0;
+}
+
+static int content_release(struct inode *inode, struct file *file,
+ struct cache_detail *cd)
+{
+ int ret = seq_release(inode, file);
+ module_put(cd->owner);
+ return ret;
+}
+
+static int open_flush(struct inode *inode, struct file *file,
+ struct cache_detail *cd)
+{
+ if (!cd || !try_module_get(cd->owner))
+ return -EACCES;
+ return nonseekable_open(inode, file);
+}
+
+static int release_flush(struct inode *inode, struct file *file,
+ struct cache_detail *cd)
+{
+ module_put(cd->owner);
+ return 0;
+}
+
+static ssize_t read_flush(struct file *file, char __user *buf,
+ size_t count, loff_t *ppos,
+ struct cache_detail *cd)
+{
+ char tbuf[22];
+ size_t len;
+
+ len = snprintf(tbuf, sizeof(tbuf), "%llu\n",
+ convert_to_wallclock(cd->flush_time));
+ return simple_read_from_buffer(buf, count, ppos, tbuf, len);
+}
+
+static ssize_t write_flush(struct file *file, const char __user *buf,
+ size_t count, loff_t *ppos,
+ struct cache_detail *cd)
+{
+ char tbuf[20];
+ char *ep;
+ time64_t now;
+
+ if (*ppos || count > sizeof(tbuf)-1)
+ return -EINVAL;
+ if (copy_from_user(tbuf, buf, count))
+ return -EFAULT;
+ tbuf[count] = 0;
+ simple_strtoul(tbuf, &ep, 0);
+ if (*ep && *ep != '\n')
+ return -EINVAL;
+ /* Note that while we check that 'buf' holds a valid number,
+ * we always ignore the value and just flush everything.
+ * Making use of the number leads to races.
+ */
+
+ now = seconds_since_boot();
+ /* Always flush everything, so behave like cache_purge()
+ * Do this by advancing flush_time to the current time,
+ * or by one second if it has already reached the current time.
+ * Newly added cache entries will always have ->last_refresh greater
+ * that ->flush_time, so they don't get flushed prematurely.
+ */
+
+ if (cd->flush_time >= now)
+ now = cd->flush_time + 1;
+
+ cd->flush_time = now;
+ cd->nextcheck = now;
+ cache_flush();
+
+ if (cd->flush)
+ cd->flush();
+
+ *ppos += count;
+ return count;
+}
+
+static ssize_t cache_read_procfs(struct file *filp, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = pde_data(file_inode(filp));
+
+ return cache_read(filp, buf, count, ppos, cd);
+}
+
+static ssize_t cache_write_procfs(struct file *filp, const char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = pde_data(file_inode(filp));
+
+ return cache_write(filp, buf, count, ppos, cd);
+}
+
+static __poll_t cache_poll_procfs(struct file *filp, poll_table *wait)
+{
+ struct cache_detail *cd = pde_data(file_inode(filp));
+
+ return cache_poll(filp, wait, cd);
+}
+
+static long cache_ioctl_procfs(struct file *filp,
+ unsigned int cmd, unsigned long arg)
+{
+ struct inode *inode = file_inode(filp);
+ struct cache_detail *cd = pde_data(inode);
+
+ return cache_ioctl(inode, filp, cmd, arg, cd);
+}
+
+static int cache_open_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return cache_open(inode, filp, cd);
+}
+
+static int cache_release_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return cache_release(inode, filp, cd);
+}
+
+static const struct proc_ops cache_channel_proc_ops = {
+ .proc_lseek = no_llseek,
+ .proc_read = cache_read_procfs,
+ .proc_write = cache_write_procfs,
+ .proc_poll = cache_poll_procfs,
+ .proc_ioctl = cache_ioctl_procfs, /* for FIONREAD */
+ .proc_open = cache_open_procfs,
+ .proc_release = cache_release_procfs,
+};
+
+static int content_open_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return content_open(inode, filp, cd);
+}
+
+static int content_release_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return content_release(inode, filp, cd);
+}
+
+static const struct proc_ops content_proc_ops = {
+ .proc_open = content_open_procfs,
+ .proc_read = seq_read,
+ .proc_lseek = seq_lseek,
+ .proc_release = content_release_procfs,
+};
+
+static int open_flush_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return open_flush(inode, filp, cd);
+}
+
+static int release_flush_procfs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = pde_data(inode);
+
+ return release_flush(inode, filp, cd);
+}
+
+static ssize_t read_flush_procfs(struct file *filp, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = pde_data(file_inode(filp));
+
+ return read_flush(filp, buf, count, ppos, cd);
+}
+
+static ssize_t write_flush_procfs(struct file *filp,
+ const char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = pde_data(file_inode(filp));
+
+ return write_flush(filp, buf, count, ppos, cd);
+}
+
+static const struct proc_ops cache_flush_proc_ops = {
+ .proc_open = open_flush_procfs,
+ .proc_read = read_flush_procfs,
+ .proc_write = write_flush_procfs,
+ .proc_release = release_flush_procfs,
+ .proc_lseek = no_llseek,
+};
+
+static void remove_cache_proc_entries(struct cache_detail *cd)
+{
+ if (cd->procfs) {
+ proc_remove(cd->procfs);
+ cd->procfs = NULL;
+ }
+}
+
+#ifdef CONFIG_PROC_FS
+static int create_cache_proc_entries(struct cache_detail *cd, struct net *net)
+{
+ struct proc_dir_entry *p;
+ struct sunrpc_net *sn;
+
+ sn = net_generic(net, sunrpc_net_id);
+ cd->procfs = proc_mkdir(cd->name, sn->proc_net_rpc);
+ if (cd->procfs == NULL)
+ goto out_nomem;
+
+ p = proc_create_data("flush", S_IFREG | 0600,
+ cd->procfs, &cache_flush_proc_ops, cd);
+ if (p == NULL)
+ goto out_nomem;
+
+ if (cd->cache_request || cd->cache_parse) {
+ p = proc_create_data("channel", S_IFREG | 0600, cd->procfs,
+ &cache_channel_proc_ops, cd);
+ if (p == NULL)
+ goto out_nomem;
+ }
+ if (cd->cache_show) {
+ p = proc_create_data("content", S_IFREG | 0400, cd->procfs,
+ &content_proc_ops, cd);
+ if (p == NULL)
+ goto out_nomem;
+ }
+ return 0;
+out_nomem:
+ remove_cache_proc_entries(cd);
+ return -ENOMEM;
+}
+#else /* CONFIG_PROC_FS */
+static int create_cache_proc_entries(struct cache_detail *cd, struct net *net)
+{
+ return 0;
+}
+#endif
+
+void __init cache_initialize(void)
+{
+ INIT_DEFERRABLE_WORK(&cache_cleaner, do_cache_clean);
+}
+
+int cache_register_net(struct cache_detail *cd, struct net *net)
+{
+ int ret;
+
+ sunrpc_init_cache_detail(cd);
+ ret = create_cache_proc_entries(cd, net);
+ if (ret)
+ sunrpc_destroy_cache_detail(cd);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(cache_register_net);
+
+void cache_unregister_net(struct cache_detail *cd, struct net *net)
+{
+ remove_cache_proc_entries(cd);
+ sunrpc_destroy_cache_detail(cd);
+}
+EXPORT_SYMBOL_GPL(cache_unregister_net);
+
+struct cache_detail *cache_create_net(const struct cache_detail *tmpl, struct net *net)
+{
+ struct cache_detail *cd;
+ int i;
+
+ cd = kmemdup(tmpl, sizeof(struct cache_detail), GFP_KERNEL);
+ if (cd == NULL)
+ return ERR_PTR(-ENOMEM);
+
+ cd->hash_table = kcalloc(cd->hash_size, sizeof(struct hlist_head),
+ GFP_KERNEL);
+ if (cd->hash_table == NULL) {
+ kfree(cd);
+ return ERR_PTR(-ENOMEM);
+ }
+
+ for (i = 0; i < cd->hash_size; i++)
+ INIT_HLIST_HEAD(&cd->hash_table[i]);
+ cd->net = net;
+ return cd;
+}
+EXPORT_SYMBOL_GPL(cache_create_net);
+
+void cache_destroy_net(struct cache_detail *cd, struct net *net)
+{
+ kfree(cd->hash_table);
+ kfree(cd);
+}
+EXPORT_SYMBOL_GPL(cache_destroy_net);
+
+static ssize_t cache_read_pipefs(struct file *filp, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = RPC_I(file_inode(filp))->private;
+
+ return cache_read(filp, buf, count, ppos, cd);
+}
+
+static ssize_t cache_write_pipefs(struct file *filp, const char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = RPC_I(file_inode(filp))->private;
+
+ return cache_write(filp, buf, count, ppos, cd);
+}
+
+static __poll_t cache_poll_pipefs(struct file *filp, poll_table *wait)
+{
+ struct cache_detail *cd = RPC_I(file_inode(filp))->private;
+
+ return cache_poll(filp, wait, cd);
+}
+
+static long cache_ioctl_pipefs(struct file *filp,
+ unsigned int cmd, unsigned long arg)
+{
+ struct inode *inode = file_inode(filp);
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return cache_ioctl(inode, filp, cmd, arg, cd);
+}
+
+static int cache_open_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return cache_open(inode, filp, cd);
+}
+
+static int cache_release_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return cache_release(inode, filp, cd);
+}
+
+const struct file_operations cache_file_operations_pipefs = {
+ .owner = THIS_MODULE,
+ .llseek = no_llseek,
+ .read = cache_read_pipefs,
+ .write = cache_write_pipefs,
+ .poll = cache_poll_pipefs,
+ .unlocked_ioctl = cache_ioctl_pipefs, /* for FIONREAD */
+ .open = cache_open_pipefs,
+ .release = cache_release_pipefs,
+};
+
+static int content_open_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return content_open(inode, filp, cd);
+}
+
+static int content_release_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return content_release(inode, filp, cd);
+}
+
+const struct file_operations content_file_operations_pipefs = {
+ .open = content_open_pipefs,
+ .read = seq_read,
+ .llseek = seq_lseek,
+ .release = content_release_pipefs,
+};
+
+static int open_flush_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return open_flush(inode, filp, cd);
+}
+
+static int release_flush_pipefs(struct inode *inode, struct file *filp)
+{
+ struct cache_detail *cd = RPC_I(inode)->private;
+
+ return release_flush(inode, filp, cd);
+}
+
+static ssize_t read_flush_pipefs(struct file *filp, char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = RPC_I(file_inode(filp))->private;
+
+ return read_flush(filp, buf, count, ppos, cd);
+}
+
+static ssize_t write_flush_pipefs(struct file *filp,
+ const char __user *buf,
+ size_t count, loff_t *ppos)
+{
+ struct cache_detail *cd = RPC_I(file_inode(filp))->private;
+
+ return write_flush(filp, buf, count, ppos, cd);
+}
+
+const struct file_operations cache_flush_operations_pipefs = {
+ .open = open_flush_pipefs,
+ .read = read_flush_pipefs,
+ .write = write_flush_pipefs,
+ .release = release_flush_pipefs,
+ .llseek = no_llseek,
+};
+
+int sunrpc_cache_register_pipefs(struct dentry *parent,
+ const char *name, umode_t umode,
+ struct cache_detail *cd)
+{
+ struct dentry *dir = rpc_create_cache_dir(parent, name, umode, cd);
+ if (IS_ERR(dir))
+ return PTR_ERR(dir);
+ cd->pipefs = dir;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_register_pipefs);
+
+void sunrpc_cache_unregister_pipefs(struct cache_detail *cd)
+{
+ if (cd->pipefs) {
+ rpc_remove_cache_dir(cd->pipefs);
+ cd->pipefs = NULL;
+ }
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs);
+
+void sunrpc_cache_unhash(struct cache_detail *cd, struct cache_head *h)
+{
+ spin_lock(&cd->hash_lock);
+ if (!hlist_unhashed(&h->cache_list)){
+ sunrpc_begin_cache_remove_entry(h, cd);
+ spin_unlock(&cd->hash_lock);
+ sunrpc_end_cache_remove_entry(h, cd);
+ } else
+ spin_unlock(&cd->hash_lock);
+}
+EXPORT_SYMBOL_GPL(sunrpc_cache_unhash);
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
new file mode 100644
index 0000000000..339dfc5b92
--- /dev/null
+++ b/net/sunrpc/clnt.c
@@ -0,0 +1,3395 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/clnt.c
+ *
+ * This file contains the high-level RPC interface.
+ * It is modeled as a finite state machine to support both synchronous
+ * and asynchronous requests.
+ *
+ * - RPC header generation and argument serialization.
+ * - Credential refresh.
+ * - TCP connect handling.
+ * - Retry of operation when it is suspected the operation failed because
+ * of uid squashing on the server, or when the credentials were stale
+ * and need to be refreshed, or when a packet was damaged in transit.
+ * This may be have to be moved to the VFS layer.
+ *
+ * Copyright (C) 1992,1993 Rick Sladkey <jrs@world.std.com>
+ * Copyright (C) 1995,1996 Olaf Kirch <okir@monad.swb.de>
+ */
+
+
+#include <linux/module.h>
+#include <linux/types.h>
+#include <linux/kallsyms.h>
+#include <linux/mm.h>
+#include <linux/namei.h>
+#include <linux/mount.h>
+#include <linux/slab.h>
+#include <linux/rcupdate.h>
+#include <linux/utsname.h>
+#include <linux/workqueue.h>
+#include <linux/in.h>
+#include <linux/in6.h>
+#include <linux/un.h>
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/rpc_pipe_fs.h>
+#include <linux/sunrpc/metrics.h>
+#include <linux/sunrpc/bc_xprt.h>
+#include <trace/events/sunrpc.h>
+
+#include "sunrpc.h"
+#include "sysfs.h"
+#include "netns.h"
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_CALL
+#endif
+
+/*
+ * All RPC clients are linked into this list
+ */
+
+static DECLARE_WAIT_QUEUE_HEAD(destroy_wait);
+
+
+static void call_start(struct rpc_task *task);
+static void call_reserve(struct rpc_task *task);
+static void call_reserveresult(struct rpc_task *task);
+static void call_allocate(struct rpc_task *task);
+static void call_encode(struct rpc_task *task);
+static void call_decode(struct rpc_task *task);
+static void call_bind(struct rpc_task *task);
+static void call_bind_status(struct rpc_task *task);
+static void call_transmit(struct rpc_task *task);
+static void call_status(struct rpc_task *task);
+static void call_transmit_status(struct rpc_task *task);
+static void call_refresh(struct rpc_task *task);
+static void call_refreshresult(struct rpc_task *task);
+static void call_connect(struct rpc_task *task);
+static void call_connect_status(struct rpc_task *task);
+
+static int rpc_encode_header(struct rpc_task *task,
+ struct xdr_stream *xdr);
+static int rpc_decode_header(struct rpc_task *task,
+ struct xdr_stream *xdr);
+static int rpc_ping(struct rpc_clnt *clnt);
+static int rpc_ping_noreply(struct rpc_clnt *clnt);
+static void rpc_check_timeout(struct rpc_task *task);
+
+static void rpc_register_client(struct rpc_clnt *clnt)
+{
+ struct net *net = rpc_net_ns(clnt);
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ spin_lock(&sn->rpc_client_lock);
+ list_add(&clnt->cl_clients, &sn->all_clients);
+ spin_unlock(&sn->rpc_client_lock);
+}
+
+static void rpc_unregister_client(struct rpc_clnt *clnt)
+{
+ struct net *net = rpc_net_ns(clnt);
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ spin_lock(&sn->rpc_client_lock);
+ list_del(&clnt->cl_clients);
+ spin_unlock(&sn->rpc_client_lock);
+}
+
+static void __rpc_clnt_remove_pipedir(struct rpc_clnt *clnt)
+{
+ rpc_remove_client_dir(clnt);
+}
+
+static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt)
+{
+ struct net *net = rpc_net_ns(clnt);
+ struct super_block *pipefs_sb;
+
+ pipefs_sb = rpc_get_sb_net(net);
+ if (pipefs_sb) {
+ if (pipefs_sb == clnt->pipefs_sb)
+ __rpc_clnt_remove_pipedir(clnt);
+ rpc_put_sb_net(net);
+ }
+}
+
+static struct dentry *rpc_setup_pipedir_sb(struct super_block *sb,
+ struct rpc_clnt *clnt)
+{
+ static uint32_t clntid;
+ const char *dir_name = clnt->cl_program->pipe_dir_name;
+ char name[15];
+ struct dentry *dir, *dentry;
+
+ dir = rpc_d_lookup_sb(sb, dir_name);
+ if (dir == NULL) {
+ pr_info("RPC: pipefs directory doesn't exist: %s\n", dir_name);
+ return dir;
+ }
+ for (;;) {
+ snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++);
+ name[sizeof(name) - 1] = '\0';
+ dentry = rpc_create_client_dir(dir, name, clnt);
+ if (!IS_ERR(dentry))
+ break;
+ if (dentry == ERR_PTR(-EEXIST))
+ continue;
+ printk(KERN_INFO "RPC: Couldn't create pipefs entry"
+ " %s/%s, error %ld\n",
+ dir_name, name, PTR_ERR(dentry));
+ break;
+ }
+ dput(dir);
+ return dentry;
+}
+
+static int
+rpc_setup_pipedir(struct super_block *pipefs_sb, struct rpc_clnt *clnt)
+{
+ struct dentry *dentry;
+
+ clnt->pipefs_sb = pipefs_sb;
+
+ if (clnt->cl_program->pipe_dir_name != NULL) {
+ dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt);
+ if (IS_ERR(dentry))
+ return PTR_ERR(dentry);
+ }
+ return 0;
+}
+
+static int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event)
+{
+ if (clnt->cl_program->pipe_dir_name == NULL)
+ return 1;
+
+ switch (event) {
+ case RPC_PIPEFS_MOUNT:
+ if (clnt->cl_pipedir_objects.pdh_dentry != NULL)
+ return 1;
+ if (refcount_read(&clnt->cl_count) == 0)
+ return 1;
+ break;
+ case RPC_PIPEFS_UMOUNT:
+ if (clnt->cl_pipedir_objects.pdh_dentry == NULL)
+ return 1;
+ break;
+ }
+ return 0;
+}
+
+static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event,
+ struct super_block *sb)
+{
+ struct dentry *dentry;
+
+ switch (event) {
+ case RPC_PIPEFS_MOUNT:
+ dentry = rpc_setup_pipedir_sb(sb, clnt);
+ if (!dentry)
+ return -ENOENT;
+ if (IS_ERR(dentry))
+ return PTR_ERR(dentry);
+ break;
+ case RPC_PIPEFS_UMOUNT:
+ __rpc_clnt_remove_pipedir(clnt);
+ break;
+ default:
+ printk(KERN_ERR "%s: unknown event: %ld\n", __func__, event);
+ return -ENOTSUPP;
+ }
+ return 0;
+}
+
+static int __rpc_pipefs_event(struct rpc_clnt *clnt, unsigned long event,
+ struct super_block *sb)
+{
+ int error = 0;
+
+ for (;; clnt = clnt->cl_parent) {
+ if (!rpc_clnt_skip_event(clnt, event))
+ error = __rpc_clnt_handle_event(clnt, event, sb);
+ if (error || clnt == clnt->cl_parent)
+ break;
+ }
+ return error;
+}
+
+static struct rpc_clnt *rpc_get_client_for_event(struct net *net, int event)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_clnt *clnt;
+
+ spin_lock(&sn->rpc_client_lock);
+ list_for_each_entry(clnt, &sn->all_clients, cl_clients) {
+ if (rpc_clnt_skip_event(clnt, event))
+ continue;
+ spin_unlock(&sn->rpc_client_lock);
+ return clnt;
+ }
+ spin_unlock(&sn->rpc_client_lock);
+ return NULL;
+}
+
+static int rpc_pipefs_event(struct notifier_block *nb, unsigned long event,
+ void *ptr)
+{
+ struct super_block *sb = ptr;
+ struct rpc_clnt *clnt;
+ int error = 0;
+
+ while ((clnt = rpc_get_client_for_event(sb->s_fs_info, event))) {
+ error = __rpc_pipefs_event(clnt, event, sb);
+ if (error)
+ break;
+ }
+ return error;
+}
+
+static struct notifier_block rpc_clients_block = {
+ .notifier_call = rpc_pipefs_event,
+ .priority = SUNRPC_PIPEFS_RPC_PRIO,
+};
+
+int rpc_clients_notifier_register(void)
+{
+ return rpc_pipefs_notifier_register(&rpc_clients_block);
+}
+
+void rpc_clients_notifier_unregister(void)
+{
+ return rpc_pipefs_notifier_unregister(&rpc_clients_block);
+}
+
+static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ const struct rpc_timeout *timeout)
+{
+ struct rpc_xprt *old;
+
+ spin_lock(&clnt->cl_lock);
+ old = rcu_dereference_protected(clnt->cl_xprt,
+ lockdep_is_held(&clnt->cl_lock));
+
+ if (!xprt_bound(xprt))
+ clnt->cl_autobind = 1;
+
+ clnt->cl_timeout = timeout;
+ rcu_assign_pointer(clnt->cl_xprt, xprt);
+ spin_unlock(&clnt->cl_lock);
+
+ return old;
+}
+
+static void rpc_clnt_set_nodename(struct rpc_clnt *clnt, const char *nodename)
+{
+ clnt->cl_nodelen = strlcpy(clnt->cl_nodename,
+ nodename, sizeof(clnt->cl_nodename));
+}
+
+static int rpc_client_register(struct rpc_clnt *clnt,
+ rpc_authflavor_t pseudoflavor,
+ const char *client_name)
+{
+ struct rpc_auth_create_args auth_args = {
+ .pseudoflavor = pseudoflavor,
+ .target_name = client_name,
+ };
+ struct rpc_auth *auth;
+ struct net *net = rpc_net_ns(clnt);
+ struct super_block *pipefs_sb;
+ int err;
+
+ rpc_clnt_debugfs_register(clnt);
+
+ pipefs_sb = rpc_get_sb_net(net);
+ if (pipefs_sb) {
+ err = rpc_setup_pipedir(pipefs_sb, clnt);
+ if (err)
+ goto out;
+ }
+
+ rpc_register_client(clnt);
+ if (pipefs_sb)
+ rpc_put_sb_net(net);
+
+ auth = rpcauth_create(&auth_args, clnt);
+ if (IS_ERR(auth)) {
+ dprintk("RPC: Couldn't create auth handle (flavor %u)\n",
+ pseudoflavor);
+ err = PTR_ERR(auth);
+ goto err_auth;
+ }
+ return 0;
+err_auth:
+ pipefs_sb = rpc_get_sb_net(net);
+ rpc_unregister_client(clnt);
+ __rpc_clnt_remove_pipedir(clnt);
+out:
+ if (pipefs_sb)
+ rpc_put_sb_net(net);
+ rpc_sysfs_client_destroy(clnt);
+ rpc_clnt_debugfs_unregister(clnt);
+ return err;
+}
+
+static DEFINE_IDA(rpc_clids);
+
+void rpc_cleanup_clids(void)
+{
+ ida_destroy(&rpc_clids);
+}
+
+static int rpc_alloc_clid(struct rpc_clnt *clnt)
+{
+ int clid;
+
+ clid = ida_alloc(&rpc_clids, GFP_KERNEL);
+ if (clid < 0)
+ return clid;
+ clnt->cl_clid = clid;
+ return 0;
+}
+
+static void rpc_free_clid(struct rpc_clnt *clnt)
+{
+ ida_free(&rpc_clids, clnt->cl_clid);
+}
+
+static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args,
+ struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt,
+ struct rpc_clnt *parent)
+{
+ const struct rpc_program *program = args->program;
+ const struct rpc_version *version;
+ struct rpc_clnt *clnt = NULL;
+ const struct rpc_timeout *timeout;
+ const char *nodename = args->nodename;
+ int err;
+
+ err = rpciod_up();
+ if (err)
+ goto out_no_rpciod;
+
+ err = -EINVAL;
+ if (args->version >= program->nrvers)
+ goto out_err;
+ version = program->version[args->version];
+ if (version == NULL)
+ goto out_err;
+
+ err = -ENOMEM;
+ clnt = kzalloc(sizeof(*clnt), GFP_KERNEL);
+ if (!clnt)
+ goto out_err;
+ clnt->cl_parent = parent ? : clnt;
+ clnt->cl_xprtsec = args->xprtsec;
+
+ err = rpc_alloc_clid(clnt);
+ if (err)
+ goto out_no_clid;
+
+ clnt->cl_cred = get_cred(args->cred);
+ clnt->cl_procinfo = version->procs;
+ clnt->cl_maxproc = version->nrprocs;
+ clnt->cl_prog = args->prognumber ? : program->number;
+ clnt->cl_vers = version->number;
+ clnt->cl_stats = program->stats;
+ clnt->cl_metrics = rpc_alloc_iostats(clnt);
+ rpc_init_pipe_dir_head(&clnt->cl_pipedir_objects);
+ err = -ENOMEM;
+ if (clnt->cl_metrics == NULL)
+ goto out_no_stats;
+ clnt->cl_program = program;
+ INIT_LIST_HEAD(&clnt->cl_tasks);
+ spin_lock_init(&clnt->cl_lock);
+
+ timeout = xprt->timeout;
+ if (args->timeout != NULL) {
+ memcpy(&clnt->cl_timeout_default, args->timeout,
+ sizeof(clnt->cl_timeout_default));
+ timeout = &clnt->cl_timeout_default;
+ }
+
+ rpc_clnt_set_transport(clnt, xprt, timeout);
+ xprt->main = true;
+ xprt_iter_init(&clnt->cl_xpi, xps);
+ xprt_switch_put(xps);
+
+ clnt->cl_rtt = &clnt->cl_rtt_default;
+ rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval);
+
+ refcount_set(&clnt->cl_count, 1);
+
+ if (nodename == NULL)
+ nodename = utsname()->nodename;
+ /* save the nodename */
+ rpc_clnt_set_nodename(clnt, nodename);
+
+ rpc_sysfs_client_setup(clnt, xps, rpc_net_ns(clnt));
+ err = rpc_client_register(clnt, args->authflavor, args->client_name);
+ if (err)
+ goto out_no_path;
+ if (parent)
+ refcount_inc(&parent->cl_count);
+
+ trace_rpc_clnt_new(clnt, xprt, args);
+ return clnt;
+
+out_no_path:
+ rpc_free_iostats(clnt->cl_metrics);
+out_no_stats:
+ put_cred(clnt->cl_cred);
+ rpc_free_clid(clnt);
+out_no_clid:
+ kfree(clnt);
+out_err:
+ rpciod_down();
+out_no_rpciod:
+ xprt_switch_put(xps);
+ xprt_put(xprt);
+ trace_rpc_clnt_new_err(program->name, args->servername, err);
+ return ERR_PTR(err);
+}
+
+static struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args,
+ struct rpc_xprt *xprt)
+{
+ struct rpc_clnt *clnt = NULL;
+ struct rpc_xprt_switch *xps;
+
+ if (args->bc_xprt && args->bc_xprt->xpt_bc_xps) {
+ WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC));
+ xps = args->bc_xprt->xpt_bc_xps;
+ xprt_switch_get(xps);
+ } else {
+ xps = xprt_switch_alloc(xprt, GFP_KERNEL);
+ if (xps == NULL) {
+ xprt_put(xprt);
+ return ERR_PTR(-ENOMEM);
+ }
+ if (xprt->bc_xprt) {
+ xprt_switch_get(xps);
+ xprt->bc_xprt->xpt_bc_xps = xps;
+ }
+ }
+ clnt = rpc_new_client(args, xps, xprt, NULL);
+ if (IS_ERR(clnt))
+ return clnt;
+
+ if (!(args->flags & RPC_CLNT_CREATE_NOPING)) {
+ int err = rpc_ping(clnt);
+ if (err != 0) {
+ rpc_shutdown_client(clnt);
+ return ERR_PTR(err);
+ }
+ } else if (args->flags & RPC_CLNT_CREATE_CONNECTED) {
+ int err = rpc_ping_noreply(clnt);
+ if (err != 0) {
+ rpc_shutdown_client(clnt);
+ return ERR_PTR(err);
+ }
+ }
+
+ clnt->cl_softrtry = 1;
+ if (args->flags & (RPC_CLNT_CREATE_HARDRTRY|RPC_CLNT_CREATE_SOFTERR)) {
+ clnt->cl_softrtry = 0;
+ if (args->flags & RPC_CLNT_CREATE_SOFTERR)
+ clnt->cl_softerr = 1;
+ }
+
+ if (args->flags & RPC_CLNT_CREATE_AUTOBIND)
+ clnt->cl_autobind = 1;
+ if (args->flags & RPC_CLNT_CREATE_NO_RETRANS_TIMEOUT)
+ clnt->cl_noretranstimeo = 1;
+ if (args->flags & RPC_CLNT_CREATE_DISCRTRY)
+ clnt->cl_discrtry = 1;
+ if (!(args->flags & RPC_CLNT_CREATE_QUIET))
+ clnt->cl_chatty = 1;
+
+ return clnt;
+}
+
+/**
+ * rpc_create - create an RPC client and transport with one call
+ * @args: rpc_clnt create argument structure
+ *
+ * Creates and initializes an RPC transport and an RPC client.
+ *
+ * It can ping the server in order to determine if it is up, and to see if
+ * it supports this program and version. RPC_CLNT_CREATE_NOPING disables
+ * this behavior so asynchronous tasks can also use rpc_create.
+ */
+struct rpc_clnt *rpc_create(struct rpc_create_args *args)
+{
+ struct rpc_xprt *xprt;
+ struct xprt_create xprtargs = {
+ .net = args->net,
+ .ident = args->protocol,
+ .srcaddr = args->saddress,
+ .dstaddr = args->address,
+ .addrlen = args->addrsize,
+ .servername = args->servername,
+ .bc_xprt = args->bc_xprt,
+ .xprtsec = args->xprtsec,
+ .connect_timeout = args->connect_timeout,
+ .reconnect_timeout = args->reconnect_timeout,
+ };
+ char servername[48];
+ struct rpc_clnt *clnt;
+ int i;
+
+ if (args->bc_xprt) {
+ WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC));
+ xprt = args->bc_xprt->xpt_bc_xprt;
+ if (xprt) {
+ xprt_get(xprt);
+ return rpc_create_xprt(args, xprt);
+ }
+ }
+
+ if (args->flags & RPC_CLNT_CREATE_INFINITE_SLOTS)
+ xprtargs.flags |= XPRT_CREATE_INFINITE_SLOTS;
+ if (args->flags & RPC_CLNT_CREATE_NO_IDLE_TIMEOUT)
+ xprtargs.flags |= XPRT_CREATE_NO_IDLE_TIMEOUT;
+ /*
+ * If the caller chooses not to specify a hostname, whip
+ * up a string representation of the passed-in address.
+ */
+ if (xprtargs.servername == NULL) {
+ struct sockaddr_un *sun =
+ (struct sockaddr_un *)args->address;
+ struct sockaddr_in *sin =
+ (struct sockaddr_in *)args->address;
+ struct sockaddr_in6 *sin6 =
+ (struct sockaddr_in6 *)args->address;
+
+ servername[0] = '\0';
+ switch (args->address->sa_family) {
+ case AF_LOCAL:
+ if (sun->sun_path[0])
+ snprintf(servername, sizeof(servername), "%s",
+ sun->sun_path);
+ else
+ snprintf(servername, sizeof(servername), "@%s",
+ sun->sun_path+1);
+ break;
+ case AF_INET:
+ snprintf(servername, sizeof(servername), "%pI4",
+ &sin->sin_addr.s_addr);
+ break;
+ case AF_INET6:
+ snprintf(servername, sizeof(servername), "%pI6",
+ &sin6->sin6_addr);
+ break;
+ default:
+ /* caller wants default server name, but
+ * address family isn't recognized. */
+ return ERR_PTR(-EINVAL);
+ }
+ xprtargs.servername = servername;
+ }
+
+ xprt = xprt_create_transport(&xprtargs);
+ if (IS_ERR(xprt))
+ return (struct rpc_clnt *)xprt;
+
+ /*
+ * By default, kernel RPC client connects from a reserved port.
+ * CAP_NET_BIND_SERVICE will not be set for unprivileged requesters,
+ * but it is always enabled for rpciod, which handles the connect
+ * operation.
+ */
+ xprt->resvport = 1;
+ if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT)
+ xprt->resvport = 0;
+ xprt->reuseport = 0;
+ if (args->flags & RPC_CLNT_CREATE_REUSEPORT)
+ xprt->reuseport = 1;
+
+ clnt = rpc_create_xprt(args, xprt);
+ if (IS_ERR(clnt) || args->nconnect <= 1)
+ return clnt;
+
+ for (i = 0; i < args->nconnect - 1; i++) {
+ if (rpc_clnt_add_xprt(clnt, &xprtargs, NULL, NULL) < 0)
+ break;
+ }
+ return clnt;
+}
+EXPORT_SYMBOL_GPL(rpc_create);
+
+/*
+ * This function clones the RPC client structure. It allows us to share the
+ * same transport while varying parameters such as the authentication
+ * flavour.
+ */
+static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args,
+ struct rpc_clnt *clnt)
+{
+ struct rpc_xprt_switch *xps;
+ struct rpc_xprt *xprt;
+ struct rpc_clnt *new;
+ int err;
+
+ err = -ENOMEM;
+ rcu_read_lock();
+ xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ rcu_read_unlock();
+ if (xprt == NULL || xps == NULL) {
+ xprt_put(xprt);
+ xprt_switch_put(xps);
+ goto out_err;
+ }
+ args->servername = xprt->servername;
+ args->nodename = clnt->cl_nodename;
+
+ new = rpc_new_client(args, xps, xprt, clnt);
+ if (IS_ERR(new))
+ return new;
+
+ /* Turn off autobind on clones */
+ new->cl_autobind = 0;
+ new->cl_softrtry = clnt->cl_softrtry;
+ new->cl_softerr = clnt->cl_softerr;
+ new->cl_noretranstimeo = clnt->cl_noretranstimeo;
+ new->cl_discrtry = clnt->cl_discrtry;
+ new->cl_chatty = clnt->cl_chatty;
+ new->cl_principal = clnt->cl_principal;
+ new->cl_max_connect = clnt->cl_max_connect;
+ return new;
+
+out_err:
+ trace_rpc_clnt_clone_err(clnt, err);
+ return ERR_PTR(err);
+}
+
+/**
+ * rpc_clone_client - Clone an RPC client structure
+ *
+ * @clnt: RPC client whose parameters are copied
+ *
+ * Returns a fresh RPC client or an ERR_PTR.
+ */
+struct rpc_clnt *rpc_clone_client(struct rpc_clnt *clnt)
+{
+ struct rpc_create_args args = {
+ .program = clnt->cl_program,
+ .prognumber = clnt->cl_prog,
+ .version = clnt->cl_vers,
+ .authflavor = clnt->cl_auth->au_flavor,
+ .cred = clnt->cl_cred,
+ };
+ return __rpc_clone_client(&args, clnt);
+}
+EXPORT_SYMBOL_GPL(rpc_clone_client);
+
+/**
+ * rpc_clone_client_set_auth - Clone an RPC client structure and set its auth
+ *
+ * @clnt: RPC client whose parameters are copied
+ * @flavor: security flavor for new client
+ *
+ * Returns a fresh RPC client or an ERR_PTR.
+ */
+struct rpc_clnt *
+rpc_clone_client_set_auth(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
+{
+ struct rpc_create_args args = {
+ .program = clnt->cl_program,
+ .prognumber = clnt->cl_prog,
+ .version = clnt->cl_vers,
+ .authflavor = flavor,
+ .cred = clnt->cl_cred,
+ };
+ return __rpc_clone_client(&args, clnt);
+}
+EXPORT_SYMBOL_GPL(rpc_clone_client_set_auth);
+
+/**
+ * rpc_switch_client_transport: switch the RPC transport on the fly
+ * @clnt: pointer to a struct rpc_clnt
+ * @args: pointer to the new transport arguments
+ * @timeout: pointer to the new timeout parameters
+ *
+ * This function allows the caller to switch the RPC transport for the
+ * rpc_clnt structure 'clnt' to allow it to connect to a mirrored NFS
+ * server, for instance. It assumes that the caller has ensured that
+ * there are no active RPC tasks by using some form of locking.
+ *
+ * Returns zero if "clnt" is now using the new xprt. Otherwise a
+ * negative errno is returned, and "clnt" continues to use the old
+ * xprt.
+ */
+int rpc_switch_client_transport(struct rpc_clnt *clnt,
+ struct xprt_create *args,
+ const struct rpc_timeout *timeout)
+{
+ const struct rpc_timeout *old_timeo;
+ rpc_authflavor_t pseudoflavor;
+ struct rpc_xprt_switch *xps, *oldxps;
+ struct rpc_xprt *xprt, *old;
+ struct rpc_clnt *parent;
+ int err;
+
+ args->xprtsec = clnt->cl_xprtsec;
+ xprt = xprt_create_transport(args);
+ if (IS_ERR(xprt))
+ return PTR_ERR(xprt);
+
+ xps = xprt_switch_alloc(xprt, GFP_KERNEL);
+ if (xps == NULL) {
+ xprt_put(xprt);
+ return -ENOMEM;
+ }
+
+ pseudoflavor = clnt->cl_auth->au_flavor;
+
+ old_timeo = clnt->cl_timeout;
+ old = rpc_clnt_set_transport(clnt, xprt, timeout);
+ oldxps = xprt_iter_xchg_switch(&clnt->cl_xpi, xps);
+
+ rpc_unregister_client(clnt);
+ __rpc_clnt_remove_pipedir(clnt);
+ rpc_sysfs_client_destroy(clnt);
+ rpc_clnt_debugfs_unregister(clnt);
+
+ /*
+ * A new transport was created. "clnt" therefore
+ * becomes the root of a new cl_parent tree. clnt's
+ * children, if it has any, still point to the old xprt.
+ */
+ parent = clnt->cl_parent;
+ clnt->cl_parent = clnt;
+
+ /*
+ * The old rpc_auth cache cannot be re-used. GSS
+ * contexts in particular are between a single
+ * client and server.
+ */
+ err = rpc_client_register(clnt, pseudoflavor, NULL);
+ if (err)
+ goto out_revert;
+
+ synchronize_rcu();
+ if (parent != clnt)
+ rpc_release_client(parent);
+ xprt_switch_put(oldxps);
+ xprt_put(old);
+ trace_rpc_clnt_replace_xprt(clnt);
+ return 0;
+
+out_revert:
+ xps = xprt_iter_xchg_switch(&clnt->cl_xpi, oldxps);
+ rpc_clnt_set_transport(clnt, old, old_timeo);
+ clnt->cl_parent = parent;
+ rpc_client_register(clnt, pseudoflavor, NULL);
+ xprt_switch_put(xps);
+ xprt_put(xprt);
+ trace_rpc_clnt_replace_xprt_err(clnt);
+ return err;
+}
+EXPORT_SYMBOL_GPL(rpc_switch_client_transport);
+
+static
+int _rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi,
+ void func(struct rpc_xprt_iter *xpi, struct rpc_xprt_switch *xps))
+{
+ struct rpc_xprt_switch *xps;
+
+ rcu_read_lock();
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ rcu_read_unlock();
+ if (xps == NULL)
+ return -EAGAIN;
+ func(xpi, xps);
+ xprt_switch_put(xps);
+ return 0;
+}
+
+static
+int rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi)
+{
+ return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listall);
+}
+
+static
+int rpc_clnt_xprt_iter_offline_init(struct rpc_clnt *clnt,
+ struct rpc_xprt_iter *xpi)
+{
+ return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listoffline);
+}
+
+/**
+ * rpc_clnt_iterate_for_each_xprt - Apply a function to all transports
+ * @clnt: pointer to client
+ * @fn: function to apply
+ * @data: void pointer to function data
+ *
+ * Iterates through the list of RPC transports currently attached to the
+ * client and applies the function fn(clnt, xprt, data).
+ *
+ * On error, the iteration stops, and the function returns the error value.
+ */
+int rpc_clnt_iterate_for_each_xprt(struct rpc_clnt *clnt,
+ int (*fn)(struct rpc_clnt *, struct rpc_xprt *, void *),
+ void *data)
+{
+ struct rpc_xprt_iter xpi;
+ int ret;
+
+ ret = rpc_clnt_xprt_iter_init(clnt, &xpi);
+ if (ret)
+ return ret;
+ for (;;) {
+ struct rpc_xprt *xprt = xprt_iter_get_next(&xpi);
+
+ if (!xprt)
+ break;
+ ret = fn(clnt, xprt, data);
+ xprt_put(xprt);
+ if (ret < 0)
+ break;
+ }
+ xprt_iter_destroy(&xpi);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_iterate_for_each_xprt);
+
+/*
+ * Kill all tasks for the given client.
+ * XXX: kill their descendants as well?
+ */
+void rpc_killall_tasks(struct rpc_clnt *clnt)
+{
+ struct rpc_task *rovr;
+
+
+ if (list_empty(&clnt->cl_tasks))
+ return;
+
+ /*
+ * Spin lock all_tasks to prevent changes...
+ */
+ trace_rpc_clnt_killall(clnt);
+ spin_lock(&clnt->cl_lock);
+ list_for_each_entry(rovr, &clnt->cl_tasks, tk_task)
+ rpc_signal_task(rovr);
+ spin_unlock(&clnt->cl_lock);
+}
+EXPORT_SYMBOL_GPL(rpc_killall_tasks);
+
+/**
+ * rpc_cancel_tasks - try to cancel a set of RPC tasks
+ * @clnt: Pointer to RPC client
+ * @error: RPC task error value to set
+ * @fnmatch: Pointer to selector function
+ * @data: User data
+ *
+ * Uses @fnmatch to define a set of RPC tasks that are to be cancelled.
+ * The argument @error must be a negative error value.
+ */
+unsigned long rpc_cancel_tasks(struct rpc_clnt *clnt, int error,
+ bool (*fnmatch)(const struct rpc_task *,
+ const void *),
+ const void *data)
+{
+ struct rpc_task *task;
+ unsigned long count = 0;
+
+ if (list_empty(&clnt->cl_tasks))
+ return 0;
+ /*
+ * Spin lock all_tasks to prevent changes...
+ */
+ spin_lock(&clnt->cl_lock);
+ list_for_each_entry(task, &clnt->cl_tasks, tk_task) {
+ if (!RPC_IS_ACTIVATED(task))
+ continue;
+ if (!fnmatch(task, data))
+ continue;
+ rpc_task_try_cancel(task, error);
+ count++;
+ }
+ spin_unlock(&clnt->cl_lock);
+ return count;
+}
+EXPORT_SYMBOL_GPL(rpc_cancel_tasks);
+
+static int rpc_clnt_disconnect_xprt(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt, void *dummy)
+{
+ if (xprt_connected(xprt))
+ xprt_force_disconnect(xprt);
+ return 0;
+}
+
+void rpc_clnt_disconnect(struct rpc_clnt *clnt)
+{
+ rpc_clnt_iterate_for_each_xprt(clnt, rpc_clnt_disconnect_xprt, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_disconnect);
+
+/*
+ * Properly shut down an RPC client, terminating all outstanding
+ * requests.
+ */
+void rpc_shutdown_client(struct rpc_clnt *clnt)
+{
+ might_sleep();
+
+ trace_rpc_clnt_shutdown(clnt);
+
+ while (!list_empty(&clnt->cl_tasks)) {
+ rpc_killall_tasks(clnt);
+ wait_event_timeout(destroy_wait,
+ list_empty(&clnt->cl_tasks), 1*HZ);
+ }
+
+ rpc_release_client(clnt);
+}
+EXPORT_SYMBOL_GPL(rpc_shutdown_client);
+
+/*
+ * Free an RPC client
+ */
+static void rpc_free_client_work(struct work_struct *work)
+{
+ struct rpc_clnt *clnt = container_of(work, struct rpc_clnt, cl_work);
+
+ trace_rpc_clnt_free(clnt);
+
+ /* These might block on processes that might allocate memory,
+ * so they cannot be called in rpciod, so they are handled separately
+ * here.
+ */
+ rpc_sysfs_client_destroy(clnt);
+ rpc_clnt_debugfs_unregister(clnt);
+ rpc_free_clid(clnt);
+ rpc_clnt_remove_pipedir(clnt);
+ xprt_put(rcu_dereference_raw(clnt->cl_xprt));
+
+ kfree(clnt);
+ rpciod_down();
+}
+static struct rpc_clnt *
+rpc_free_client(struct rpc_clnt *clnt)
+{
+ struct rpc_clnt *parent = NULL;
+
+ trace_rpc_clnt_release(clnt);
+ if (clnt->cl_parent != clnt)
+ parent = clnt->cl_parent;
+ rpc_unregister_client(clnt);
+ rpc_free_iostats(clnt->cl_metrics);
+ clnt->cl_metrics = NULL;
+ xprt_iter_destroy(&clnt->cl_xpi);
+ put_cred(clnt->cl_cred);
+
+ INIT_WORK(&clnt->cl_work, rpc_free_client_work);
+ schedule_work(&clnt->cl_work);
+ return parent;
+}
+
+/*
+ * Free an RPC client
+ */
+static struct rpc_clnt *
+rpc_free_auth(struct rpc_clnt *clnt)
+{
+ /*
+ * Note: RPCSEC_GSS may need to send NULL RPC calls in order to
+ * release remaining GSS contexts. This mechanism ensures
+ * that it can do so safely.
+ */
+ if (clnt->cl_auth != NULL) {
+ rpcauth_release(clnt->cl_auth);
+ clnt->cl_auth = NULL;
+ }
+ if (refcount_dec_and_test(&clnt->cl_count))
+ return rpc_free_client(clnt);
+ return NULL;
+}
+
+/*
+ * Release reference to the RPC client
+ */
+void
+rpc_release_client(struct rpc_clnt *clnt)
+{
+ do {
+ if (list_empty(&clnt->cl_tasks))
+ wake_up(&destroy_wait);
+ if (refcount_dec_not_one(&clnt->cl_count))
+ break;
+ clnt = rpc_free_auth(clnt);
+ } while (clnt != NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_release_client);
+
+/**
+ * rpc_bind_new_program - bind a new RPC program to an existing client
+ * @old: old rpc_client
+ * @program: rpc program to set
+ * @vers: rpc program version
+ *
+ * Clones the rpc client and sets up a new RPC program. This is mainly
+ * of use for enabling different RPC programs to share the same transport.
+ * The Sun NFSv2/v3 ACL protocol can do this.
+ */
+struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old,
+ const struct rpc_program *program,
+ u32 vers)
+{
+ struct rpc_create_args args = {
+ .program = program,
+ .prognumber = program->number,
+ .version = vers,
+ .authflavor = old->cl_auth->au_flavor,
+ .cred = old->cl_cred,
+ };
+ struct rpc_clnt *clnt;
+ int err;
+
+ clnt = __rpc_clone_client(&args, old);
+ if (IS_ERR(clnt))
+ goto out;
+ err = rpc_ping(clnt);
+ if (err != 0) {
+ rpc_shutdown_client(clnt);
+ clnt = ERR_PTR(err);
+ }
+out:
+ return clnt;
+}
+EXPORT_SYMBOL_GPL(rpc_bind_new_program);
+
+struct rpc_xprt *
+rpc_task_get_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+ struct rpc_xprt_switch *xps;
+
+ if (!xprt)
+ return NULL;
+ rcu_read_lock();
+ xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+ atomic_long_inc(&xps->xps_queuelen);
+ rcu_read_unlock();
+ atomic_long_inc(&xprt->queuelen);
+
+ return xprt;
+}
+
+static void
+rpc_task_release_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+ struct rpc_xprt_switch *xps;
+
+ atomic_long_dec(&xprt->queuelen);
+ rcu_read_lock();
+ xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+ atomic_long_dec(&xps->xps_queuelen);
+ rcu_read_unlock();
+
+ xprt_put(xprt);
+}
+
+void rpc_task_release_transport(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_xprt;
+
+ if (xprt) {
+ task->tk_xprt = NULL;
+ if (task->tk_client)
+ rpc_task_release_xprt(task->tk_client, xprt);
+ else
+ xprt_put(xprt);
+ }
+}
+EXPORT_SYMBOL_GPL(rpc_task_release_transport);
+
+void rpc_task_release_client(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+
+ rpc_task_release_transport(task);
+ if (clnt != NULL) {
+ /* Remove from client task list */
+ spin_lock(&clnt->cl_lock);
+ list_del(&task->tk_task);
+ spin_unlock(&clnt->cl_lock);
+ task->tk_client = NULL;
+
+ rpc_release_client(clnt);
+ }
+}
+
+static struct rpc_xprt *
+rpc_task_get_first_xprt(struct rpc_clnt *clnt)
+{
+ struct rpc_xprt *xprt;
+
+ rcu_read_lock();
+ xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+ rcu_read_unlock();
+ return rpc_task_get_xprt(clnt, xprt);
+}
+
+static struct rpc_xprt *
+rpc_task_get_next_xprt(struct rpc_clnt *clnt)
+{
+ return rpc_task_get_xprt(clnt, xprt_iter_get_next(&clnt->cl_xpi));
+}
+
+static
+void rpc_task_set_transport(struct rpc_task *task, struct rpc_clnt *clnt)
+{
+ if (task->tk_xprt) {
+ if (!(test_bit(XPRT_OFFLINE, &task->tk_xprt->state) &&
+ (task->tk_flags & RPC_TASK_MOVEABLE)))
+ return;
+ xprt_release(task);
+ xprt_put(task->tk_xprt);
+ }
+ if (task->tk_flags & RPC_TASK_NO_ROUND_ROBIN)
+ task->tk_xprt = rpc_task_get_first_xprt(clnt);
+ else
+ task->tk_xprt = rpc_task_get_next_xprt(clnt);
+}
+
+static
+void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt)
+{
+ rpc_task_set_transport(task, clnt);
+ task->tk_client = clnt;
+ refcount_inc(&clnt->cl_count);
+ if (clnt->cl_softrtry)
+ task->tk_flags |= RPC_TASK_SOFT;
+ if (clnt->cl_softerr)
+ task->tk_flags |= RPC_TASK_TIMEOUT;
+ if (clnt->cl_noretranstimeo)
+ task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT;
+ /* Add to the client's list of all tasks */
+ spin_lock(&clnt->cl_lock);
+ list_add_tail(&task->tk_task, &clnt->cl_tasks);
+ spin_unlock(&clnt->cl_lock);
+}
+
+static void
+rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg)
+{
+ if (msg != NULL) {
+ task->tk_msg.rpc_proc = msg->rpc_proc;
+ task->tk_msg.rpc_argp = msg->rpc_argp;
+ task->tk_msg.rpc_resp = msg->rpc_resp;
+ task->tk_msg.rpc_cred = msg->rpc_cred;
+ if (!(task->tk_flags & RPC_TASK_CRED_NOREF))
+ get_cred(task->tk_msg.rpc_cred);
+ }
+}
+
+/*
+ * Default callback for async RPC calls
+ */
+static void
+rpc_default_callback(struct rpc_task *task, void *data)
+{
+}
+
+static const struct rpc_call_ops rpc_default_ops = {
+ .rpc_call_done = rpc_default_callback,
+};
+
+/**
+ * rpc_run_task - Allocate a new RPC task, then run rpc_execute against it
+ * @task_setup_data: pointer to task initialisation data
+ */
+struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data)
+{
+ struct rpc_task *task;
+
+ task = rpc_new_task(task_setup_data);
+ if (IS_ERR(task))
+ return task;
+
+ if (!RPC_IS_ASYNC(task))
+ task->tk_flags |= RPC_TASK_CRED_NOREF;
+
+ rpc_task_set_client(task, task_setup_data->rpc_client);
+ rpc_task_set_rpc_message(task, task_setup_data->rpc_message);
+
+ if (task->tk_action == NULL)
+ rpc_call_start(task);
+
+ atomic_inc(&task->tk_count);
+ rpc_execute(task);
+ return task;
+}
+EXPORT_SYMBOL_GPL(rpc_run_task);
+
+/**
+ * rpc_call_sync - Perform a synchronous RPC call
+ * @clnt: pointer to RPC client
+ * @msg: RPC call parameters
+ * @flags: RPC call flags
+ */
+int rpc_call_sync(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags)
+{
+ struct rpc_task *task;
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_message = msg,
+ .callback_ops = &rpc_default_ops,
+ .flags = flags,
+ };
+ int status;
+
+ WARN_ON_ONCE(flags & RPC_TASK_ASYNC);
+ if (flags & RPC_TASK_ASYNC) {
+ rpc_release_calldata(task_setup_data.callback_ops,
+ task_setup_data.callback_data);
+ return -EINVAL;
+ }
+
+ task = rpc_run_task(&task_setup_data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ status = task->tk_status;
+ rpc_put_task(task);
+ return status;
+}
+EXPORT_SYMBOL_GPL(rpc_call_sync);
+
+/**
+ * rpc_call_async - Perform an asynchronous RPC call
+ * @clnt: pointer to RPC client
+ * @msg: RPC call parameters
+ * @flags: RPC call flags
+ * @tk_ops: RPC call ops
+ * @data: user call data
+ */
+int
+rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags,
+ const struct rpc_call_ops *tk_ops, void *data)
+{
+ struct rpc_task *task;
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_message = msg,
+ .callback_ops = tk_ops,
+ .callback_data = data,
+ .flags = flags|RPC_TASK_ASYNC,
+ };
+
+ task = rpc_run_task(&task_setup_data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ rpc_put_task(task);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_call_async);
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+static void call_bc_encode(struct rpc_task *task);
+
+/**
+ * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run
+ * rpc_execute against it
+ * @req: RPC request
+ */
+struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req)
+{
+ struct rpc_task *task;
+ struct rpc_task_setup task_setup_data = {
+ .callback_ops = &rpc_default_ops,
+ .flags = RPC_TASK_SOFTCONN |
+ RPC_TASK_NO_RETRANS_TIMEOUT,
+ };
+
+ dprintk("RPC: rpc_run_bc_task req= %p\n", req);
+ /*
+ * Create an rpc_task to send the data
+ */
+ task = rpc_new_task(&task_setup_data);
+ if (IS_ERR(task)) {
+ xprt_free_bc_request(req);
+ return task;
+ }
+
+ xprt_init_bc_request(req, task);
+
+ task->tk_action = call_bc_encode;
+ atomic_inc(&task->tk_count);
+ WARN_ON_ONCE(atomic_read(&task->tk_count) != 2);
+ rpc_execute(task);
+
+ dprintk("RPC: rpc_run_bc_task: task= %p\n", task);
+ return task;
+}
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+/**
+ * rpc_prepare_reply_pages - Prepare to receive a reply data payload into pages
+ * @req: RPC request to prepare
+ * @pages: vector of struct page pointers
+ * @base: offset in first page where receive should start, in bytes
+ * @len: expected size of the upper layer data payload, in bytes
+ * @hdrsize: expected size of upper layer reply header, in XDR words
+ *
+ */
+void rpc_prepare_reply_pages(struct rpc_rqst *req, struct page **pages,
+ unsigned int base, unsigned int len,
+ unsigned int hdrsize)
+{
+ hdrsize += RPC_REPHDRSIZE + req->rq_cred->cr_auth->au_ralign;
+
+ xdr_inline_pages(&req->rq_rcv_buf, hdrsize << 2, pages, base, len);
+ trace_rpc_xdr_reply_pages(req->rq_task, &req->rq_rcv_buf);
+}
+EXPORT_SYMBOL_GPL(rpc_prepare_reply_pages);
+
+void
+rpc_call_start(struct rpc_task *task)
+{
+ task->tk_action = call_start;
+}
+EXPORT_SYMBOL_GPL(rpc_call_start);
+
+/**
+ * rpc_peeraddr - extract remote peer address from clnt's xprt
+ * @clnt: RPC client structure
+ * @buf: target buffer
+ * @bufsize: length of target buffer
+ *
+ * Returns the number of bytes that are actually in the stored address.
+ */
+size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize)
+{
+ size_t bytes;
+ struct rpc_xprt *xprt;
+
+ rcu_read_lock();
+ xprt = rcu_dereference(clnt->cl_xprt);
+
+ bytes = xprt->addrlen;
+ if (bytes > bufsize)
+ bytes = bufsize;
+ memcpy(buf, &xprt->addr, bytes);
+ rcu_read_unlock();
+
+ return bytes;
+}
+EXPORT_SYMBOL_GPL(rpc_peeraddr);
+
+/**
+ * rpc_peeraddr2str - return remote peer address in printable format
+ * @clnt: RPC client structure
+ * @format: address format
+ *
+ * NB: the lifetime of the memory referenced by the returned pointer is
+ * the same as the rpc_xprt itself. As long as the caller uses this
+ * pointer, it must hold the RCU read lock.
+ */
+const char *rpc_peeraddr2str(struct rpc_clnt *clnt,
+ enum rpc_display_format_t format)
+{
+ struct rpc_xprt *xprt;
+
+ xprt = rcu_dereference(clnt->cl_xprt);
+
+ if (xprt->address_strings[format] != NULL)
+ return xprt->address_strings[format];
+ else
+ return "unprintable";
+}
+EXPORT_SYMBOL_GPL(rpc_peeraddr2str);
+
+static const struct sockaddr_in rpc_inaddr_loopback = {
+ .sin_family = AF_INET,
+ .sin_addr.s_addr = htonl(INADDR_ANY),
+};
+
+static const struct sockaddr_in6 rpc_in6addr_loopback = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = IN6ADDR_ANY_INIT,
+};
+
+/*
+ * Try a getsockname() on a connected datagram socket. Using a
+ * connected datagram socket prevents leaving a socket in TIME_WAIT.
+ * This conserves the ephemeral port number space.
+ *
+ * Returns zero and fills in "buf" if successful; otherwise, a
+ * negative errno is returned.
+ */
+static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen,
+ struct sockaddr *buf)
+{
+ struct socket *sock;
+ int err;
+
+ err = __sock_create(net, sap->sa_family,
+ SOCK_DGRAM, IPPROTO_UDP, &sock, 1);
+ if (err < 0) {
+ dprintk("RPC: can't create UDP socket (%d)\n", err);
+ goto out;
+ }
+
+ switch (sap->sa_family) {
+ case AF_INET:
+ err = kernel_bind(sock,
+ (struct sockaddr *)&rpc_inaddr_loopback,
+ sizeof(rpc_inaddr_loopback));
+ break;
+ case AF_INET6:
+ err = kernel_bind(sock,
+ (struct sockaddr *)&rpc_in6addr_loopback,
+ sizeof(rpc_in6addr_loopback));
+ break;
+ default:
+ err = -EAFNOSUPPORT;
+ goto out_release;
+ }
+ if (err < 0) {
+ dprintk("RPC: can't bind UDP socket (%d)\n", err);
+ goto out_release;
+ }
+
+ err = kernel_connect(sock, sap, salen, 0);
+ if (err < 0) {
+ dprintk("RPC: can't connect UDP socket (%d)\n", err);
+ goto out_release;
+ }
+
+ err = kernel_getsockname(sock, buf);
+ if (err < 0) {
+ dprintk("RPC: getsockname failed (%d)\n", err);
+ goto out_release;
+ }
+
+ err = 0;
+ if (buf->sa_family == AF_INET6) {
+ struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)buf;
+ sin6->sin6_scope_id = 0;
+ }
+ dprintk("RPC: %s succeeded\n", __func__);
+
+out_release:
+ sock_release(sock);
+out:
+ return err;
+}
+
+/*
+ * Scraping a connected socket failed, so we don't have a useable
+ * local address. Fallback: generate an address that will prevent
+ * the server from calling us back.
+ *
+ * Returns zero and fills in "buf" if successful; otherwise, a
+ * negative errno is returned.
+ */
+static int rpc_anyaddr(int family, struct sockaddr *buf, size_t buflen)
+{
+ switch (family) {
+ case AF_INET:
+ if (buflen < sizeof(rpc_inaddr_loopback))
+ return -EINVAL;
+ memcpy(buf, &rpc_inaddr_loopback,
+ sizeof(rpc_inaddr_loopback));
+ break;
+ case AF_INET6:
+ if (buflen < sizeof(rpc_in6addr_loopback))
+ return -EINVAL;
+ memcpy(buf, &rpc_in6addr_loopback,
+ sizeof(rpc_in6addr_loopback));
+ break;
+ default:
+ dprintk("RPC: %s: address family not supported\n",
+ __func__);
+ return -EAFNOSUPPORT;
+ }
+ dprintk("RPC: %s: succeeded\n", __func__);
+ return 0;
+}
+
+/**
+ * rpc_localaddr - discover local endpoint address for an RPC client
+ * @clnt: RPC client structure
+ * @buf: target buffer
+ * @buflen: size of target buffer, in bytes
+ *
+ * Returns zero and fills in "buf" and "buflen" if successful;
+ * otherwise, a negative errno is returned.
+ *
+ * This works even if the underlying transport is not currently connected,
+ * or if the upper layer never previously provided a source address.
+ *
+ * The result of this function call is transient: multiple calls in
+ * succession may give different results, depending on how local
+ * networking configuration changes over time.
+ */
+int rpc_localaddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t buflen)
+{
+ struct sockaddr_storage address;
+ struct sockaddr *sap = (struct sockaddr *)&address;
+ struct rpc_xprt *xprt;
+ struct net *net;
+ size_t salen;
+ int err;
+
+ rcu_read_lock();
+ xprt = rcu_dereference(clnt->cl_xprt);
+ salen = xprt->addrlen;
+ memcpy(sap, &xprt->addr, salen);
+ net = get_net(xprt->xprt_net);
+ rcu_read_unlock();
+
+ rpc_set_port(sap, 0);
+ err = rpc_sockname(net, sap, salen, buf);
+ put_net(net);
+ if (err != 0)
+ /* Couldn't discover local address, return ANYADDR */
+ return rpc_anyaddr(sap->sa_family, buf, buflen);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_localaddr);
+
+void
+rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize)
+{
+ struct rpc_xprt *xprt;
+
+ rcu_read_lock();
+ xprt = rcu_dereference(clnt->cl_xprt);
+ if (xprt->ops->set_buffer_size)
+ xprt->ops->set_buffer_size(xprt, sndsize, rcvsize);
+ rcu_read_unlock();
+}
+EXPORT_SYMBOL_GPL(rpc_setbufsize);
+
+/**
+ * rpc_net_ns - Get the network namespace for this RPC client
+ * @clnt: RPC client to query
+ *
+ */
+struct net *rpc_net_ns(struct rpc_clnt *clnt)
+{
+ struct net *ret;
+
+ rcu_read_lock();
+ ret = rcu_dereference(clnt->cl_xprt)->xprt_net;
+ rcu_read_unlock();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_net_ns);
+
+/**
+ * rpc_max_payload - Get maximum payload size for a transport, in bytes
+ * @clnt: RPC client to query
+ *
+ * For stream transports, this is one RPC record fragment (see RFC
+ * 1831), as we don't support multi-record requests yet. For datagram
+ * transports, this is the size of an IP packet minus the IP, UDP, and
+ * RPC header sizes.
+ */
+size_t rpc_max_payload(struct rpc_clnt *clnt)
+{
+ size_t ret;
+
+ rcu_read_lock();
+ ret = rcu_dereference(clnt->cl_xprt)->max_payload;
+ rcu_read_unlock();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_max_payload);
+
+/**
+ * rpc_max_bc_payload - Get maximum backchannel payload size, in bytes
+ * @clnt: RPC client to query
+ */
+size_t rpc_max_bc_payload(struct rpc_clnt *clnt)
+{
+ struct rpc_xprt *xprt;
+ size_t ret;
+
+ rcu_read_lock();
+ xprt = rcu_dereference(clnt->cl_xprt);
+ ret = xprt->ops->bc_maxpayload(xprt);
+ rcu_read_unlock();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_max_bc_payload);
+
+unsigned int rpc_num_bc_slots(struct rpc_clnt *clnt)
+{
+ struct rpc_xprt *xprt;
+ unsigned int ret;
+
+ rcu_read_lock();
+ xprt = rcu_dereference(clnt->cl_xprt);
+ ret = xprt->ops->bc_num_slots(xprt);
+ rcu_read_unlock();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_num_bc_slots);
+
+/**
+ * rpc_force_rebind - force transport to check that remote port is unchanged
+ * @clnt: client to rebind
+ *
+ */
+void rpc_force_rebind(struct rpc_clnt *clnt)
+{
+ if (clnt->cl_autobind) {
+ rcu_read_lock();
+ xprt_clear_bound(rcu_dereference(clnt->cl_xprt));
+ rcu_read_unlock();
+ }
+}
+EXPORT_SYMBOL_GPL(rpc_force_rebind);
+
+static int
+__rpc_restart_call(struct rpc_task *task, void (*action)(struct rpc_task *))
+{
+ task->tk_status = 0;
+ task->tk_rpc_status = 0;
+ task->tk_action = action;
+ return 1;
+}
+
+/*
+ * Restart an (async) RPC call. Usually called from within the
+ * exit handler.
+ */
+int
+rpc_restart_call(struct rpc_task *task)
+{
+ return __rpc_restart_call(task, call_start);
+}
+EXPORT_SYMBOL_GPL(rpc_restart_call);
+
+/*
+ * Restart an (async) RPC call from the call_prepare state.
+ * Usually called from within the exit handler.
+ */
+int
+rpc_restart_call_prepare(struct rpc_task *task)
+{
+ if (task->tk_ops->rpc_call_prepare != NULL)
+ return __rpc_restart_call(task, rpc_prepare_task);
+ return rpc_restart_call(task);
+}
+EXPORT_SYMBOL_GPL(rpc_restart_call_prepare);
+
+const char
+*rpc_proc_name(const struct rpc_task *task)
+{
+ const struct rpc_procinfo *proc = task->tk_msg.rpc_proc;
+
+ if (proc) {
+ if (proc->p_name)
+ return proc->p_name;
+ else
+ return "NULL";
+ } else
+ return "no proc";
+}
+
+static void
+__rpc_call_rpcerror(struct rpc_task *task, int tk_status, int rpc_status)
+{
+ trace_rpc_call_rpcerror(task, tk_status, rpc_status);
+ rpc_task_set_rpc_status(task, rpc_status);
+ rpc_exit(task, tk_status);
+}
+
+static void
+rpc_call_rpcerror(struct rpc_task *task, int status)
+{
+ __rpc_call_rpcerror(task, status, status);
+}
+
+/*
+ * 0. Initial state
+ *
+ * Other FSM states can be visited zero or more times, but
+ * this state is visited exactly once for each RPC.
+ */
+static void
+call_start(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ int idx = task->tk_msg.rpc_proc->p_statidx;
+
+ trace_rpc_request(task);
+
+ if (task->tk_client->cl_shutdown) {
+ rpc_call_rpcerror(task, -EIO);
+ return;
+ }
+
+ /* Increment call count (version might not be valid for ping) */
+ if (clnt->cl_program->version[clnt->cl_vers])
+ clnt->cl_program->version[clnt->cl_vers]->counts[idx]++;
+ clnt->cl_stats->rpccnt++;
+ task->tk_action = call_reserve;
+ rpc_task_set_transport(task, clnt);
+}
+
+/*
+ * 1. Reserve an RPC call slot
+ */
+static void
+call_reserve(struct rpc_task *task)
+{
+ task->tk_status = 0;
+ task->tk_action = call_reserveresult;
+ xprt_reserve(task);
+}
+
+static void call_retry_reserve(struct rpc_task *task);
+
+/*
+ * 1b. Grok the result of xprt_reserve()
+ */
+static void
+call_reserveresult(struct rpc_task *task)
+{
+ int status = task->tk_status;
+
+ /*
+ * After a call to xprt_reserve(), we must have either
+ * a request slot or else an error status.
+ */
+ task->tk_status = 0;
+ if (status >= 0) {
+ if (task->tk_rqstp) {
+ task->tk_action = call_refresh;
+ return;
+ }
+
+ rpc_call_rpcerror(task, -EIO);
+ return;
+ }
+
+ switch (status) {
+ case -ENOMEM:
+ rpc_delay(task, HZ >> 2);
+ fallthrough;
+ case -EAGAIN: /* woken up; retry */
+ task->tk_action = call_retry_reserve;
+ return;
+ default:
+ rpc_call_rpcerror(task, status);
+ }
+}
+
+/*
+ * 1c. Retry reserving an RPC call slot
+ */
+static void
+call_retry_reserve(struct rpc_task *task)
+{
+ task->tk_status = 0;
+ task->tk_action = call_reserveresult;
+ xprt_retry_reserve(task);
+}
+
+/*
+ * 2. Bind and/or refresh the credentials
+ */
+static void
+call_refresh(struct rpc_task *task)
+{
+ task->tk_action = call_refreshresult;
+ task->tk_status = 0;
+ task->tk_client->cl_stats->rpcauthrefresh++;
+ rpcauth_refreshcred(task);
+}
+
+/*
+ * 2a. Process the results of a credential refresh
+ */
+static void
+call_refreshresult(struct rpc_task *task)
+{
+ int status = task->tk_status;
+
+ task->tk_status = 0;
+ task->tk_action = call_refresh;
+ switch (status) {
+ case 0:
+ if (rpcauth_uptodatecred(task)) {
+ task->tk_action = call_allocate;
+ return;
+ }
+ /* Use rate-limiting and a max number of retries if refresh
+ * had status 0 but failed to update the cred.
+ */
+ fallthrough;
+ case -ETIMEDOUT:
+ rpc_delay(task, 3*HZ);
+ fallthrough;
+ case -EAGAIN:
+ status = -EACCES;
+ fallthrough;
+ case -EKEYEXPIRED:
+ if (!task->tk_cred_retry)
+ break;
+ task->tk_cred_retry--;
+ trace_rpc_retry_refresh_status(task);
+ return;
+ case -ENOMEM:
+ rpc_delay(task, HZ >> 4);
+ return;
+ }
+ trace_rpc_refresh_status(task);
+ rpc_call_rpcerror(task, status);
+}
+
+/*
+ * 2b. Allocate the buffer. For details, see sched.c:rpc_malloc.
+ * (Note: buffer memory is freed in xprt_release).
+ */
+static void
+call_allocate(struct rpc_task *task)
+{
+ const struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth;
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+ const struct rpc_procinfo *proc = task->tk_msg.rpc_proc;
+ int status;
+
+ task->tk_status = 0;
+ task->tk_action = call_encode;
+
+ if (req->rq_buffer)
+ return;
+
+ if (proc->p_proc != 0) {
+ BUG_ON(proc->p_arglen == 0);
+ if (proc->p_decode != NULL)
+ BUG_ON(proc->p_replen == 0);
+ }
+
+ /*
+ * Calculate the size (in quads) of the RPC call
+ * and reply headers, and convert both values
+ * to byte sizes.
+ */
+ req->rq_callsize = RPC_CALLHDRSIZE + (auth->au_cslack << 1) +
+ proc->p_arglen;
+ req->rq_callsize <<= 2;
+ /*
+ * Note: the reply buffer must at minimum allocate enough space
+ * for the 'struct accepted_reply' from RFC5531.
+ */
+ req->rq_rcvsize = RPC_REPHDRSIZE + auth->au_rslack + \
+ max_t(size_t, proc->p_replen, 2);
+ req->rq_rcvsize <<= 2;
+
+ status = xprt->ops->buf_alloc(task);
+ trace_rpc_buf_alloc(task, status);
+ if (status == 0)
+ return;
+ if (status != -ENOMEM) {
+ rpc_call_rpcerror(task, status);
+ return;
+ }
+
+ if (RPC_IS_ASYNC(task) || !fatal_signal_pending(current)) {
+ task->tk_action = call_allocate;
+ rpc_delay(task, HZ>>4);
+ return;
+ }
+
+ rpc_call_rpcerror(task, -ERESTARTSYS);
+}
+
+static int
+rpc_task_need_encode(struct rpc_task *task)
+{
+ return test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) == 0 &&
+ (!(task->tk_flags & RPC_TASK_SENT) ||
+ !(task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) ||
+ xprt_request_need_retransmit(task));
+}
+
+static void
+rpc_xdr_encode(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct xdr_stream xdr;
+
+ xdr_buf_init(&req->rq_snd_buf,
+ req->rq_buffer,
+ req->rq_callsize);
+ xdr_buf_init(&req->rq_rcv_buf,
+ req->rq_rbuffer,
+ req->rq_rcvsize);
+
+ req->rq_reply_bytes_recvd = 0;
+ req->rq_snd_buf.head[0].iov_len = 0;
+ xdr_init_encode(&xdr, &req->rq_snd_buf,
+ req->rq_snd_buf.head[0].iov_base, req);
+ if (rpc_encode_header(task, &xdr))
+ return;
+
+ task->tk_status = rpcauth_wrap_req(task, &xdr);
+}
+
+/*
+ * 3. Encode arguments of an RPC call
+ */
+static void
+call_encode(struct rpc_task *task)
+{
+ if (!rpc_task_need_encode(task))
+ goto out;
+
+ /* Dequeue task from the receive queue while we're encoding */
+ xprt_request_dequeue_xprt(task);
+ /* Encode here so that rpcsec_gss can use correct sequence number. */
+ rpc_xdr_encode(task);
+ /* Add task to reply queue before transmission to avoid races */
+ if (task->tk_status == 0 && rpc_reply_expected(task))
+ task->tk_status = xprt_request_enqueue_receive(task);
+ /* Did the encode result in an error condition? */
+ if (task->tk_status != 0) {
+ /* Was the error nonfatal? */
+ switch (task->tk_status) {
+ case -EAGAIN:
+ case -ENOMEM:
+ rpc_delay(task, HZ >> 4);
+ break;
+ case -EKEYEXPIRED:
+ if (!task->tk_cred_retry) {
+ rpc_call_rpcerror(task, task->tk_status);
+ } else {
+ task->tk_action = call_refresh;
+ task->tk_cred_retry--;
+ trace_rpc_retry_refresh_status(task);
+ }
+ break;
+ default:
+ rpc_call_rpcerror(task, task->tk_status);
+ }
+ return;
+ }
+
+ xprt_request_enqueue_transmit(task);
+out:
+ task->tk_action = call_transmit;
+ /* Check that the connection is OK */
+ if (!xprt_bound(task->tk_xprt))
+ task->tk_action = call_bind;
+ else if (!xprt_connected(task->tk_xprt))
+ task->tk_action = call_connect;
+}
+
+/*
+ * Helpers to check if the task was already transmitted, and
+ * to take action when that is the case.
+ */
+static bool
+rpc_task_transmitted(struct rpc_task *task)
+{
+ return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate);
+}
+
+static void
+rpc_task_handle_transmitted(struct rpc_task *task)
+{
+ xprt_end_transmit(task);
+ task->tk_action = call_transmit_status;
+}
+
+/*
+ * 4. Get the server port number if not yet set
+ */
+static void
+call_bind(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+
+ if (rpc_task_transmitted(task)) {
+ rpc_task_handle_transmitted(task);
+ return;
+ }
+
+ if (xprt_bound(xprt)) {
+ task->tk_action = call_connect;
+ return;
+ }
+
+ task->tk_action = call_bind_status;
+ if (!xprt_prepare_transmit(task))
+ return;
+
+ xprt->ops->rpcbind(task);
+}
+
+/*
+ * 4a. Sort out bind result
+ */
+static void
+call_bind_status(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+ int status = -EIO;
+
+ if (rpc_task_transmitted(task)) {
+ rpc_task_handle_transmitted(task);
+ return;
+ }
+
+ if (task->tk_status >= 0)
+ goto out_next;
+ if (xprt_bound(xprt)) {
+ task->tk_status = 0;
+ goto out_next;
+ }
+
+ switch (task->tk_status) {
+ case -ENOMEM:
+ rpc_delay(task, HZ >> 2);
+ goto retry_timeout;
+ case -EACCES:
+ trace_rpcb_prog_unavail_err(task);
+ /* fail immediately if this is an RPC ping */
+ if (task->tk_msg.rpc_proc->p_proc == 0) {
+ status = -EOPNOTSUPP;
+ break;
+ }
+ rpc_delay(task, 3*HZ);
+ goto retry_timeout;
+ case -ENOBUFS:
+ rpc_delay(task, HZ >> 2);
+ goto retry_timeout;
+ case -EAGAIN:
+ goto retry_timeout;
+ case -ETIMEDOUT:
+ trace_rpcb_timeout_err(task);
+ goto retry_timeout;
+ case -EPFNOSUPPORT:
+ /* server doesn't support any rpcbind version we know of */
+ trace_rpcb_bind_version_err(task);
+ break;
+ case -EPROTONOSUPPORT:
+ trace_rpcb_bind_version_err(task);
+ goto retry_timeout;
+ case -ECONNREFUSED: /* connection problems */
+ case -ECONNRESET:
+ case -ECONNABORTED:
+ case -ENOTCONN:
+ case -EHOSTDOWN:
+ case -ENETDOWN:
+ case -EHOSTUNREACH:
+ case -ENETUNREACH:
+ case -EPIPE:
+ trace_rpcb_unreachable_err(task);
+ if (!RPC_IS_SOFTCONN(task)) {
+ rpc_delay(task, 5*HZ);
+ goto retry_timeout;
+ }
+ status = task->tk_status;
+ break;
+ default:
+ trace_rpcb_unrecognized_err(task);
+ }
+
+ rpc_call_rpcerror(task, status);
+ return;
+out_next:
+ task->tk_action = call_connect;
+ return;
+retry_timeout:
+ task->tk_status = 0;
+ task->tk_action = call_bind;
+ rpc_check_timeout(task);
+}
+
+/*
+ * 4b. Connect to the RPC server
+ */
+static void
+call_connect(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+
+ if (rpc_task_transmitted(task)) {
+ rpc_task_handle_transmitted(task);
+ return;
+ }
+
+ if (xprt_connected(xprt)) {
+ task->tk_action = call_transmit;
+ return;
+ }
+
+ task->tk_action = call_connect_status;
+ if (task->tk_status < 0)
+ return;
+ if (task->tk_flags & RPC_TASK_NOCONNECT) {
+ rpc_call_rpcerror(task, -ENOTCONN);
+ return;
+ }
+ if (!xprt_prepare_transmit(task))
+ return;
+ xprt_connect(task);
+}
+
+/*
+ * 4c. Sort out connect result
+ */
+static void
+call_connect_status(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+ struct rpc_clnt *clnt = task->tk_client;
+ int status = task->tk_status;
+
+ if (rpc_task_transmitted(task)) {
+ rpc_task_handle_transmitted(task);
+ return;
+ }
+
+ trace_rpc_connect_status(task);
+
+ if (task->tk_status == 0) {
+ clnt->cl_stats->netreconn++;
+ goto out_next;
+ }
+ if (xprt_connected(xprt)) {
+ task->tk_status = 0;
+ goto out_next;
+ }
+
+ task->tk_status = 0;
+ switch (status) {
+ case -ECONNREFUSED:
+ case -ECONNRESET:
+ /* A positive refusal suggests a rebind is needed. */
+ if (RPC_IS_SOFTCONN(task))
+ break;
+ if (clnt->cl_autobind) {
+ rpc_force_rebind(clnt);
+ goto out_retry;
+ }
+ fallthrough;
+ case -ECONNABORTED:
+ case -ENETDOWN:
+ case -ENETUNREACH:
+ case -EHOSTUNREACH:
+ case -EPIPE:
+ case -EPROTO:
+ xprt_conditional_disconnect(task->tk_rqstp->rq_xprt,
+ task->tk_rqstp->rq_connect_cookie);
+ if (RPC_IS_SOFTCONN(task))
+ break;
+ /* retry with existing socket, after a delay */
+ rpc_delay(task, 3*HZ);
+ fallthrough;
+ case -EADDRINUSE:
+ case -ENOTCONN:
+ case -EAGAIN:
+ case -ETIMEDOUT:
+ if (!(task->tk_flags & RPC_TASK_NO_ROUND_ROBIN) &&
+ (task->tk_flags & RPC_TASK_MOVEABLE) &&
+ test_bit(XPRT_REMOVE, &xprt->state)) {
+ struct rpc_xprt *saved = task->tk_xprt;
+ struct rpc_xprt_switch *xps;
+
+ rcu_read_lock();
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ rcu_read_unlock();
+ if (xps->xps_nxprts > 1) {
+ long value;
+
+ xprt_release(task);
+ value = atomic_long_dec_return(&xprt->queuelen);
+ if (value == 0)
+ rpc_xprt_switch_remove_xprt(xps, saved,
+ true);
+ xprt_put(saved);
+ task->tk_xprt = NULL;
+ task->tk_action = call_start;
+ }
+ xprt_switch_put(xps);
+ if (!task->tk_xprt)
+ return;
+ }
+ goto out_retry;
+ case -ENOBUFS:
+ rpc_delay(task, HZ >> 2);
+ goto out_retry;
+ }
+ rpc_call_rpcerror(task, status);
+ return;
+out_next:
+ task->tk_action = call_transmit;
+ return;
+out_retry:
+ /* Check for timeouts before looping back to call_bind */
+ task->tk_action = call_bind;
+ rpc_check_timeout(task);
+}
+
+/*
+ * 5. Transmit the RPC request, and wait for reply
+ */
+static void
+call_transmit(struct rpc_task *task)
+{
+ if (rpc_task_transmitted(task)) {
+ rpc_task_handle_transmitted(task);
+ return;
+ }
+
+ task->tk_action = call_transmit_status;
+ if (!xprt_prepare_transmit(task))
+ return;
+ task->tk_status = 0;
+ if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) {
+ if (!xprt_connected(task->tk_xprt)) {
+ task->tk_status = -ENOTCONN;
+ return;
+ }
+ xprt_transmit(task);
+ }
+ xprt_end_transmit(task);
+}
+
+/*
+ * 5a. Handle cleanup after a transmission
+ */
+static void
+call_transmit_status(struct rpc_task *task)
+{
+ task->tk_action = call_status;
+
+ /*
+ * Common case: success. Force the compiler to put this
+ * test first.
+ */
+ if (rpc_task_transmitted(task)) {
+ task->tk_status = 0;
+ xprt_request_wait_receive(task);
+ return;
+ }
+
+ switch (task->tk_status) {
+ default:
+ break;
+ case -EBADMSG:
+ task->tk_status = 0;
+ task->tk_action = call_encode;
+ break;
+ /*
+ * Special cases: if we've been waiting on the
+ * socket's write_space() callback, or if the
+ * socket just returned a connection error,
+ * then hold onto the transport lock.
+ */
+ case -ENOMEM:
+ case -ENOBUFS:
+ rpc_delay(task, HZ>>2);
+ fallthrough;
+ case -EBADSLT:
+ case -EAGAIN:
+ task->tk_action = call_transmit;
+ task->tk_status = 0;
+ break;
+ case -ECONNREFUSED:
+ case -EHOSTDOWN:
+ case -ENETDOWN:
+ case -EHOSTUNREACH:
+ case -ENETUNREACH:
+ case -EPERM:
+ if (RPC_IS_SOFTCONN(task)) {
+ if (!task->tk_msg.rpc_proc->p_proc)
+ trace_xprt_ping(task->tk_xprt,
+ task->tk_status);
+ rpc_call_rpcerror(task, task->tk_status);
+ return;
+ }
+ fallthrough;
+ case -ECONNRESET:
+ case -ECONNABORTED:
+ case -EADDRINUSE:
+ case -ENOTCONN:
+ case -EPIPE:
+ task->tk_action = call_bind;
+ task->tk_status = 0;
+ break;
+ }
+ rpc_check_timeout(task);
+}
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+static void call_bc_transmit(struct rpc_task *task);
+static void call_bc_transmit_status(struct rpc_task *task);
+
+static void
+call_bc_encode(struct rpc_task *task)
+{
+ xprt_request_enqueue_transmit(task);
+ task->tk_action = call_bc_transmit;
+}
+
+/*
+ * 5b. Send the backchannel RPC reply. On error, drop the reply. In
+ * addition, disconnect on connectivity errors.
+ */
+static void
+call_bc_transmit(struct rpc_task *task)
+{
+ task->tk_action = call_bc_transmit_status;
+ if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) {
+ if (!xprt_prepare_transmit(task))
+ return;
+ task->tk_status = 0;
+ xprt_transmit(task);
+ }
+ xprt_end_transmit(task);
+}
+
+static void
+call_bc_transmit_status(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (rpc_task_transmitted(task))
+ task->tk_status = 0;
+
+ switch (task->tk_status) {
+ case 0:
+ /* Success */
+ case -ENETDOWN:
+ case -EHOSTDOWN:
+ case -EHOSTUNREACH:
+ case -ENETUNREACH:
+ case -ECONNRESET:
+ case -ECONNREFUSED:
+ case -EADDRINUSE:
+ case -ENOTCONN:
+ case -EPIPE:
+ break;
+ case -ENOMEM:
+ case -ENOBUFS:
+ rpc_delay(task, HZ>>2);
+ fallthrough;
+ case -EBADSLT:
+ case -EAGAIN:
+ task->tk_status = 0;
+ task->tk_action = call_bc_transmit;
+ return;
+ case -ETIMEDOUT:
+ /*
+ * Problem reaching the server. Disconnect and let the
+ * forechannel reestablish the connection. The server will
+ * have to retransmit the backchannel request and we'll
+ * reprocess it. Since these ops are idempotent, there's no
+ * need to cache our reply at this time.
+ */
+ printk(KERN_NOTICE "RPC: Could not send backchannel reply "
+ "error: %d\n", task->tk_status);
+ xprt_conditional_disconnect(req->rq_xprt,
+ req->rq_connect_cookie);
+ break;
+ default:
+ /*
+ * We were unable to reply and will have to drop the
+ * request. The server should reconnect and retransmit.
+ */
+ printk(KERN_NOTICE "RPC: Could not send backchannel reply "
+ "error: %d\n", task->tk_status);
+ break;
+ }
+ task->tk_action = rpc_exit_task;
+}
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+/*
+ * 6. Sort out the RPC call status
+ */
+static void
+call_status(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ int status;
+
+ if (!task->tk_msg.rpc_proc->p_proc)
+ trace_xprt_ping(task->tk_xprt, task->tk_status);
+
+ status = task->tk_status;
+ if (status >= 0) {
+ task->tk_action = call_decode;
+ return;
+ }
+
+ trace_rpc_call_status(task);
+ task->tk_status = 0;
+ switch(status) {
+ case -EHOSTDOWN:
+ case -ENETDOWN:
+ case -EHOSTUNREACH:
+ case -ENETUNREACH:
+ case -EPERM:
+ if (RPC_IS_SOFTCONN(task))
+ goto out_exit;
+ /*
+ * Delay any retries for 3 seconds, then handle as if it
+ * were a timeout.
+ */
+ rpc_delay(task, 3*HZ);
+ fallthrough;
+ case -ETIMEDOUT:
+ break;
+ case -ECONNREFUSED:
+ case -ECONNRESET:
+ case -ECONNABORTED:
+ case -ENOTCONN:
+ rpc_force_rebind(clnt);
+ break;
+ case -EADDRINUSE:
+ rpc_delay(task, 3*HZ);
+ fallthrough;
+ case -EPIPE:
+ case -EAGAIN:
+ break;
+ case -ENFILE:
+ case -ENOBUFS:
+ case -ENOMEM:
+ rpc_delay(task, HZ>>2);
+ break;
+ case -EIO:
+ /* shutdown or soft timeout */
+ goto out_exit;
+ default:
+ if (clnt->cl_chatty)
+ printk("%s: RPC call returned error %d\n",
+ clnt->cl_program->name, -status);
+ goto out_exit;
+ }
+ task->tk_action = call_encode;
+ rpc_check_timeout(task);
+ return;
+out_exit:
+ rpc_call_rpcerror(task, status);
+}
+
+static bool
+rpc_check_connected(const struct rpc_rqst *req)
+{
+ /* No allocated request or transport? return true */
+ if (!req || !req->rq_xprt)
+ return true;
+ return xprt_connected(req->rq_xprt);
+}
+
+static void
+rpc_check_timeout(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+
+ if (RPC_SIGNALLED(task))
+ return;
+
+ if (xprt_adjust_timeout(task->tk_rqstp) == 0)
+ return;
+
+ trace_rpc_timeout_status(task);
+ task->tk_timeouts++;
+
+ if (RPC_IS_SOFTCONN(task) && !rpc_check_connected(task->tk_rqstp)) {
+ rpc_call_rpcerror(task, -ETIMEDOUT);
+ return;
+ }
+
+ if (RPC_IS_SOFT(task)) {
+ /*
+ * Once a "no retrans timeout" soft tasks (a.k.a NFSv4) has
+ * been sent, it should time out only if the transport
+ * connection gets terminally broken.
+ */
+ if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) &&
+ rpc_check_connected(task->tk_rqstp))
+ return;
+
+ if (clnt->cl_chatty) {
+ pr_notice_ratelimited(
+ "%s: server %s not responding, timed out\n",
+ clnt->cl_program->name,
+ task->tk_xprt->servername);
+ }
+ if (task->tk_flags & RPC_TASK_TIMEOUT)
+ rpc_call_rpcerror(task, -ETIMEDOUT);
+ else
+ __rpc_call_rpcerror(task, -EIO, -ETIMEDOUT);
+ return;
+ }
+
+ if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) {
+ task->tk_flags |= RPC_CALL_MAJORSEEN;
+ if (clnt->cl_chatty) {
+ pr_notice_ratelimited(
+ "%s: server %s not responding, still trying\n",
+ clnt->cl_program->name,
+ task->tk_xprt->servername);
+ }
+ }
+ rpc_force_rebind(clnt);
+ /*
+ * Did our request time out due to an RPCSEC_GSS out-of-sequence
+ * event? RFC2203 requires the server to drop all such requests.
+ */
+ rpcauth_invalcred(task);
+}
+
+/*
+ * 7. Decode the RPC reply
+ */
+static void
+call_decode(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct xdr_stream xdr;
+ int err;
+
+ if (!task->tk_msg.rpc_proc->p_decode) {
+ task->tk_action = rpc_exit_task;
+ return;
+ }
+
+ if (task->tk_flags & RPC_CALL_MAJORSEEN) {
+ if (clnt->cl_chatty) {
+ pr_notice_ratelimited("%s: server %s OK\n",
+ clnt->cl_program->name,
+ task->tk_xprt->servername);
+ }
+ task->tk_flags &= ~RPC_CALL_MAJORSEEN;
+ }
+
+ /*
+ * Did we ever call xprt_complete_rqst()? If not, we should assume
+ * the message is incomplete.
+ */
+ err = -EAGAIN;
+ if (!req->rq_reply_bytes_recvd)
+ goto out;
+
+ /* Ensure that we see all writes made by xprt_complete_rqst()
+ * before it changed req->rq_reply_bytes_recvd.
+ */
+ smp_rmb();
+
+ req->rq_rcv_buf.len = req->rq_private_buf.len;
+ trace_rpc_xdr_recvfrom(task, &req->rq_rcv_buf);
+
+ /* Check that the softirq receive buffer is valid */
+ WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf,
+ sizeof(req->rq_rcv_buf)) != 0);
+
+ xdr_init_decode(&xdr, &req->rq_rcv_buf,
+ req->rq_rcv_buf.head[0].iov_base, req);
+ err = rpc_decode_header(task, &xdr);
+out:
+ switch (err) {
+ case 0:
+ task->tk_action = rpc_exit_task;
+ task->tk_status = rpcauth_unwrap_resp(task, &xdr);
+ xdr_finish_decode(&xdr);
+ return;
+ case -EAGAIN:
+ task->tk_status = 0;
+ if (task->tk_client->cl_discrtry)
+ xprt_conditional_disconnect(req->rq_xprt,
+ req->rq_connect_cookie);
+ task->tk_action = call_encode;
+ rpc_check_timeout(task);
+ break;
+ case -EKEYREJECTED:
+ task->tk_action = call_reserve;
+ rpc_check_timeout(task);
+ rpcauth_invalcred(task);
+ /* Ensure we obtain a new XID if we retry! */
+ xprt_release(task);
+ }
+}
+
+static int
+rpc_encode_header(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ struct rpc_rqst *req = task->tk_rqstp;
+ __be32 *p;
+ int error;
+
+ error = -EMSGSIZE;
+ p = xdr_reserve_space(xdr, RPC_CALLHDRSIZE << 2);
+ if (!p)
+ goto out_fail;
+ *p++ = req->rq_xid;
+ *p++ = rpc_call;
+ *p++ = cpu_to_be32(RPC_VERSION);
+ *p++ = cpu_to_be32(clnt->cl_prog);
+ *p++ = cpu_to_be32(clnt->cl_vers);
+ *p = cpu_to_be32(task->tk_msg.rpc_proc->p_proc);
+
+ error = rpcauth_marshcred(task, xdr);
+ if (error < 0)
+ goto out_fail;
+ return 0;
+out_fail:
+ trace_rpc_bad_callhdr(task);
+ rpc_call_rpcerror(task, error);
+ return error;
+}
+
+static noinline int
+rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+ int error;
+ __be32 *p;
+
+ /* RFC-1014 says that the representation of XDR data must be a
+ * multiple of four bytes
+ * - if it isn't pointer subtraction in the NFS client may give
+ * undefined results
+ */
+ if (task->tk_rqstp->rq_rcv_buf.len & 3)
+ goto out_unparsable;
+
+ p = xdr_inline_decode(xdr, 3 * sizeof(*p));
+ if (!p)
+ goto out_unparsable;
+ p++; /* skip XID */
+ if (*p++ != rpc_reply)
+ goto out_unparsable;
+ if (*p++ != rpc_msg_accepted)
+ goto out_msg_denied;
+
+ error = rpcauth_checkverf(task, xdr);
+ if (error)
+ goto out_verifier;
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (!p)
+ goto out_unparsable;
+ switch (*p) {
+ case rpc_success:
+ return 0;
+ case rpc_prog_unavail:
+ trace_rpc__prog_unavail(task);
+ error = -EPFNOSUPPORT;
+ goto out_err;
+ case rpc_prog_mismatch:
+ trace_rpc__prog_mismatch(task);
+ error = -EPROTONOSUPPORT;
+ goto out_err;
+ case rpc_proc_unavail:
+ trace_rpc__proc_unavail(task);
+ error = -EOPNOTSUPP;
+ goto out_err;
+ case rpc_garbage_args:
+ case rpc_system_err:
+ trace_rpc__garbage_args(task);
+ error = -EIO;
+ break;
+ default:
+ goto out_unparsable;
+ }
+
+out_garbage:
+ clnt->cl_stats->rpcgarbage++;
+ if (task->tk_garb_retry) {
+ task->tk_garb_retry--;
+ task->tk_action = call_encode;
+ return -EAGAIN;
+ }
+out_err:
+ rpc_call_rpcerror(task, error);
+ return error;
+
+out_unparsable:
+ trace_rpc__unparsable(task);
+ error = -EIO;
+ goto out_garbage;
+
+out_verifier:
+ trace_rpc_bad_verifier(task);
+ switch (error) {
+ case -EPROTONOSUPPORT:
+ goto out_err;
+ case -EACCES:
+ /* Re-encode with a fresh cred */
+ fallthrough;
+ default:
+ goto out_garbage;
+ }
+
+out_msg_denied:
+ error = -EACCES;
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (!p)
+ goto out_unparsable;
+ switch (*p++) {
+ case rpc_auth_error:
+ break;
+ case rpc_mismatch:
+ trace_rpc__mismatch(task);
+ error = -EPROTONOSUPPORT;
+ goto out_err;
+ default:
+ goto out_unparsable;
+ }
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (!p)
+ goto out_unparsable;
+ switch (*p++) {
+ case rpc_autherr_rejectedcred:
+ case rpc_autherr_rejectedverf:
+ case rpcsec_gsserr_credproblem:
+ case rpcsec_gsserr_ctxproblem:
+ rpcauth_invalcred(task);
+ if (!task->tk_cred_retry)
+ break;
+ task->tk_cred_retry--;
+ trace_rpc__stale_creds(task);
+ return -EKEYREJECTED;
+ case rpc_autherr_badcred:
+ case rpc_autherr_badverf:
+ /* possibly garbled cred/verf? */
+ if (!task->tk_garb_retry)
+ break;
+ task->tk_garb_retry--;
+ trace_rpc__bad_creds(task);
+ task->tk_action = call_encode;
+ return -EAGAIN;
+ case rpc_autherr_tooweak:
+ trace_rpc__auth_tooweak(task);
+ pr_warn("RPC: server %s requires stronger authentication.\n",
+ task->tk_xprt->servername);
+ break;
+ default:
+ goto out_unparsable;
+ }
+ goto out_err;
+}
+
+static void rpcproc_encode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ const void *obj)
+{
+}
+
+static int rpcproc_decode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ void *obj)
+{
+ return 0;
+}
+
+static const struct rpc_procinfo rpcproc_null = {
+ .p_encode = rpcproc_encode_null,
+ .p_decode = rpcproc_decode_null,
+};
+
+static const struct rpc_procinfo rpcproc_null_noreply = {
+ .p_encode = rpcproc_encode_null,
+};
+
+static void
+rpc_null_call_prepare(struct rpc_task *task, void *data)
+{
+ task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
+ rpc_call_start(task);
+}
+
+static const struct rpc_call_ops rpc_null_ops = {
+ .rpc_call_prepare = rpc_null_call_prepare,
+ .rpc_call_done = rpc_default_callback,
+};
+
+static
+struct rpc_task *rpc_call_null_helper(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt, struct rpc_cred *cred, int flags,
+ const struct rpc_call_ops *ops, void *data)
+{
+ struct rpc_message msg = {
+ .rpc_proc = &rpcproc_null,
+ };
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_xprt = xprt,
+ .rpc_message = &msg,
+ .rpc_op_cred = cred,
+ .callback_ops = ops ?: &rpc_null_ops,
+ .callback_data = data,
+ .flags = flags | RPC_TASK_SOFT | RPC_TASK_SOFTCONN |
+ RPC_TASK_NULLCREDS,
+ };
+
+ return rpc_run_task(&task_setup_data);
+}
+
+struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags)
+{
+ return rpc_call_null_helper(clnt, NULL, cred, flags, NULL, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_call_null);
+
+static int rpc_ping(struct rpc_clnt *clnt)
+{
+ struct rpc_task *task;
+ int status;
+
+ if (clnt->cl_auth->au_ops->ping)
+ return clnt->cl_auth->au_ops->ping(clnt);
+
+ task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ status = task->tk_status;
+ rpc_put_task(task);
+ return status;
+}
+
+static int rpc_ping_noreply(struct rpc_clnt *clnt)
+{
+ struct rpc_message msg = {
+ .rpc_proc = &rpcproc_null_noreply,
+ };
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_message = &msg,
+ .callback_ops = &rpc_null_ops,
+ .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN | RPC_TASK_NULLCREDS,
+ };
+ struct rpc_task *task;
+ int status;
+
+ task = rpc_run_task(&task_setup_data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ status = task->tk_status;
+ rpc_put_task(task);
+ return status;
+}
+
+struct rpc_cb_add_xprt_calldata {
+ struct rpc_xprt_switch *xps;
+ struct rpc_xprt *xprt;
+};
+
+static void rpc_cb_add_xprt_done(struct rpc_task *task, void *calldata)
+{
+ struct rpc_cb_add_xprt_calldata *data = calldata;
+
+ if (task->tk_status == 0)
+ rpc_xprt_switch_add_xprt(data->xps, data->xprt);
+}
+
+static void rpc_cb_add_xprt_release(void *calldata)
+{
+ struct rpc_cb_add_xprt_calldata *data = calldata;
+
+ xprt_put(data->xprt);
+ xprt_switch_put(data->xps);
+ kfree(data);
+}
+
+static const struct rpc_call_ops rpc_cb_add_xprt_call_ops = {
+ .rpc_call_prepare = rpc_null_call_prepare,
+ .rpc_call_done = rpc_cb_add_xprt_done,
+ .rpc_release = rpc_cb_add_xprt_release,
+};
+
+/**
+ * rpc_clnt_test_and_add_xprt - Test and add a new transport to a rpc_clnt
+ * @clnt: pointer to struct rpc_clnt
+ * @xps: pointer to struct rpc_xprt_switch,
+ * @xprt: pointer struct rpc_xprt
+ * @in_max_connect: pointer to the max_connect value for the passed in xprt transport
+ */
+int rpc_clnt_test_and_add_xprt(struct rpc_clnt *clnt,
+ struct rpc_xprt_switch *xps, struct rpc_xprt *xprt,
+ void *in_max_connect)
+{
+ struct rpc_cb_add_xprt_calldata *data;
+ struct rpc_task *task;
+ int max_connect = clnt->cl_max_connect;
+
+ if (in_max_connect)
+ max_connect = *(int *)in_max_connect;
+ if (xps->xps_nunique_destaddr_xprts + 1 > max_connect) {
+ rcu_read_lock();
+ pr_warn("SUNRPC: reached max allowed number (%d) did not add "
+ "transport to server: %s\n", max_connect,
+ rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR));
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+
+ data = kmalloc(sizeof(*data), GFP_KERNEL);
+ if (!data)
+ return -ENOMEM;
+ data->xps = xprt_switch_get(xps);
+ data->xprt = xprt_get(xprt);
+ if (rpc_xprt_switch_has_addr(data->xps, (struct sockaddr *)&xprt->addr)) {
+ rpc_cb_add_xprt_release(data);
+ goto success;
+ }
+
+ task = rpc_call_null_helper(clnt, xprt, NULL, RPC_TASK_ASYNC,
+ &rpc_cb_add_xprt_call_ops, data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+
+ data->xps->xps_nunique_destaddr_xprts++;
+ rpc_put_task(task);
+success:
+ return 1;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_test_and_add_xprt);
+
+static int rpc_clnt_add_xprt_helper(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ struct rpc_add_xprt_test *data)
+{
+ struct rpc_task *task;
+ int status = -EADDRINUSE;
+
+ /* Test the connection */
+ task = rpc_call_null_helper(clnt, xprt, NULL, 0, NULL, NULL);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+
+ status = task->tk_status;
+ rpc_put_task(task);
+
+ if (status < 0)
+ return status;
+
+ /* rpc_xprt_switch and rpc_xprt are deferrenced by add_xprt_test() */
+ data->add_xprt_test(clnt, xprt, data->data);
+
+ return 0;
+}
+
+/**
+ * rpc_clnt_setup_test_and_add_xprt()
+ *
+ * This is an rpc_clnt_add_xprt setup() function which returns 1 so:
+ * 1) caller of the test function must dereference the rpc_xprt_switch
+ * and the rpc_xprt.
+ * 2) test function must call rpc_xprt_switch_add_xprt, usually in
+ * the rpc_call_done routine.
+ *
+ * Upon success (return of 1), the test function adds the new
+ * transport to the rpc_clnt xprt switch
+ *
+ * @clnt: struct rpc_clnt to get the new transport
+ * @xps: the rpc_xprt_switch to hold the new transport
+ * @xprt: the rpc_xprt to test
+ * @data: a struct rpc_add_xprt_test pointer that holds the test function
+ * and test function call data
+ */
+int rpc_clnt_setup_test_and_add_xprt(struct rpc_clnt *clnt,
+ struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt,
+ void *data)
+{
+ int status = -EADDRINUSE;
+
+ xprt = xprt_get(xprt);
+ xprt_switch_get(xps);
+
+ if (rpc_xprt_switch_has_addr(xps, (struct sockaddr *)&xprt->addr))
+ goto out_err;
+
+ status = rpc_clnt_add_xprt_helper(clnt, xprt, data);
+ if (status < 0)
+ goto out_err;
+
+ status = 1;
+out_err:
+ xprt_put(xprt);
+ xprt_switch_put(xps);
+ if (status < 0)
+ pr_info("RPC: rpc_clnt_test_xprt failed: %d addr %s not "
+ "added\n", status,
+ xprt->address_strings[RPC_DISPLAY_ADDR]);
+ /* so that rpc_clnt_add_xprt does not call rpc_xprt_switch_add_xprt */
+ return status;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_setup_test_and_add_xprt);
+
+/**
+ * rpc_clnt_add_xprt - Add a new transport to a rpc_clnt
+ * @clnt: pointer to struct rpc_clnt
+ * @xprtargs: pointer to struct xprt_create
+ * @setup: callback to test and/or set up the connection
+ * @data: pointer to setup function data
+ *
+ * Creates a new transport using the parameters set in args and
+ * adds it to clnt.
+ * If ping is set, then test that connectivity succeeds before
+ * adding the new transport.
+ *
+ */
+int rpc_clnt_add_xprt(struct rpc_clnt *clnt,
+ struct xprt_create *xprtargs,
+ int (*setup)(struct rpc_clnt *,
+ struct rpc_xprt_switch *,
+ struct rpc_xprt *,
+ void *),
+ void *data)
+{
+ struct rpc_xprt_switch *xps;
+ struct rpc_xprt *xprt;
+ unsigned long connect_timeout;
+ unsigned long reconnect_timeout;
+ unsigned char resvport, reuseport;
+ int ret = 0, ident;
+
+ rcu_read_lock();
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ xprt = xprt_iter_xprt(&clnt->cl_xpi);
+ if (xps == NULL || xprt == NULL) {
+ rcu_read_unlock();
+ xprt_switch_put(xps);
+ return -EAGAIN;
+ }
+ resvport = xprt->resvport;
+ reuseport = xprt->reuseport;
+ connect_timeout = xprt->connect_timeout;
+ reconnect_timeout = xprt->max_reconnect_timeout;
+ ident = xprt->xprt_class->ident;
+ rcu_read_unlock();
+
+ if (!xprtargs->ident)
+ xprtargs->ident = ident;
+ xprtargs->xprtsec = clnt->cl_xprtsec;
+ xprt = xprt_create_transport(xprtargs);
+ if (IS_ERR(xprt)) {
+ ret = PTR_ERR(xprt);
+ goto out_put_switch;
+ }
+ xprt->resvport = resvport;
+ xprt->reuseport = reuseport;
+
+ if (xprtargs->connect_timeout)
+ connect_timeout = xprtargs->connect_timeout;
+ if (xprtargs->reconnect_timeout)
+ reconnect_timeout = xprtargs->reconnect_timeout;
+ if (xprt->ops->set_connect_timeout != NULL)
+ xprt->ops->set_connect_timeout(xprt,
+ connect_timeout,
+ reconnect_timeout);
+
+ rpc_xprt_switch_set_roundrobin(xps);
+ if (setup) {
+ ret = setup(clnt, xps, xprt, data);
+ if (ret != 0)
+ goto out_put_xprt;
+ }
+ rpc_xprt_switch_add_xprt(xps, xprt);
+out_put_xprt:
+ xprt_put(xprt);
+out_put_switch:
+ xprt_switch_put(xps);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_add_xprt);
+
+static int rpc_xprt_probe_trunked(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ struct rpc_add_xprt_test *data)
+{
+ struct rpc_xprt_switch *xps;
+ struct rpc_xprt *main_xprt;
+ int status = 0;
+
+ xprt_get(xprt);
+
+ rcu_read_lock();
+ main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ status = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr,
+ (struct sockaddr *)&main_xprt->addr);
+ rcu_read_unlock();
+ xprt_put(main_xprt);
+ if (status || !test_bit(XPRT_OFFLINE, &xprt->state))
+ goto out;
+
+ status = rpc_clnt_add_xprt_helper(clnt, xprt, data);
+out:
+ xprt_put(xprt);
+ xprt_switch_put(xps);
+ return status;
+}
+
+/* rpc_clnt_probe_trunked_xprt -- probe offlined transport for session trunking
+ * @clnt rpc_clnt structure
+ *
+ * For each offlined transport found in the rpc_clnt structure call
+ * the function rpc_xprt_probe_trunked() which will determine if this
+ * transport still belongs to the trunking group.
+ */
+void rpc_clnt_probe_trunked_xprts(struct rpc_clnt *clnt,
+ struct rpc_add_xprt_test *data)
+{
+ struct rpc_xprt_iter xpi;
+ int ret;
+
+ ret = rpc_clnt_xprt_iter_offline_init(clnt, &xpi);
+ if (ret)
+ return;
+ for (;;) {
+ struct rpc_xprt *xprt = xprt_iter_get_next(&xpi);
+
+ if (!xprt)
+ break;
+ ret = rpc_xprt_probe_trunked(clnt, xprt, data);
+ xprt_put(xprt);
+ if (ret < 0)
+ break;
+ xprt_iter_rewind(&xpi);
+ }
+ xprt_iter_destroy(&xpi);
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_probe_trunked_xprts);
+
+static int rpc_xprt_offline(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ void *data)
+{
+ struct rpc_xprt *main_xprt;
+ struct rpc_xprt_switch *xps;
+ int err = 0;
+
+ xprt_get(xprt);
+
+ rcu_read_lock();
+ main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+ xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ err = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr,
+ (struct sockaddr *)&main_xprt->addr);
+ rcu_read_unlock();
+ xprt_put(main_xprt);
+ if (err)
+ goto out;
+
+ if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) {
+ err = -EINTR;
+ goto out;
+ }
+ xprt_set_offline_locked(xprt, xps);
+
+ xprt_release_write(xprt, NULL);
+out:
+ xprt_put(xprt);
+ xprt_switch_put(xps);
+ return err;
+}
+
+/* rpc_clnt_manage_trunked_xprts -- offline trunked transports
+ * @clnt rpc_clnt structure
+ *
+ * For each active transport found in the rpc_clnt structure call
+ * the function rpc_xprt_offline() which will identify trunked transports
+ * and will mark them offline.
+ */
+void rpc_clnt_manage_trunked_xprts(struct rpc_clnt *clnt)
+{
+ rpc_clnt_iterate_for_each_xprt(clnt, rpc_xprt_offline, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_manage_trunked_xprts);
+
+struct connect_timeout_data {
+ unsigned long connect_timeout;
+ unsigned long reconnect_timeout;
+};
+
+static int
+rpc_xprt_set_connect_timeout(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ void *data)
+{
+ struct connect_timeout_data *timeo = data;
+
+ if (xprt->ops->set_connect_timeout)
+ xprt->ops->set_connect_timeout(xprt,
+ timeo->connect_timeout,
+ timeo->reconnect_timeout);
+ return 0;
+}
+
+void
+rpc_set_connect_timeout(struct rpc_clnt *clnt,
+ unsigned long connect_timeout,
+ unsigned long reconnect_timeout)
+{
+ struct connect_timeout_data timeout = {
+ .connect_timeout = connect_timeout,
+ .reconnect_timeout = reconnect_timeout,
+ };
+ rpc_clnt_iterate_for_each_xprt(clnt,
+ rpc_xprt_set_connect_timeout,
+ &timeout);
+}
+EXPORT_SYMBOL_GPL(rpc_set_connect_timeout);
+
+void rpc_clnt_xprt_switch_put(struct rpc_clnt *clnt)
+{
+ rcu_read_lock();
+ xprt_switch_put(rcu_dereference(clnt->cl_xpi.xpi_xpswitch));
+ rcu_read_unlock();
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_put);
+
+void rpc_clnt_xprt_set_online(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+ struct rpc_xprt_switch *xps;
+
+ rcu_read_lock();
+ xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+ rcu_read_unlock();
+ xprt_set_online_locked(xprt, xps);
+}
+
+void rpc_clnt_xprt_switch_add_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+ if (rpc_clnt_xprt_switch_has_addr(clnt,
+ (const struct sockaddr *)&xprt->addr)) {
+ return rpc_clnt_xprt_set_online(clnt, xprt);
+ }
+ rcu_read_lock();
+ rpc_xprt_switch_add_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch),
+ xprt);
+ rcu_read_unlock();
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_add_xprt);
+
+void rpc_clnt_xprt_switch_remove_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+ struct rpc_xprt_switch *xps;
+
+ rcu_read_lock();
+ xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+ rpc_xprt_switch_remove_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch),
+ xprt, 0);
+ xps->xps_nunique_destaddr_xprts--;
+ rcu_read_unlock();
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_remove_xprt);
+
+bool rpc_clnt_xprt_switch_has_addr(struct rpc_clnt *clnt,
+ const struct sockaddr *sap)
+{
+ struct rpc_xprt_switch *xps;
+ bool ret;
+
+ rcu_read_lock();
+ xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+ ret = rpc_xprt_switch_has_addr(xps, sap);
+ rcu_read_unlock();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_has_addr);
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+static void rpc_show_header(void)
+{
+ printk(KERN_INFO "-pid- flgs status -client- --rqstp- "
+ "-timeout ---ops--\n");
+}
+
+static void rpc_show_task(const struct rpc_clnt *clnt,
+ const struct rpc_task *task)
+{
+ const char *rpc_waitq = "none";
+
+ if (RPC_IS_QUEUED(task))
+ rpc_waitq = rpc_qname(task->tk_waitqueue);
+
+ printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%ps q:%s\n",
+ task->tk_pid, task->tk_flags, task->tk_status,
+ clnt, task->tk_rqstp, rpc_task_timeout(task), task->tk_ops,
+ clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task),
+ task->tk_action, rpc_waitq);
+}
+
+void rpc_show_tasks(struct net *net)
+{
+ struct rpc_clnt *clnt;
+ struct rpc_task *task;
+ int header = 0;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ spin_lock(&sn->rpc_client_lock);
+ list_for_each_entry(clnt, &sn->all_clients, cl_clients) {
+ spin_lock(&clnt->cl_lock);
+ list_for_each_entry(task, &clnt->cl_tasks, tk_task) {
+ if (!header) {
+ rpc_show_header();
+ header++;
+ }
+ rpc_show_task(clnt, task);
+ }
+ spin_unlock(&clnt->cl_lock);
+ }
+ spin_unlock(&sn->rpc_client_lock);
+}
+#endif
+
+#if IS_ENABLED(CONFIG_SUNRPC_SWAP)
+static int
+rpc_clnt_swap_activate_callback(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ void *dummy)
+{
+ return xprt_enable_swap(xprt);
+}
+
+int
+rpc_clnt_swap_activate(struct rpc_clnt *clnt)
+{
+ while (clnt != clnt->cl_parent)
+ clnt = clnt->cl_parent;
+ if (atomic_inc_return(&clnt->cl_swapper) == 1)
+ return rpc_clnt_iterate_for_each_xprt(clnt,
+ rpc_clnt_swap_activate_callback, NULL);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_swap_activate);
+
+static int
+rpc_clnt_swap_deactivate_callback(struct rpc_clnt *clnt,
+ struct rpc_xprt *xprt,
+ void *dummy)
+{
+ xprt_disable_swap(xprt);
+ return 0;
+}
+
+void
+rpc_clnt_swap_deactivate(struct rpc_clnt *clnt)
+{
+ while (clnt != clnt->cl_parent)
+ clnt = clnt->cl_parent;
+ if (atomic_dec_if_positive(&clnt->cl_swapper) == 0)
+ rpc_clnt_iterate_for_each_xprt(clnt,
+ rpc_clnt_swap_deactivate_callback, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_swap_deactivate);
+#endif /* CONFIG_SUNRPC_SWAP */
diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c
new file mode 100644
index 0000000000..a176d5a0b0
--- /dev/null
+++ b/net/sunrpc/debugfs.c
@@ -0,0 +1,294 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * debugfs interface for sunrpc
+ *
+ * (c) 2014 Jeff Layton <jlayton@primarydata.com>
+ */
+
+#include <linux/debugfs.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/clnt.h>
+
+#include "netns.h"
+#include "fail.h"
+
+static struct dentry *topdir;
+static struct dentry *rpc_clnt_dir;
+static struct dentry *rpc_xprt_dir;
+
+static int
+tasks_show(struct seq_file *f, void *v)
+{
+ u32 xid = 0;
+ struct rpc_task *task = v;
+ struct rpc_clnt *clnt = task->tk_client;
+ const char *rpc_waitq = "none";
+
+ if (RPC_IS_QUEUED(task))
+ rpc_waitq = rpc_qname(task->tk_waitqueue);
+
+ if (task->tk_rqstp)
+ xid = be32_to_cpu(task->tk_rqstp->rq_xid);
+
+ seq_printf(f, "%5u %04x %6d 0x%x 0x%x %8ld %ps %sv%u %s a:%ps q:%s\n",
+ task->tk_pid, task->tk_flags, task->tk_status,
+ clnt->cl_clid, xid, rpc_task_timeout(task), task->tk_ops,
+ clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task),
+ task->tk_action, rpc_waitq);
+ return 0;
+}
+
+static void *
+tasks_start(struct seq_file *f, loff_t *ppos)
+ __acquires(&clnt->cl_lock)
+{
+ struct rpc_clnt *clnt = f->private;
+ loff_t pos = *ppos;
+ struct rpc_task *task;
+
+ spin_lock(&clnt->cl_lock);
+ list_for_each_entry(task, &clnt->cl_tasks, tk_task)
+ if (pos-- == 0)
+ return task;
+ return NULL;
+}
+
+static void *
+tasks_next(struct seq_file *f, void *v, loff_t *pos)
+{
+ struct rpc_clnt *clnt = f->private;
+ struct rpc_task *task = v;
+ struct list_head *next = task->tk_task.next;
+
+ ++*pos;
+
+ /* If there's another task on list, return it */
+ if (next == &clnt->cl_tasks)
+ return NULL;
+ return list_entry(next, struct rpc_task, tk_task);
+}
+
+static void
+tasks_stop(struct seq_file *f, void *v)
+ __releases(&clnt->cl_lock)
+{
+ struct rpc_clnt *clnt = f->private;
+ spin_unlock(&clnt->cl_lock);
+}
+
+static const struct seq_operations tasks_seq_operations = {
+ .start = tasks_start,
+ .next = tasks_next,
+ .stop = tasks_stop,
+ .show = tasks_show,
+};
+
+static int tasks_open(struct inode *inode, struct file *filp)
+{
+ int ret = seq_open(filp, &tasks_seq_operations);
+ if (!ret) {
+ struct seq_file *seq = filp->private_data;
+ struct rpc_clnt *clnt = seq->private = inode->i_private;
+
+ if (!refcount_inc_not_zero(&clnt->cl_count)) {
+ seq_release(inode, filp);
+ ret = -EINVAL;
+ }
+ }
+
+ return ret;
+}
+
+static int
+tasks_release(struct inode *inode, struct file *filp)
+{
+ struct seq_file *seq = filp->private_data;
+ struct rpc_clnt *clnt = seq->private;
+
+ rpc_release_client(clnt);
+ return seq_release(inode, filp);
+}
+
+static const struct file_operations tasks_fops = {
+ .owner = THIS_MODULE,
+ .open = tasks_open,
+ .read = seq_read,
+ .llseek = seq_lseek,
+ .release = tasks_release,
+};
+
+static int do_xprt_debugfs(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *numv)
+{
+ int len;
+ char name[24]; /* enough for "../../rpc_xprt/ + 8 hex digits + NULL */
+ char link[9]; /* enough for 8 hex digits + NULL */
+ int *nump = numv;
+
+ if (IS_ERR_OR_NULL(xprt->debugfs))
+ return 0;
+ len = snprintf(name, sizeof(name), "../../rpc_xprt/%s",
+ xprt->debugfs->d_name.name);
+ if (len >= sizeof(name))
+ return -1;
+ if (*nump == 0)
+ strcpy(link, "xprt");
+ else {
+ len = snprintf(link, sizeof(link), "xprt%d", *nump);
+ if (len >= sizeof(link))
+ return -1;
+ }
+ debugfs_create_symlink(link, clnt->cl_debugfs, name);
+ (*nump)++;
+ return 0;
+}
+
+void
+rpc_clnt_debugfs_register(struct rpc_clnt *clnt)
+{
+ int len;
+ char name[9]; /* enough for 8 hex digits + NULL */
+ int xprtnum = 0;
+
+ len = snprintf(name, sizeof(name), "%x", clnt->cl_clid);
+ if (len >= sizeof(name))
+ return;
+
+ /* make the per-client dir */
+ clnt->cl_debugfs = debugfs_create_dir(name, rpc_clnt_dir);
+
+ /* make tasks file */
+ debugfs_create_file("tasks", S_IFREG | 0400, clnt->cl_debugfs, clnt,
+ &tasks_fops);
+
+ rpc_clnt_iterate_for_each_xprt(clnt, do_xprt_debugfs, &xprtnum);
+}
+
+void
+rpc_clnt_debugfs_unregister(struct rpc_clnt *clnt)
+{
+ debugfs_remove_recursive(clnt->cl_debugfs);
+ clnt->cl_debugfs = NULL;
+}
+
+static int
+xprt_info_show(struct seq_file *f, void *v)
+{
+ struct rpc_xprt *xprt = f->private;
+
+ seq_printf(f, "netid: %s\n", xprt->address_strings[RPC_DISPLAY_NETID]);
+ seq_printf(f, "addr: %s\n", xprt->address_strings[RPC_DISPLAY_ADDR]);
+ seq_printf(f, "port: %s\n", xprt->address_strings[RPC_DISPLAY_PORT]);
+ seq_printf(f, "state: 0x%lx\n", xprt->state);
+ return 0;
+}
+
+static int
+xprt_info_open(struct inode *inode, struct file *filp)
+{
+ int ret;
+ struct rpc_xprt *xprt = inode->i_private;
+
+ ret = single_open(filp, xprt_info_show, xprt);
+
+ if (!ret) {
+ if (!xprt_get(xprt)) {
+ single_release(inode, filp);
+ ret = -EINVAL;
+ }
+ }
+ return ret;
+}
+
+static int
+xprt_info_release(struct inode *inode, struct file *filp)
+{
+ struct rpc_xprt *xprt = inode->i_private;
+
+ xprt_put(xprt);
+ return single_release(inode, filp);
+}
+
+static const struct file_operations xprt_info_fops = {
+ .owner = THIS_MODULE,
+ .open = xprt_info_open,
+ .read = seq_read,
+ .llseek = seq_lseek,
+ .release = xprt_info_release,
+};
+
+void
+rpc_xprt_debugfs_register(struct rpc_xprt *xprt)
+{
+ int len, id;
+ static atomic_t cur_id;
+ char name[9]; /* 8 hex digits + NULL term */
+
+ id = (unsigned int)atomic_inc_return(&cur_id);
+
+ len = snprintf(name, sizeof(name), "%x", id);
+ if (len >= sizeof(name))
+ return;
+
+ /* make the per-client dir */
+ xprt->debugfs = debugfs_create_dir(name, rpc_xprt_dir);
+
+ /* make tasks file */
+ debugfs_create_file("info", S_IFREG | 0400, xprt->debugfs, xprt,
+ &xprt_info_fops);
+}
+
+void
+rpc_xprt_debugfs_unregister(struct rpc_xprt *xprt)
+{
+ debugfs_remove_recursive(xprt->debugfs);
+ xprt->debugfs = NULL;
+}
+
+#if IS_ENABLED(CONFIG_FAIL_SUNRPC)
+struct fail_sunrpc_attr fail_sunrpc = {
+ .attr = FAULT_ATTR_INITIALIZER,
+};
+EXPORT_SYMBOL_GPL(fail_sunrpc);
+
+static void fail_sunrpc_init(void)
+{
+ struct dentry *dir;
+
+ dir = fault_create_debugfs_attr("fail_sunrpc", NULL,
+ &fail_sunrpc.attr);
+
+ debugfs_create_bool("ignore-client-disconnect", S_IFREG | 0600, dir,
+ &fail_sunrpc.ignore_client_disconnect);
+
+ debugfs_create_bool("ignore-server-disconnect", S_IFREG | 0600, dir,
+ &fail_sunrpc.ignore_server_disconnect);
+
+ debugfs_create_bool("ignore-cache-wait", S_IFREG | 0600, dir,
+ &fail_sunrpc.ignore_cache_wait);
+}
+#else
+static void fail_sunrpc_init(void)
+{
+}
+#endif
+
+void __exit
+sunrpc_debugfs_exit(void)
+{
+ debugfs_remove_recursive(topdir);
+ topdir = NULL;
+ rpc_clnt_dir = NULL;
+ rpc_xprt_dir = NULL;
+}
+
+void __init
+sunrpc_debugfs_init(void)
+{
+ topdir = debugfs_create_dir("sunrpc", NULL);
+
+ rpc_clnt_dir = debugfs_create_dir("rpc_clnt", topdir);
+
+ rpc_xprt_dir = debugfs_create_dir("rpc_xprt", topdir);
+
+ fail_sunrpc_init();
+}
diff --git a/net/sunrpc/fail.h b/net/sunrpc/fail.h
new file mode 100644
index 0000000000..4b4b500df4
--- /dev/null
+++ b/net/sunrpc/fail.h
@@ -0,0 +1,25 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * Copyright (C) 2021, Oracle. All rights reserved.
+ */
+
+#ifndef _NET_SUNRPC_FAIL_H_
+#define _NET_SUNRPC_FAIL_H_
+
+#include <linux/fault-inject.h>
+
+#if IS_ENABLED(CONFIG_FAULT_INJECTION)
+
+struct fail_sunrpc_attr {
+ struct fault_attr attr;
+
+ bool ignore_client_disconnect;
+ bool ignore_server_disconnect;
+ bool ignore_cache_wait;
+};
+
+extern struct fail_sunrpc_attr fail_sunrpc;
+
+#endif /* CONFIG_FAULT_INJECTION */
+
+#endif /* _NET_SUNRPC_FAIL_H_ */
diff --git a/net/sunrpc/netns.h b/net/sunrpc/netns.h
new file mode 100644
index 0000000000..4efb5f28d8
--- /dev/null
+++ b/net/sunrpc/netns.h
@@ -0,0 +1,44 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef __SUNRPC_NETNS_H__
+#define __SUNRPC_NETNS_H__
+
+#include <net/net_namespace.h>
+#include <net/netns/generic.h>
+
+struct cache_detail;
+
+struct sunrpc_net {
+ struct proc_dir_entry *proc_net_rpc;
+ struct cache_detail *ip_map_cache;
+ struct cache_detail *unix_gid_cache;
+ struct cache_detail *rsc_cache;
+ struct cache_detail *rsi_cache;
+
+ struct super_block *pipefs_sb;
+ struct rpc_pipe *gssd_dummy;
+ struct mutex pipefs_sb_lock;
+
+ struct list_head all_clients;
+ spinlock_t rpc_client_lock;
+
+ struct rpc_clnt *rpcb_local_clnt;
+ struct rpc_clnt *rpcb_local_clnt4;
+ spinlock_t rpcb_clnt_lock;
+ unsigned int rpcb_users;
+ unsigned int rpcb_is_af_local : 1;
+
+ struct mutex gssp_lock;
+ struct rpc_clnt *gssp_clnt;
+ int use_gss_proxy;
+ int pipe_version;
+ atomic_t pipe_users;
+ struct proc_dir_entry *use_gssp_proc;
+ struct proc_dir_entry *gss_krb5_enctypes;
+};
+
+extern unsigned int sunrpc_net_id;
+
+int ip_map_cache_create(struct net *);
+void ip_map_cache_destroy(struct net *);
+
+#endif
diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c
new file mode 100644
index 0000000000..f420d84573
--- /dev/null
+++ b/net/sunrpc/rpc_pipe.c
@@ -0,0 +1,1517 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * net/sunrpc/rpc_pipe.c
+ *
+ * Userland/kernel interface for rpcauth_gss.
+ * Code shamelessly plagiarized from fs/nfsd/nfsctl.c
+ * and fs/sysfs/inode.c
+ *
+ * Copyright (c) 2002, Trond Myklebust <trond.myklebust@fys.uio.no>
+ *
+ */
+#include <linux/module.h>
+#include <linux/slab.h>
+#include <linux/string.h>
+#include <linux/pagemap.h>
+#include <linux/mount.h>
+#include <linux/fs_context.h>
+#include <linux/namei.h>
+#include <linux/fsnotify.h>
+#include <linux/kernel.h>
+#include <linux/rcupdate.h>
+#include <linux/utsname.h>
+
+#include <asm/ioctls.h>
+#include <linux/poll.h>
+#include <linux/wait.h>
+#include <linux/seq_file.h>
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/workqueue.h>
+#include <linux/sunrpc/rpc_pipe_fs.h>
+#include <linux/sunrpc/cache.h>
+#include <linux/nsproxy.h>
+#include <linux/notifier.h>
+
+#include "netns.h"
+#include "sunrpc.h"
+
+#define RPCDBG_FACILITY RPCDBG_DEBUG
+
+#define NET_NAME(net) ((net == &init_net) ? " (init_net)" : "")
+
+static struct file_system_type rpc_pipe_fs_type;
+static const struct rpc_pipe_ops gssd_dummy_pipe_ops;
+
+static struct kmem_cache *rpc_inode_cachep __read_mostly;
+
+#define RPC_UPCALL_TIMEOUT (30*HZ)
+
+static BLOCKING_NOTIFIER_HEAD(rpc_pipefs_notifier_list);
+
+int rpc_pipefs_notifier_register(struct notifier_block *nb)
+{
+ return blocking_notifier_chain_register(&rpc_pipefs_notifier_list, nb);
+}
+EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_register);
+
+void rpc_pipefs_notifier_unregister(struct notifier_block *nb)
+{
+ blocking_notifier_chain_unregister(&rpc_pipefs_notifier_list, nb);
+}
+EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_unregister);
+
+static void rpc_purge_list(wait_queue_head_t *waitq, struct list_head *head,
+ void (*destroy_msg)(struct rpc_pipe_msg *), int err)
+{
+ struct rpc_pipe_msg *msg;
+
+ if (list_empty(head))
+ return;
+ do {
+ msg = list_entry(head->next, struct rpc_pipe_msg, list);
+ list_del_init(&msg->list);
+ msg->errno = err;
+ destroy_msg(msg);
+ } while (!list_empty(head));
+
+ if (waitq)
+ wake_up(waitq);
+}
+
+static void
+rpc_timeout_upcall_queue(struct work_struct *work)
+{
+ LIST_HEAD(free_list);
+ struct rpc_pipe *pipe =
+ container_of(work, struct rpc_pipe, queue_timeout.work);
+ void (*destroy_msg)(struct rpc_pipe_msg *);
+ struct dentry *dentry;
+
+ spin_lock(&pipe->lock);
+ destroy_msg = pipe->ops->destroy_msg;
+ if (pipe->nreaders == 0) {
+ list_splice_init(&pipe->pipe, &free_list);
+ pipe->pipelen = 0;
+ }
+ dentry = dget(pipe->dentry);
+ spin_unlock(&pipe->lock);
+ rpc_purge_list(dentry ? &RPC_I(d_inode(dentry))->waitq : NULL,
+ &free_list, destroy_msg, -ETIMEDOUT);
+ dput(dentry);
+}
+
+ssize_t rpc_pipe_generic_upcall(struct file *filp, struct rpc_pipe_msg *msg,
+ char __user *dst, size_t buflen)
+{
+ char *data = (char *)msg->data + msg->copied;
+ size_t mlen = min(msg->len - msg->copied, buflen);
+ unsigned long left;
+
+ left = copy_to_user(dst, data, mlen);
+ if (left == mlen) {
+ msg->errno = -EFAULT;
+ return -EFAULT;
+ }
+
+ mlen -= left;
+ msg->copied += mlen;
+ msg->errno = 0;
+ return mlen;
+}
+EXPORT_SYMBOL_GPL(rpc_pipe_generic_upcall);
+
+/**
+ * rpc_queue_upcall - queue an upcall message to userspace
+ * @pipe: upcall pipe on which to queue given message
+ * @msg: message to queue
+ *
+ * Call with an @inode created by rpc_mkpipe() to queue an upcall.
+ * A userspace process may then later read the upcall by performing a
+ * read on an open file for this inode. It is up to the caller to
+ * initialize the fields of @msg (other than @msg->list) appropriately.
+ */
+int
+rpc_queue_upcall(struct rpc_pipe *pipe, struct rpc_pipe_msg *msg)
+{
+ int res = -EPIPE;
+ struct dentry *dentry;
+
+ spin_lock(&pipe->lock);
+ if (pipe->nreaders) {
+ list_add_tail(&msg->list, &pipe->pipe);
+ pipe->pipelen += msg->len;
+ res = 0;
+ } else if (pipe->flags & RPC_PIPE_WAIT_FOR_OPEN) {
+ if (list_empty(&pipe->pipe))
+ queue_delayed_work(rpciod_workqueue,
+ &pipe->queue_timeout,
+ RPC_UPCALL_TIMEOUT);
+ list_add_tail(&msg->list, &pipe->pipe);
+ pipe->pipelen += msg->len;
+ res = 0;
+ }
+ dentry = dget(pipe->dentry);
+ spin_unlock(&pipe->lock);
+ if (dentry) {
+ wake_up(&RPC_I(d_inode(dentry))->waitq);
+ dput(dentry);
+ }
+ return res;
+}
+EXPORT_SYMBOL_GPL(rpc_queue_upcall);
+
+static inline void
+rpc_inode_setowner(struct inode *inode, void *private)
+{
+ RPC_I(inode)->private = private;
+}
+
+static void
+rpc_close_pipes(struct inode *inode)
+{
+ struct rpc_pipe *pipe = RPC_I(inode)->pipe;
+ int need_release;
+ LIST_HEAD(free_list);
+
+ inode_lock(inode);
+ spin_lock(&pipe->lock);
+ need_release = pipe->nreaders != 0 || pipe->nwriters != 0;
+ pipe->nreaders = 0;
+ list_splice_init(&pipe->in_upcall, &free_list);
+ list_splice_init(&pipe->pipe, &free_list);
+ pipe->pipelen = 0;
+ pipe->dentry = NULL;
+ spin_unlock(&pipe->lock);
+ rpc_purge_list(&RPC_I(inode)->waitq, &free_list, pipe->ops->destroy_msg, -EPIPE);
+ pipe->nwriters = 0;
+ if (need_release && pipe->ops->release_pipe)
+ pipe->ops->release_pipe(inode);
+ cancel_delayed_work_sync(&pipe->queue_timeout);
+ rpc_inode_setowner(inode, NULL);
+ RPC_I(inode)->pipe = NULL;
+ inode_unlock(inode);
+}
+
+static struct inode *
+rpc_alloc_inode(struct super_block *sb)
+{
+ struct rpc_inode *rpci;
+ rpci = alloc_inode_sb(sb, rpc_inode_cachep, GFP_KERNEL);
+ if (!rpci)
+ return NULL;
+ return &rpci->vfs_inode;
+}
+
+static void
+rpc_free_inode(struct inode *inode)
+{
+ kmem_cache_free(rpc_inode_cachep, RPC_I(inode));
+}
+
+static int
+rpc_pipe_open(struct inode *inode, struct file *filp)
+{
+ struct rpc_pipe *pipe;
+ int first_open;
+ int res = -ENXIO;
+
+ inode_lock(inode);
+ pipe = RPC_I(inode)->pipe;
+ if (pipe == NULL)
+ goto out;
+ first_open = pipe->nreaders == 0 && pipe->nwriters == 0;
+ if (first_open && pipe->ops->open_pipe) {
+ res = pipe->ops->open_pipe(inode);
+ if (res)
+ goto out;
+ }
+ if (filp->f_mode & FMODE_READ)
+ pipe->nreaders++;
+ if (filp->f_mode & FMODE_WRITE)
+ pipe->nwriters++;
+ res = 0;
+out:
+ inode_unlock(inode);
+ return res;
+}
+
+static int
+rpc_pipe_release(struct inode *inode, struct file *filp)
+{
+ struct rpc_pipe *pipe;
+ struct rpc_pipe_msg *msg;
+ int last_close;
+
+ inode_lock(inode);
+ pipe = RPC_I(inode)->pipe;
+ if (pipe == NULL)
+ goto out;
+ msg = filp->private_data;
+ if (msg != NULL) {
+ spin_lock(&pipe->lock);
+ msg->errno = -EAGAIN;
+ list_del_init(&msg->list);
+ spin_unlock(&pipe->lock);
+ pipe->ops->destroy_msg(msg);
+ }
+ if (filp->f_mode & FMODE_WRITE)
+ pipe->nwriters --;
+ if (filp->f_mode & FMODE_READ) {
+ pipe->nreaders --;
+ if (pipe->nreaders == 0) {
+ LIST_HEAD(free_list);
+ spin_lock(&pipe->lock);
+ list_splice_init(&pipe->pipe, &free_list);
+ pipe->pipelen = 0;
+ spin_unlock(&pipe->lock);
+ rpc_purge_list(&RPC_I(inode)->waitq, &free_list,
+ pipe->ops->destroy_msg, -EAGAIN);
+ }
+ }
+ last_close = pipe->nwriters == 0 && pipe->nreaders == 0;
+ if (last_close && pipe->ops->release_pipe)
+ pipe->ops->release_pipe(inode);
+out:
+ inode_unlock(inode);
+ return 0;
+}
+
+static ssize_t
+rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset)
+{
+ struct inode *inode = file_inode(filp);
+ struct rpc_pipe *pipe;
+ struct rpc_pipe_msg *msg;
+ int res = 0;
+
+ inode_lock(inode);
+ pipe = RPC_I(inode)->pipe;
+ if (pipe == NULL) {
+ res = -EPIPE;
+ goto out_unlock;
+ }
+ msg = filp->private_data;
+ if (msg == NULL) {
+ spin_lock(&pipe->lock);
+ if (!list_empty(&pipe->pipe)) {
+ msg = list_entry(pipe->pipe.next,
+ struct rpc_pipe_msg,
+ list);
+ list_move(&msg->list, &pipe->in_upcall);
+ pipe->pipelen -= msg->len;
+ filp->private_data = msg;
+ msg->copied = 0;
+ }
+ spin_unlock(&pipe->lock);
+ if (msg == NULL)
+ goto out_unlock;
+ }
+ /* NOTE: it is up to the callback to update msg->copied */
+ res = pipe->ops->upcall(filp, msg, buf, len);
+ if (res < 0 || msg->len == msg->copied) {
+ filp->private_data = NULL;
+ spin_lock(&pipe->lock);
+ list_del_init(&msg->list);
+ spin_unlock(&pipe->lock);
+ pipe->ops->destroy_msg(msg);
+ }
+out_unlock:
+ inode_unlock(inode);
+ return res;
+}
+
+static ssize_t
+rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *offset)
+{
+ struct inode *inode = file_inode(filp);
+ int res;
+
+ inode_lock(inode);
+ res = -EPIPE;
+ if (RPC_I(inode)->pipe != NULL)
+ res = RPC_I(inode)->pipe->ops->downcall(filp, buf, len);
+ inode_unlock(inode);
+ return res;
+}
+
+static __poll_t
+rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait)
+{
+ struct inode *inode = file_inode(filp);
+ struct rpc_inode *rpci = RPC_I(inode);
+ __poll_t mask = EPOLLOUT | EPOLLWRNORM;
+
+ poll_wait(filp, &rpci->waitq, wait);
+
+ inode_lock(inode);
+ if (rpci->pipe == NULL)
+ mask |= EPOLLERR | EPOLLHUP;
+ else if (filp->private_data || !list_empty(&rpci->pipe->pipe))
+ mask |= EPOLLIN | EPOLLRDNORM;
+ inode_unlock(inode);
+ return mask;
+}
+
+static long
+rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg)
+{
+ struct inode *inode = file_inode(filp);
+ struct rpc_pipe *pipe;
+ int len;
+
+ switch (cmd) {
+ case FIONREAD:
+ inode_lock(inode);
+ pipe = RPC_I(inode)->pipe;
+ if (pipe == NULL) {
+ inode_unlock(inode);
+ return -EPIPE;
+ }
+ spin_lock(&pipe->lock);
+ len = pipe->pipelen;
+ if (filp->private_data) {
+ struct rpc_pipe_msg *msg;
+ msg = filp->private_data;
+ len += msg->len - msg->copied;
+ }
+ spin_unlock(&pipe->lock);
+ inode_unlock(inode);
+ return put_user(len, (int __user *)arg);
+ default:
+ return -EINVAL;
+ }
+}
+
+static const struct file_operations rpc_pipe_fops = {
+ .owner = THIS_MODULE,
+ .llseek = no_llseek,
+ .read = rpc_pipe_read,
+ .write = rpc_pipe_write,
+ .poll = rpc_pipe_poll,
+ .unlocked_ioctl = rpc_pipe_ioctl,
+ .open = rpc_pipe_open,
+ .release = rpc_pipe_release,
+};
+
+static int
+rpc_show_info(struct seq_file *m, void *v)
+{
+ struct rpc_clnt *clnt = m->private;
+
+ rcu_read_lock();
+ seq_printf(m, "RPC server: %s\n",
+ rcu_dereference(clnt->cl_xprt)->servername);
+ seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_program->name,
+ clnt->cl_prog, clnt->cl_vers);
+ seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR));
+ seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO));
+ seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT));
+ rcu_read_unlock();
+ return 0;
+}
+
+static int
+rpc_info_open(struct inode *inode, struct file *file)
+{
+ struct rpc_clnt *clnt = NULL;
+ int ret = single_open(file, rpc_show_info, NULL);
+
+ if (!ret) {
+ struct seq_file *m = file->private_data;
+
+ spin_lock(&file->f_path.dentry->d_lock);
+ if (!d_unhashed(file->f_path.dentry))
+ clnt = RPC_I(inode)->private;
+ if (clnt != NULL && refcount_inc_not_zero(&clnt->cl_count)) {
+ spin_unlock(&file->f_path.dentry->d_lock);
+ m->private = clnt;
+ } else {
+ spin_unlock(&file->f_path.dentry->d_lock);
+ single_release(inode, file);
+ ret = -EINVAL;
+ }
+ }
+ return ret;
+}
+
+static int
+rpc_info_release(struct inode *inode, struct file *file)
+{
+ struct seq_file *m = file->private_data;
+ struct rpc_clnt *clnt = (struct rpc_clnt *)m->private;
+
+ if (clnt)
+ rpc_release_client(clnt);
+ return single_release(inode, file);
+}
+
+static const struct file_operations rpc_info_operations = {
+ .owner = THIS_MODULE,
+ .open = rpc_info_open,
+ .read = seq_read,
+ .llseek = seq_lseek,
+ .release = rpc_info_release,
+};
+
+
+/*
+ * Description of fs contents.
+ */
+struct rpc_filelist {
+ const char *name;
+ const struct file_operations *i_fop;
+ umode_t mode;
+};
+
+static struct inode *
+rpc_get_inode(struct super_block *sb, umode_t mode)
+{
+ struct inode *inode = new_inode(sb);
+ if (!inode)
+ return NULL;
+ inode->i_ino = get_next_ino();
+ inode->i_mode = mode;
+ inode->i_atime = inode->i_mtime = inode_set_ctime_current(inode);
+ switch (mode & S_IFMT) {
+ case S_IFDIR:
+ inode->i_fop = &simple_dir_operations;
+ inode->i_op = &simple_dir_inode_operations;
+ inc_nlink(inode);
+ break;
+ default:
+ break;
+ }
+ return inode;
+}
+
+static int __rpc_create_common(struct inode *dir, struct dentry *dentry,
+ umode_t mode,
+ const struct file_operations *i_fop,
+ void *private)
+{
+ struct inode *inode;
+
+ d_drop(dentry);
+ inode = rpc_get_inode(dir->i_sb, mode);
+ if (!inode)
+ goto out_err;
+ inode->i_ino = iunique(dir->i_sb, 100);
+ if (i_fop)
+ inode->i_fop = i_fop;
+ if (private)
+ rpc_inode_setowner(inode, private);
+ d_add(dentry, inode);
+ return 0;
+out_err:
+ printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %pd\n",
+ __FILE__, __func__, dentry);
+ dput(dentry);
+ return -ENOMEM;
+}
+
+static int __rpc_create(struct inode *dir, struct dentry *dentry,
+ umode_t mode,
+ const struct file_operations *i_fop,
+ void *private)
+{
+ int err;
+
+ err = __rpc_create_common(dir, dentry, S_IFREG | mode, i_fop, private);
+ if (err)
+ return err;
+ fsnotify_create(dir, dentry);
+ return 0;
+}
+
+static int __rpc_mkdir(struct inode *dir, struct dentry *dentry,
+ umode_t mode,
+ const struct file_operations *i_fop,
+ void *private)
+{
+ int err;
+
+ err = __rpc_create_common(dir, dentry, S_IFDIR | mode, i_fop, private);
+ if (err)
+ return err;
+ inc_nlink(dir);
+ fsnotify_mkdir(dir, dentry);
+ return 0;
+}
+
+static void
+init_pipe(struct rpc_pipe *pipe)
+{
+ pipe->nreaders = 0;
+ pipe->nwriters = 0;
+ INIT_LIST_HEAD(&pipe->in_upcall);
+ INIT_LIST_HEAD(&pipe->in_downcall);
+ INIT_LIST_HEAD(&pipe->pipe);
+ pipe->pipelen = 0;
+ INIT_DELAYED_WORK(&pipe->queue_timeout,
+ rpc_timeout_upcall_queue);
+ pipe->ops = NULL;
+ spin_lock_init(&pipe->lock);
+ pipe->dentry = NULL;
+}
+
+void rpc_destroy_pipe_data(struct rpc_pipe *pipe)
+{
+ kfree(pipe);
+}
+EXPORT_SYMBOL_GPL(rpc_destroy_pipe_data);
+
+struct rpc_pipe *rpc_mkpipe_data(const struct rpc_pipe_ops *ops, int flags)
+{
+ struct rpc_pipe *pipe;
+
+ pipe = kzalloc(sizeof(struct rpc_pipe), GFP_KERNEL);
+ if (!pipe)
+ return ERR_PTR(-ENOMEM);
+ init_pipe(pipe);
+ pipe->ops = ops;
+ pipe->flags = flags;
+ return pipe;
+}
+EXPORT_SYMBOL_GPL(rpc_mkpipe_data);
+
+static int __rpc_mkpipe_dentry(struct inode *dir, struct dentry *dentry,
+ umode_t mode,
+ const struct file_operations *i_fop,
+ void *private,
+ struct rpc_pipe *pipe)
+{
+ struct rpc_inode *rpci;
+ int err;
+
+ err = __rpc_create_common(dir, dentry, S_IFIFO | mode, i_fop, private);
+ if (err)
+ return err;
+ rpci = RPC_I(d_inode(dentry));
+ rpci->private = private;
+ rpci->pipe = pipe;
+ fsnotify_create(dir, dentry);
+ return 0;
+}
+
+static int __rpc_rmdir(struct inode *dir, struct dentry *dentry)
+{
+ int ret;
+
+ dget(dentry);
+ ret = simple_rmdir(dir, dentry);
+ d_drop(dentry);
+ if (!ret)
+ fsnotify_rmdir(dir, dentry);
+ dput(dentry);
+ return ret;
+}
+
+static int __rpc_unlink(struct inode *dir, struct dentry *dentry)
+{
+ int ret;
+
+ dget(dentry);
+ ret = simple_unlink(dir, dentry);
+ d_drop(dentry);
+ if (!ret)
+ fsnotify_unlink(dir, dentry);
+ dput(dentry);
+ return ret;
+}
+
+static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry)
+{
+ struct inode *inode = d_inode(dentry);
+
+ rpc_close_pipes(inode);
+ return __rpc_unlink(dir, dentry);
+}
+
+static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent,
+ const char *name)
+{
+ struct qstr q = QSTR_INIT(name, strlen(name));
+ struct dentry *dentry = d_hash_and_lookup(parent, &q);
+ if (!dentry) {
+ dentry = d_alloc(parent, &q);
+ if (!dentry)
+ return ERR_PTR(-ENOMEM);
+ }
+ if (d_really_is_negative(dentry))
+ return dentry;
+ dput(dentry);
+ return ERR_PTR(-EEXIST);
+}
+
+/*
+ * FIXME: This probably has races.
+ */
+static void __rpc_depopulate(struct dentry *parent,
+ const struct rpc_filelist *files,
+ int start, int eof)
+{
+ struct inode *dir = d_inode(parent);
+ struct dentry *dentry;
+ struct qstr name;
+ int i;
+
+ for (i = start; i < eof; i++) {
+ name.name = files[i].name;
+ name.len = strlen(files[i].name);
+ dentry = d_hash_and_lookup(parent, &name);
+
+ if (dentry == NULL)
+ continue;
+ if (d_really_is_negative(dentry))
+ goto next;
+ switch (d_inode(dentry)->i_mode & S_IFMT) {
+ default:
+ BUG();
+ case S_IFREG:
+ __rpc_unlink(dir, dentry);
+ break;
+ case S_IFDIR:
+ __rpc_rmdir(dir, dentry);
+ }
+next:
+ dput(dentry);
+ }
+}
+
+static void rpc_depopulate(struct dentry *parent,
+ const struct rpc_filelist *files,
+ int start, int eof)
+{
+ struct inode *dir = d_inode(parent);
+
+ inode_lock_nested(dir, I_MUTEX_CHILD);
+ __rpc_depopulate(parent, files, start, eof);
+ inode_unlock(dir);
+}
+
+static int rpc_populate(struct dentry *parent,
+ const struct rpc_filelist *files,
+ int start, int eof,
+ void *private)
+{
+ struct inode *dir = d_inode(parent);
+ struct dentry *dentry;
+ int i, err;
+
+ inode_lock(dir);
+ for (i = start; i < eof; i++) {
+ dentry = __rpc_lookup_create_exclusive(parent, files[i].name);
+ err = PTR_ERR(dentry);
+ if (IS_ERR(dentry))
+ goto out_bad;
+ switch (files[i].mode & S_IFMT) {
+ default:
+ BUG();
+ case S_IFREG:
+ err = __rpc_create(dir, dentry,
+ files[i].mode,
+ files[i].i_fop,
+ private);
+ break;
+ case S_IFDIR:
+ err = __rpc_mkdir(dir, dentry,
+ files[i].mode,
+ NULL,
+ private);
+ }
+ if (err != 0)
+ goto out_bad;
+ }
+ inode_unlock(dir);
+ return 0;
+out_bad:
+ __rpc_depopulate(parent, files, start, eof);
+ inode_unlock(dir);
+ printk(KERN_WARNING "%s: %s failed to populate directory %pd\n",
+ __FILE__, __func__, parent);
+ return err;
+}
+
+static struct dentry *rpc_mkdir_populate(struct dentry *parent,
+ const char *name, umode_t mode, void *private,
+ int (*populate)(struct dentry *, void *), void *args_populate)
+{
+ struct dentry *dentry;
+ struct inode *dir = d_inode(parent);
+ int error;
+
+ inode_lock_nested(dir, I_MUTEX_PARENT);
+ dentry = __rpc_lookup_create_exclusive(parent, name);
+ if (IS_ERR(dentry))
+ goto out;
+ error = __rpc_mkdir(dir, dentry, mode, NULL, private);
+ if (error != 0)
+ goto out_err;
+ if (populate != NULL) {
+ error = populate(dentry, args_populate);
+ if (error)
+ goto err_rmdir;
+ }
+out:
+ inode_unlock(dir);
+ return dentry;
+err_rmdir:
+ __rpc_rmdir(dir, dentry);
+out_err:
+ dentry = ERR_PTR(error);
+ goto out;
+}
+
+static int rpc_rmdir_depopulate(struct dentry *dentry,
+ void (*depopulate)(struct dentry *))
+{
+ struct dentry *parent;
+ struct inode *dir;
+ int error;
+
+ parent = dget_parent(dentry);
+ dir = d_inode(parent);
+ inode_lock_nested(dir, I_MUTEX_PARENT);
+ if (depopulate != NULL)
+ depopulate(dentry);
+ error = __rpc_rmdir(dir, dentry);
+ inode_unlock(dir);
+ dput(parent);
+ return error;
+}
+
+/**
+ * rpc_mkpipe_dentry - make an rpc_pipefs file for kernel<->userspace
+ * communication
+ * @parent: dentry of directory to create new "pipe" in
+ * @name: name of pipe
+ * @private: private data to associate with the pipe, for the caller's use
+ * @pipe: &rpc_pipe containing input parameters
+ *
+ * Data is made available for userspace to read by calls to
+ * rpc_queue_upcall(). The actual reads will result in calls to
+ * @ops->upcall, which will be called with the file pointer,
+ * message, and userspace buffer to copy to.
+ *
+ * Writes can come at any time, and do not necessarily have to be
+ * responses to upcalls. They will result in calls to @msg->downcall.
+ *
+ * The @private argument passed here will be available to all these methods
+ * from the file pointer, via RPC_I(file_inode(file))->private.
+ */
+struct dentry *rpc_mkpipe_dentry(struct dentry *parent, const char *name,
+ void *private, struct rpc_pipe *pipe)
+{
+ struct dentry *dentry;
+ struct inode *dir = d_inode(parent);
+ umode_t umode = S_IFIFO | 0600;
+ int err;
+
+ if (pipe->ops->upcall == NULL)
+ umode &= ~0444;
+ if (pipe->ops->downcall == NULL)
+ umode &= ~0222;
+
+ inode_lock_nested(dir, I_MUTEX_PARENT);
+ dentry = __rpc_lookup_create_exclusive(parent, name);
+ if (IS_ERR(dentry))
+ goto out;
+ err = __rpc_mkpipe_dentry(dir, dentry, umode, &rpc_pipe_fops,
+ private, pipe);
+ if (err)
+ goto out_err;
+out:
+ inode_unlock(dir);
+ return dentry;
+out_err:
+ dentry = ERR_PTR(err);
+ printk(KERN_WARNING "%s: %s() failed to create pipe %pd/%s (errno = %d)\n",
+ __FILE__, __func__, parent, name,
+ err);
+ goto out;
+}
+EXPORT_SYMBOL_GPL(rpc_mkpipe_dentry);
+
+/**
+ * rpc_unlink - remove a pipe
+ * @dentry: dentry for the pipe, as returned from rpc_mkpipe
+ *
+ * After this call, lookups will no longer find the pipe, and any
+ * attempts to read or write using preexisting opens of the pipe will
+ * return -EPIPE.
+ */
+int
+rpc_unlink(struct dentry *dentry)
+{
+ struct dentry *parent;
+ struct inode *dir;
+ int error = 0;
+
+ parent = dget_parent(dentry);
+ dir = d_inode(parent);
+ inode_lock_nested(dir, I_MUTEX_PARENT);
+ error = __rpc_rmpipe(dir, dentry);
+ inode_unlock(dir);
+ dput(parent);
+ return error;
+}
+EXPORT_SYMBOL_GPL(rpc_unlink);
+
+/**
+ * rpc_init_pipe_dir_head - initialise a struct rpc_pipe_dir_head
+ * @pdh: pointer to struct rpc_pipe_dir_head
+ */
+void rpc_init_pipe_dir_head(struct rpc_pipe_dir_head *pdh)
+{
+ INIT_LIST_HEAD(&pdh->pdh_entries);
+ pdh->pdh_dentry = NULL;
+}
+EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_head);
+
+/**
+ * rpc_init_pipe_dir_object - initialise a struct rpc_pipe_dir_object
+ * @pdo: pointer to struct rpc_pipe_dir_object
+ * @pdo_ops: pointer to const struct rpc_pipe_dir_object_ops
+ * @pdo_data: pointer to caller-defined data
+ */
+void rpc_init_pipe_dir_object(struct rpc_pipe_dir_object *pdo,
+ const struct rpc_pipe_dir_object_ops *pdo_ops,
+ void *pdo_data)
+{
+ INIT_LIST_HEAD(&pdo->pdo_head);
+ pdo->pdo_ops = pdo_ops;
+ pdo->pdo_data = pdo_data;
+}
+EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_object);
+
+static int
+rpc_add_pipe_dir_object_locked(struct net *net,
+ struct rpc_pipe_dir_head *pdh,
+ struct rpc_pipe_dir_object *pdo)
+{
+ int ret = 0;
+
+ if (pdh->pdh_dentry)
+ ret = pdo->pdo_ops->create(pdh->pdh_dentry, pdo);
+ if (ret == 0)
+ list_add_tail(&pdo->pdo_head, &pdh->pdh_entries);
+ return ret;
+}
+
+static void
+rpc_remove_pipe_dir_object_locked(struct net *net,
+ struct rpc_pipe_dir_head *pdh,
+ struct rpc_pipe_dir_object *pdo)
+{
+ if (pdh->pdh_dentry)
+ pdo->pdo_ops->destroy(pdh->pdh_dentry, pdo);
+ list_del_init(&pdo->pdo_head);
+}
+
+/**
+ * rpc_add_pipe_dir_object - associate a rpc_pipe_dir_object to a directory
+ * @net: pointer to struct net
+ * @pdh: pointer to struct rpc_pipe_dir_head
+ * @pdo: pointer to struct rpc_pipe_dir_object
+ *
+ */
+int
+rpc_add_pipe_dir_object(struct net *net,
+ struct rpc_pipe_dir_head *pdh,
+ struct rpc_pipe_dir_object *pdo)
+{
+ int ret = 0;
+
+ if (list_empty(&pdo->pdo_head)) {
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ mutex_lock(&sn->pipefs_sb_lock);
+ ret = rpc_add_pipe_dir_object_locked(net, pdh, pdo);
+ mutex_unlock(&sn->pipefs_sb_lock);
+ }
+ return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_add_pipe_dir_object);
+
+/**
+ * rpc_remove_pipe_dir_object - remove a rpc_pipe_dir_object from a directory
+ * @net: pointer to struct net
+ * @pdh: pointer to struct rpc_pipe_dir_head
+ * @pdo: pointer to struct rpc_pipe_dir_object
+ *
+ */
+void
+rpc_remove_pipe_dir_object(struct net *net,
+ struct rpc_pipe_dir_head *pdh,
+ struct rpc_pipe_dir_object *pdo)
+{
+ if (!list_empty(&pdo->pdo_head)) {
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ mutex_lock(&sn->pipefs_sb_lock);
+ rpc_remove_pipe_dir_object_locked(net, pdh, pdo);
+ mutex_unlock(&sn->pipefs_sb_lock);
+ }
+}
+EXPORT_SYMBOL_GPL(rpc_remove_pipe_dir_object);
+
+/**
+ * rpc_find_or_alloc_pipe_dir_object
+ * @net: pointer to struct net
+ * @pdh: pointer to struct rpc_pipe_dir_head
+ * @match: match struct rpc_pipe_dir_object to data
+ * @alloc: allocate a new struct rpc_pipe_dir_object
+ * @data: user defined data for match() and alloc()
+ *
+ */
+struct rpc_pipe_dir_object *
+rpc_find_or_alloc_pipe_dir_object(struct net *net,
+ struct rpc_pipe_dir_head *pdh,
+ int (*match)(struct rpc_pipe_dir_object *, void *),
+ struct rpc_pipe_dir_object *(*alloc)(void *),
+ void *data)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_pipe_dir_object *pdo;
+
+ mutex_lock(&sn->pipefs_sb_lock);
+ list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) {
+ if (!match(pdo, data))
+ continue;
+ goto out;
+ }
+ pdo = alloc(data);
+ if (!pdo)
+ goto out;
+ rpc_add_pipe_dir_object_locked(net, pdh, pdo);
+out:
+ mutex_unlock(&sn->pipefs_sb_lock);
+ return pdo;
+}
+EXPORT_SYMBOL_GPL(rpc_find_or_alloc_pipe_dir_object);
+
+static void
+rpc_create_pipe_dir_objects(struct rpc_pipe_dir_head *pdh)
+{
+ struct rpc_pipe_dir_object *pdo;
+ struct dentry *dir = pdh->pdh_dentry;
+
+ list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head)
+ pdo->pdo_ops->create(dir, pdo);
+}
+
+static void
+rpc_destroy_pipe_dir_objects(struct rpc_pipe_dir_head *pdh)
+{
+ struct rpc_pipe_dir_object *pdo;
+ struct dentry *dir = pdh->pdh_dentry;
+
+ list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head)
+ pdo->pdo_ops->destroy(dir, pdo);
+}
+
+enum {
+ RPCAUTH_info,
+ RPCAUTH_EOF
+};
+
+static const struct rpc_filelist authfiles[] = {
+ [RPCAUTH_info] = {
+ .name = "info",
+ .i_fop = &rpc_info_operations,
+ .mode = S_IFREG | 0400,
+ },
+};
+
+static int rpc_clntdir_populate(struct dentry *dentry, void *private)
+{
+ return rpc_populate(dentry,
+ authfiles, RPCAUTH_info, RPCAUTH_EOF,
+ private);
+}
+
+static void rpc_clntdir_depopulate(struct dentry *dentry)
+{
+ rpc_depopulate(dentry, authfiles, RPCAUTH_info, RPCAUTH_EOF);
+}
+
+/**
+ * rpc_create_client_dir - Create a new rpc_client directory in rpc_pipefs
+ * @dentry: the parent of new directory
+ * @name: the name of new directory
+ * @rpc_client: rpc client to associate with this directory
+ *
+ * This creates a directory at the given @path associated with
+ * @rpc_clnt, which will contain a file named "info" with some basic
+ * information about the client, together with any "pipes" that may
+ * later be created using rpc_mkpipe().
+ */
+struct dentry *rpc_create_client_dir(struct dentry *dentry,
+ const char *name,
+ struct rpc_clnt *rpc_client)
+{
+ struct dentry *ret;
+
+ ret = rpc_mkdir_populate(dentry, name, 0555, NULL,
+ rpc_clntdir_populate, rpc_client);
+ if (!IS_ERR(ret)) {
+ rpc_client->cl_pipedir_objects.pdh_dentry = ret;
+ rpc_create_pipe_dir_objects(&rpc_client->cl_pipedir_objects);
+ }
+ return ret;
+}
+
+/**
+ * rpc_remove_client_dir - Remove a directory created with rpc_create_client_dir()
+ * @rpc_client: rpc_client for the pipe
+ */
+int rpc_remove_client_dir(struct rpc_clnt *rpc_client)
+{
+ struct dentry *dentry = rpc_client->cl_pipedir_objects.pdh_dentry;
+
+ if (dentry == NULL)
+ return 0;
+ rpc_destroy_pipe_dir_objects(&rpc_client->cl_pipedir_objects);
+ rpc_client->cl_pipedir_objects.pdh_dentry = NULL;
+ return rpc_rmdir_depopulate(dentry, rpc_clntdir_depopulate);
+}
+
+static const struct rpc_filelist cache_pipefs_files[3] = {
+ [0] = {
+ .name = "channel",
+ .i_fop = &cache_file_operations_pipefs,
+ .mode = S_IFREG | 0600,
+ },
+ [1] = {
+ .name = "content",
+ .i_fop = &content_file_operations_pipefs,
+ .mode = S_IFREG | 0400,
+ },
+ [2] = {
+ .name = "flush",
+ .i_fop = &cache_flush_operations_pipefs,
+ .mode = S_IFREG | 0600,
+ },
+};
+
+static int rpc_cachedir_populate(struct dentry *dentry, void *private)
+{
+ return rpc_populate(dentry,
+ cache_pipefs_files, 0, 3,
+ private);
+}
+
+static void rpc_cachedir_depopulate(struct dentry *dentry)
+{
+ rpc_depopulate(dentry, cache_pipefs_files, 0, 3);
+}
+
+struct dentry *rpc_create_cache_dir(struct dentry *parent, const char *name,
+ umode_t umode, struct cache_detail *cd)
+{
+ return rpc_mkdir_populate(parent, name, umode, NULL,
+ rpc_cachedir_populate, cd);
+}
+
+void rpc_remove_cache_dir(struct dentry *dentry)
+{
+ rpc_rmdir_depopulate(dentry, rpc_cachedir_depopulate);
+}
+
+/*
+ * populate the filesystem
+ */
+static const struct super_operations s_ops = {
+ .alloc_inode = rpc_alloc_inode,
+ .free_inode = rpc_free_inode,
+ .statfs = simple_statfs,
+};
+
+#define RPCAUTH_GSSMAGIC 0x67596969
+
+/*
+ * We have a single directory with 1 node in it.
+ */
+enum {
+ RPCAUTH_lockd,
+ RPCAUTH_mount,
+ RPCAUTH_nfs,
+ RPCAUTH_portmap,
+ RPCAUTH_statd,
+ RPCAUTH_nfsd4_cb,
+ RPCAUTH_cache,
+ RPCAUTH_nfsd,
+ RPCAUTH_gssd,
+ RPCAUTH_RootEOF
+};
+
+static const struct rpc_filelist files[] = {
+ [RPCAUTH_lockd] = {
+ .name = "lockd",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_mount] = {
+ .name = "mount",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_nfs] = {
+ .name = "nfs",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_portmap] = {
+ .name = "portmap",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_statd] = {
+ .name = "statd",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_nfsd4_cb] = {
+ .name = "nfsd4_cb",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_cache] = {
+ .name = "cache",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_nfsd] = {
+ .name = "nfsd",
+ .mode = S_IFDIR | 0555,
+ },
+ [RPCAUTH_gssd] = {
+ .name = "gssd",
+ .mode = S_IFDIR | 0555,
+ },
+};
+
+/*
+ * This call can be used only in RPC pipefs mount notification hooks.
+ */
+struct dentry *rpc_d_lookup_sb(const struct super_block *sb,
+ const unsigned char *dir_name)
+{
+ struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name));
+ return d_hash_and_lookup(sb->s_root, &dir);
+}
+EXPORT_SYMBOL_GPL(rpc_d_lookup_sb);
+
+int rpc_pipefs_init_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ sn->gssd_dummy = rpc_mkpipe_data(&gssd_dummy_pipe_ops, 0);
+ if (IS_ERR(sn->gssd_dummy))
+ return PTR_ERR(sn->gssd_dummy);
+
+ mutex_init(&sn->pipefs_sb_lock);
+ sn->pipe_version = -1;
+ return 0;
+}
+
+void rpc_pipefs_exit_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ rpc_destroy_pipe_data(sn->gssd_dummy);
+}
+
+/*
+ * This call will be used for per network namespace operations calls.
+ * Note: Function will be returned with pipefs_sb_lock taken if superblock was
+ * found. This lock have to be released by rpc_put_sb_net() when all operations
+ * will be completed.
+ */
+struct super_block *rpc_get_sb_net(const struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ mutex_lock(&sn->pipefs_sb_lock);
+ if (sn->pipefs_sb)
+ return sn->pipefs_sb;
+ mutex_unlock(&sn->pipefs_sb_lock);
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(rpc_get_sb_net);
+
+void rpc_put_sb_net(const struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ WARN_ON(sn->pipefs_sb == NULL);
+ mutex_unlock(&sn->pipefs_sb_lock);
+}
+EXPORT_SYMBOL_GPL(rpc_put_sb_net);
+
+static const struct rpc_filelist gssd_dummy_clnt_dir[] = {
+ [0] = {
+ .name = "clntXX",
+ .mode = S_IFDIR | 0555,
+ },
+};
+
+static ssize_t
+dummy_downcall(struct file *filp, const char __user *src, size_t len)
+{
+ return -EINVAL;
+}
+
+static const struct rpc_pipe_ops gssd_dummy_pipe_ops = {
+ .upcall = rpc_pipe_generic_upcall,
+ .downcall = dummy_downcall,
+};
+
+/*
+ * Here we present a bogus "info" file to keep rpc.gssd happy. We don't expect
+ * that it will ever use this info to handle an upcall, but rpc.gssd expects
+ * that this file will be there and have a certain format.
+ */
+static int
+rpc_dummy_info_show(struct seq_file *m, void *v)
+{
+ seq_printf(m, "RPC server: %s\n", utsname()->nodename);
+ seq_printf(m, "service: foo (1) version 0\n");
+ seq_printf(m, "address: 127.0.0.1\n");
+ seq_printf(m, "protocol: tcp\n");
+ seq_printf(m, "port: 0\n");
+ return 0;
+}
+DEFINE_SHOW_ATTRIBUTE(rpc_dummy_info);
+
+static const struct rpc_filelist gssd_dummy_info_file[] = {
+ [0] = {
+ .name = "info",
+ .i_fop = &rpc_dummy_info_fops,
+ .mode = S_IFREG | 0400,
+ },
+};
+
+/**
+ * rpc_gssd_dummy_populate - create a dummy gssd pipe
+ * @root: root of the rpc_pipefs filesystem
+ * @pipe_data: pipe data created when netns is initialized
+ *
+ * Create a dummy set of directories and a pipe that gssd can hold open to
+ * indicate that it is up and running.
+ */
+static struct dentry *
+rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data)
+{
+ int ret = 0;
+ struct dentry *gssd_dentry;
+ struct dentry *clnt_dentry = NULL;
+ struct dentry *pipe_dentry = NULL;
+ struct qstr q = QSTR_INIT(files[RPCAUTH_gssd].name,
+ strlen(files[RPCAUTH_gssd].name));
+
+ /* We should never get this far if "gssd" doesn't exist */
+ gssd_dentry = d_hash_and_lookup(root, &q);
+ if (!gssd_dentry)
+ return ERR_PTR(-ENOENT);
+
+ ret = rpc_populate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1, NULL);
+ if (ret) {
+ pipe_dentry = ERR_PTR(ret);
+ goto out;
+ }
+
+ q.name = gssd_dummy_clnt_dir[0].name;
+ q.len = strlen(gssd_dummy_clnt_dir[0].name);
+ clnt_dentry = d_hash_and_lookup(gssd_dentry, &q);
+ if (!clnt_dentry) {
+ __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1);
+ pipe_dentry = ERR_PTR(-ENOENT);
+ goto out;
+ }
+
+ ret = rpc_populate(clnt_dentry, gssd_dummy_info_file, 0, 1, NULL);
+ if (ret) {
+ __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1);
+ pipe_dentry = ERR_PTR(ret);
+ goto out;
+ }
+
+ pipe_dentry = rpc_mkpipe_dentry(clnt_dentry, "gssd", NULL, pipe_data);
+ if (IS_ERR(pipe_dentry)) {
+ __rpc_depopulate(clnt_dentry, gssd_dummy_info_file, 0, 1);
+ __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1);
+ }
+out:
+ dput(clnt_dentry);
+ dput(gssd_dentry);
+ return pipe_dentry;
+}
+
+static void
+rpc_gssd_dummy_depopulate(struct dentry *pipe_dentry)
+{
+ struct dentry *clnt_dir = pipe_dentry->d_parent;
+ struct dentry *gssd_dir = clnt_dir->d_parent;
+
+ dget(pipe_dentry);
+ __rpc_rmpipe(d_inode(clnt_dir), pipe_dentry);
+ __rpc_depopulate(clnt_dir, gssd_dummy_info_file, 0, 1);
+ __rpc_depopulate(gssd_dir, gssd_dummy_clnt_dir, 0, 1);
+ dput(pipe_dentry);
+}
+
+static int
+rpc_fill_super(struct super_block *sb, struct fs_context *fc)
+{
+ struct inode *inode;
+ struct dentry *root, *gssd_dentry;
+ struct net *net = sb->s_fs_info;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ int err;
+
+ sb->s_blocksize = PAGE_SIZE;
+ sb->s_blocksize_bits = PAGE_SHIFT;
+ sb->s_magic = RPCAUTH_GSSMAGIC;
+ sb->s_op = &s_ops;
+ sb->s_d_op = &simple_dentry_operations;
+ sb->s_time_gran = 1;
+
+ inode = rpc_get_inode(sb, S_IFDIR | 0555);
+ sb->s_root = root = d_make_root(inode);
+ if (!root)
+ return -ENOMEM;
+ if (rpc_populate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF, NULL))
+ return -ENOMEM;
+
+ gssd_dentry = rpc_gssd_dummy_populate(root, sn->gssd_dummy);
+ if (IS_ERR(gssd_dentry)) {
+ __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF);
+ return PTR_ERR(gssd_dentry);
+ }
+
+ dprintk("RPC: sending pipefs MOUNT notification for net %x%s\n",
+ net->ns.inum, NET_NAME(net));
+ mutex_lock(&sn->pipefs_sb_lock);
+ sn->pipefs_sb = sb;
+ err = blocking_notifier_call_chain(&rpc_pipefs_notifier_list,
+ RPC_PIPEFS_MOUNT,
+ sb);
+ if (err)
+ goto err_depopulate;
+ mutex_unlock(&sn->pipefs_sb_lock);
+ return 0;
+
+err_depopulate:
+ rpc_gssd_dummy_depopulate(gssd_dentry);
+ blocking_notifier_call_chain(&rpc_pipefs_notifier_list,
+ RPC_PIPEFS_UMOUNT,
+ sb);
+ sn->pipefs_sb = NULL;
+ __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF);
+ mutex_unlock(&sn->pipefs_sb_lock);
+ return err;
+}
+
+bool
+gssd_running(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_pipe *pipe = sn->gssd_dummy;
+
+ return pipe->nreaders || pipe->nwriters;
+}
+EXPORT_SYMBOL_GPL(gssd_running);
+
+static int rpc_fs_get_tree(struct fs_context *fc)
+{
+ return get_tree_keyed(fc, rpc_fill_super, get_net(fc->net_ns));
+}
+
+static void rpc_fs_free_fc(struct fs_context *fc)
+{
+ if (fc->s_fs_info)
+ put_net(fc->s_fs_info);
+}
+
+static const struct fs_context_operations rpc_fs_context_ops = {
+ .free = rpc_fs_free_fc,
+ .get_tree = rpc_fs_get_tree,
+};
+
+static int rpc_init_fs_context(struct fs_context *fc)
+{
+ put_user_ns(fc->user_ns);
+ fc->user_ns = get_user_ns(fc->net_ns->user_ns);
+ fc->ops = &rpc_fs_context_ops;
+ return 0;
+}
+
+static void rpc_kill_sb(struct super_block *sb)
+{
+ struct net *net = sb->s_fs_info;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ mutex_lock(&sn->pipefs_sb_lock);
+ if (sn->pipefs_sb != sb) {
+ mutex_unlock(&sn->pipefs_sb_lock);
+ goto out;
+ }
+ sn->pipefs_sb = NULL;
+ dprintk("RPC: sending pipefs UMOUNT notification for net %x%s\n",
+ net->ns.inum, NET_NAME(net));
+ blocking_notifier_call_chain(&rpc_pipefs_notifier_list,
+ RPC_PIPEFS_UMOUNT,
+ sb);
+ mutex_unlock(&sn->pipefs_sb_lock);
+out:
+ kill_litter_super(sb);
+ put_net(net);
+}
+
+static struct file_system_type rpc_pipe_fs_type = {
+ .owner = THIS_MODULE,
+ .name = "rpc_pipefs",
+ .init_fs_context = rpc_init_fs_context,
+ .kill_sb = rpc_kill_sb,
+};
+MODULE_ALIAS_FS("rpc_pipefs");
+MODULE_ALIAS("rpc_pipefs");
+
+static void
+init_once(void *foo)
+{
+ struct rpc_inode *rpci = (struct rpc_inode *) foo;
+
+ inode_init_once(&rpci->vfs_inode);
+ rpci->private = NULL;
+ rpci->pipe = NULL;
+ init_waitqueue_head(&rpci->waitq);
+}
+
+int register_rpc_pipefs(void)
+{
+ int err;
+
+ rpc_inode_cachep = kmem_cache_create("rpc_inode_cache",
+ sizeof(struct rpc_inode),
+ 0, (SLAB_HWCACHE_ALIGN|SLAB_RECLAIM_ACCOUNT|
+ SLAB_MEM_SPREAD|SLAB_ACCOUNT),
+ init_once);
+ if (!rpc_inode_cachep)
+ return -ENOMEM;
+ err = rpc_clients_notifier_register();
+ if (err)
+ goto err_notifier;
+ err = register_filesystem(&rpc_pipe_fs_type);
+ if (err)
+ goto err_register;
+ return 0;
+
+err_register:
+ rpc_clients_notifier_unregister();
+err_notifier:
+ kmem_cache_destroy(rpc_inode_cachep);
+ return err;
+}
+
+void unregister_rpc_pipefs(void)
+{
+ rpc_clients_notifier_unregister();
+ unregister_filesystem(&rpc_pipe_fs_type);
+ kmem_cache_destroy(rpc_inode_cachep);
+}
diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c
new file mode 100644
index 0000000000..102c3818bc
--- /dev/null
+++ b/net/sunrpc/rpcb_clnt.c
@@ -0,0 +1,1121 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * In-kernel rpcbind client supporting versions 2, 3, and 4 of the rpcbind
+ * protocol
+ *
+ * Based on RFC 1833: "Binding Protocols for ONC RPC Version 2" and
+ * RFC 3530: "Network File System (NFS) version 4 Protocol"
+ *
+ * Original: Gilles Quillard, Bull Open Source, 2005 <gilles.quillard@bull.net>
+ * Updated: Chuck Lever, Oracle Corporation, 2007 <chuck.lever@oracle.com>
+ *
+ * Descended from net/sunrpc/pmap_clnt.c,
+ * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/module.h>
+
+#include <linux/types.h>
+#include <linux/socket.h>
+#include <linux/un.h>
+#include <linux/in.h>
+#include <linux/in6.h>
+#include <linux/kernel.h>
+#include <linux/errno.h>
+#include <linux/mutex.h>
+#include <linux/slab.h>
+#include <net/ipv6.h>
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/xprtsock.h>
+
+#include <trace/events/sunrpc.h>
+
+#include "netns.h"
+
+#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock"
+#define RPCBIND_SOCK_ABSTRACT_NAME "\0/run/rpcbind.sock"
+
+#define RPCBIND_PROGRAM (100000u)
+#define RPCBIND_PORT (111u)
+
+#define RPCBVERS_2 (2u)
+#define RPCBVERS_3 (3u)
+#define RPCBVERS_4 (4u)
+
+enum {
+ RPCBPROC_NULL,
+ RPCBPROC_SET,
+ RPCBPROC_UNSET,
+ RPCBPROC_GETPORT,
+ RPCBPROC_GETADDR = 3, /* alias for GETPORT */
+ RPCBPROC_DUMP,
+ RPCBPROC_CALLIT,
+ RPCBPROC_BCAST = 5, /* alias for CALLIT */
+ RPCBPROC_GETTIME,
+ RPCBPROC_UADDR2TADDR,
+ RPCBPROC_TADDR2UADDR,
+ RPCBPROC_GETVERSADDR,
+ RPCBPROC_INDIRECT,
+ RPCBPROC_GETADDRLIST,
+ RPCBPROC_GETSTAT,
+};
+
+/*
+ * r_owner
+ *
+ * The "owner" is allowed to unset a service in the rpcbind database.
+ *
+ * For AF_LOCAL SET/UNSET requests, rpcbind treats this string as a
+ * UID which it maps to a local user name via a password lookup.
+ * In all other cases it is ignored.
+ *
+ * For SET/UNSET requests, user space provides a value, even for
+ * network requests, and GETADDR uses an empty string. We follow
+ * those precedents here.
+ */
+#define RPCB_OWNER_STRING "0"
+#define RPCB_MAXOWNERLEN sizeof(RPCB_OWNER_STRING)
+
+/*
+ * XDR data type sizes
+ */
+#define RPCB_program_sz (1)
+#define RPCB_version_sz (1)
+#define RPCB_protocol_sz (1)
+#define RPCB_port_sz (1)
+#define RPCB_boolean_sz (1)
+
+#define RPCB_netid_sz (1 + XDR_QUADLEN(RPCBIND_MAXNETIDLEN))
+#define RPCB_addr_sz (1 + XDR_QUADLEN(RPCBIND_MAXUADDRLEN))
+#define RPCB_ownerstring_sz (1 + XDR_QUADLEN(RPCB_MAXOWNERLEN))
+
+/*
+ * XDR argument and result sizes
+ */
+#define RPCB_mappingargs_sz (RPCB_program_sz + RPCB_version_sz + \
+ RPCB_protocol_sz + RPCB_port_sz)
+#define RPCB_getaddrargs_sz (RPCB_program_sz + RPCB_version_sz + \
+ RPCB_netid_sz + RPCB_addr_sz + \
+ RPCB_ownerstring_sz)
+
+#define RPCB_getportres_sz RPCB_port_sz
+#define RPCB_setres_sz RPCB_boolean_sz
+
+/*
+ * Note that RFC 1833 does not put any size restrictions on the
+ * address string returned by the remote rpcbind database.
+ */
+#define RPCB_getaddrres_sz RPCB_addr_sz
+
+static void rpcb_getport_done(struct rpc_task *, void *);
+static void rpcb_map_release(void *data);
+static const struct rpc_program rpcb_program;
+
+struct rpcbind_args {
+ struct rpc_xprt * r_xprt;
+
+ u32 r_prog;
+ u32 r_vers;
+ u32 r_prot;
+ unsigned short r_port;
+ const char * r_netid;
+ const char * r_addr;
+ const char * r_owner;
+
+ int r_status;
+};
+
+static const struct rpc_procinfo rpcb_procedures2[];
+static const struct rpc_procinfo rpcb_procedures3[];
+static const struct rpc_procinfo rpcb_procedures4[];
+
+struct rpcb_info {
+ u32 rpc_vers;
+ const struct rpc_procinfo *rpc_proc;
+};
+
+static const struct rpcb_info rpcb_next_version[];
+static const struct rpcb_info rpcb_next_version6[];
+
+static const struct rpc_call_ops rpcb_getport_ops = {
+ .rpc_call_done = rpcb_getport_done,
+ .rpc_release = rpcb_map_release,
+};
+
+static void rpcb_wake_rpcbind_waiters(struct rpc_xprt *xprt, int status)
+{
+ xprt_clear_binding(xprt);
+ rpc_wake_up_status(&xprt->binding, status);
+}
+
+static void rpcb_map_release(void *data)
+{
+ struct rpcbind_args *map = data;
+
+ rpcb_wake_rpcbind_waiters(map->r_xprt, map->r_status);
+ xprt_put(map->r_xprt);
+ kfree(map->r_addr);
+ kfree(map);
+}
+
+static int rpcb_get_local(struct net *net)
+{
+ int cnt;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ spin_lock(&sn->rpcb_clnt_lock);
+ if (sn->rpcb_users)
+ sn->rpcb_users++;
+ cnt = sn->rpcb_users;
+ spin_unlock(&sn->rpcb_clnt_lock);
+
+ return cnt;
+}
+
+void rpcb_put_local(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct rpc_clnt *clnt = sn->rpcb_local_clnt;
+ struct rpc_clnt *clnt4 = sn->rpcb_local_clnt4;
+ int shutdown = 0;
+
+ spin_lock(&sn->rpcb_clnt_lock);
+ if (sn->rpcb_users) {
+ if (--sn->rpcb_users == 0) {
+ sn->rpcb_local_clnt = NULL;
+ sn->rpcb_local_clnt4 = NULL;
+ }
+ shutdown = !sn->rpcb_users;
+ }
+ spin_unlock(&sn->rpcb_clnt_lock);
+
+ if (shutdown) {
+ /*
+ * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister
+ */
+ if (clnt4)
+ rpc_shutdown_client(clnt4);
+ if (clnt)
+ rpc_shutdown_client(clnt);
+ }
+}
+
+static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt,
+ struct rpc_clnt *clnt4,
+ bool is_af_local)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ /* Protected by rpcb_create_local_mutex */
+ sn->rpcb_local_clnt = clnt;
+ sn->rpcb_local_clnt4 = clnt4;
+ sn->rpcb_is_af_local = is_af_local ? 1 : 0;
+ smp_wmb();
+ sn->rpcb_users = 1;
+}
+
+/* Evaluate to actual length of the `sockaddr_un' structure. */
+# define SUN_LEN(ptr) (offsetof(struct sockaddr_un, sun_path) \
+ + 1 + strlen((ptr)->sun_path + 1))
+
+/*
+ * Returns zero on success, otherwise a negative errno value
+ * is returned.
+ */
+static int rpcb_create_af_local(struct net *net,
+ const struct sockaddr_un *addr)
+{
+ struct rpc_create_args args = {
+ .net = net,
+ .protocol = XPRT_TRANSPORT_LOCAL,
+ .address = (struct sockaddr *)addr,
+ .addrsize = SUN_LEN(addr),
+ .servername = "localhost",
+ .program = &rpcb_program,
+ .version = RPCBVERS_2,
+ .authflavor = RPC_AUTH_NULL,
+ .cred = current_cred(),
+ /*
+ * We turn off the idle timeout to prevent the kernel
+ * from automatically disconnecting the socket.
+ * Otherwise, we'd have to cache the mount namespace
+ * of the caller and somehow pass that to the socket
+ * reconnect code.
+ */
+ .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT,
+ };
+ struct rpc_clnt *clnt, *clnt4;
+ int result = 0;
+
+ /*
+ * Because we requested an RPC PING at transport creation time,
+ * this works only if the user space portmapper is rpcbind, and
+ * it's listening on AF_LOCAL on the named socket.
+ */
+ clnt = rpc_create(&args);
+ if (IS_ERR(clnt)) {
+ result = PTR_ERR(clnt);
+ goto out;
+ }
+
+ clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4);
+ if (IS_ERR(clnt4))
+ clnt4 = NULL;
+
+ rpcb_set_local(net, clnt, clnt4, true);
+
+out:
+ return result;
+}
+
+static int rpcb_create_local_abstract(struct net *net)
+{
+ static const struct sockaddr_un rpcb_localaddr_abstract = {
+ .sun_family = AF_LOCAL,
+ .sun_path = RPCBIND_SOCK_ABSTRACT_NAME,
+ };
+
+ return rpcb_create_af_local(net, &rpcb_localaddr_abstract);
+}
+
+static int rpcb_create_local_unix(struct net *net)
+{
+ static const struct sockaddr_un rpcb_localaddr_unix = {
+ .sun_family = AF_LOCAL,
+ .sun_path = RPCBIND_SOCK_PATHNAME,
+ };
+
+ return rpcb_create_af_local(net, &rpcb_localaddr_unix);
+}
+
+/*
+ * Returns zero on success, otherwise a negative errno value
+ * is returned.
+ */
+static int rpcb_create_local_net(struct net *net)
+{
+ static const struct sockaddr_in rpcb_inaddr_loopback = {
+ .sin_family = AF_INET,
+ .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
+ .sin_port = htons(RPCBIND_PORT),
+ };
+ struct rpc_create_args args = {
+ .net = net,
+ .protocol = XPRT_TRANSPORT_TCP,
+ .address = (struct sockaddr *)&rpcb_inaddr_loopback,
+ .addrsize = sizeof(rpcb_inaddr_loopback),
+ .servername = "localhost",
+ .program = &rpcb_program,
+ .version = RPCBVERS_2,
+ .authflavor = RPC_AUTH_UNIX,
+ .cred = current_cred(),
+ .flags = RPC_CLNT_CREATE_NOPING,
+ };
+ struct rpc_clnt *clnt, *clnt4;
+ int result = 0;
+
+ clnt = rpc_create(&args);
+ if (IS_ERR(clnt)) {
+ result = PTR_ERR(clnt);
+ goto out;
+ }
+
+ /*
+ * This results in an RPC ping. On systems running portmapper,
+ * the v4 ping will fail. Proceed anyway, but disallow rpcb
+ * v4 upcalls.
+ */
+ clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4);
+ if (IS_ERR(clnt4))
+ clnt4 = NULL;
+
+ rpcb_set_local(net, clnt, clnt4, false);
+
+out:
+ return result;
+}
+
+/*
+ * Returns zero on success, otherwise a negative errno value
+ * is returned.
+ */
+int rpcb_create_local(struct net *net)
+{
+ static DEFINE_MUTEX(rpcb_create_local_mutex);
+ int result = 0;
+
+ if (rpcb_get_local(net))
+ return result;
+
+ mutex_lock(&rpcb_create_local_mutex);
+ if (rpcb_get_local(net))
+ goto out;
+
+ if (rpcb_create_local_abstract(net) != 0 &&
+ rpcb_create_local_unix(net) != 0)
+ result = rpcb_create_local_net(net);
+
+out:
+ mutex_unlock(&rpcb_create_local_mutex);
+ return result;
+}
+
+static struct rpc_clnt *rpcb_create(struct net *net, const char *nodename,
+ const char *hostname,
+ struct sockaddr *srvaddr, size_t salen,
+ int proto, u32 version,
+ const struct cred *cred,
+ const struct rpc_timeout *timeo)
+{
+ struct rpc_create_args args = {
+ .net = net,
+ .protocol = proto,
+ .address = srvaddr,
+ .addrsize = salen,
+ .timeout = timeo,
+ .servername = hostname,
+ .nodename = nodename,
+ .program = &rpcb_program,
+ .version = version,
+ .authflavor = RPC_AUTH_UNIX,
+ .cred = cred,
+ .flags = (RPC_CLNT_CREATE_NOPING |
+ RPC_CLNT_CREATE_NONPRIVPORT),
+ };
+
+ switch (srvaddr->sa_family) {
+ case AF_INET:
+ ((struct sockaddr_in *)srvaddr)->sin_port = htons(RPCBIND_PORT);
+ break;
+ case AF_INET6:
+ ((struct sockaddr_in6 *)srvaddr)->sin6_port = htons(RPCBIND_PORT);
+ break;
+ default:
+ return ERR_PTR(-EAFNOSUPPORT);
+ }
+
+ return rpc_create(&args);
+}
+
+static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set)
+{
+ int flags = RPC_TASK_NOCONNECT;
+ int error, result = 0;
+
+ if (is_set || !sn->rpcb_is_af_local)
+ flags = RPC_TASK_SOFTCONN;
+ msg->rpc_resp = &result;
+
+ error = rpc_call_sync(clnt, msg, flags);
+ if (error < 0)
+ return error;
+
+ if (!result)
+ return -EACCES;
+ return 0;
+}
+
+/**
+ * rpcb_register - set or unset a port registration with the local rpcbind svc
+ * @net: target network namespace
+ * @prog: RPC program number to bind
+ * @vers: RPC version number to bind
+ * @prot: transport protocol to register
+ * @port: port value to register
+ *
+ * Returns zero if the registration request was dispatched successfully
+ * and the rpcbind daemon returned success. Otherwise, returns an errno
+ * value that reflects the nature of the error (request could not be
+ * dispatched, timed out, or rpcbind returned an error).
+ *
+ * RPC services invoke this function to advertise their contact
+ * information via the system's rpcbind daemon. RPC services
+ * invoke this function once for each [program, version, transport]
+ * tuple they wish to advertise.
+ *
+ * Callers may also unregister RPC services that are no longer
+ * available by setting the passed-in port to zero. This removes
+ * all registered transports for [program, version] from the local
+ * rpcbind database.
+ *
+ * This function uses rpcbind protocol version 2 to contact the
+ * local rpcbind daemon.
+ *
+ * Registration works over both AF_INET and AF_INET6, and services
+ * registered via this function are advertised as available for any
+ * address. If the local rpcbind daemon is listening on AF_INET6,
+ * services registered via this function will be advertised on
+ * IN6ADDR_ANY (ie available for all AF_INET and AF_INET6
+ * addresses).
+ */
+int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short port)
+{
+ struct rpcbind_args map = {
+ .r_prog = prog,
+ .r_vers = vers,
+ .r_prot = prot,
+ .r_port = port,
+ };
+ struct rpc_message msg = {
+ .rpc_argp = &map,
+ };
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ bool is_set = false;
+
+ trace_pmap_register(prog, vers, prot, port);
+
+ msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET];
+ if (port != 0) {
+ msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET];
+ is_set = true;
+ }
+
+ return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set);
+}
+
+/*
+ * Fill in AF_INET family-specific arguments to register
+ */
+static int rpcb_register_inet4(struct sunrpc_net *sn,
+ const struct sockaddr *sap,
+ struct rpc_message *msg)
+{
+ const struct sockaddr_in *sin = (const struct sockaddr_in *)sap;
+ struct rpcbind_args *map = msg->rpc_argp;
+ unsigned short port = ntohs(sin->sin_port);
+ bool is_set = false;
+ int result;
+
+ map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL);
+
+ msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
+ if (port != 0) {
+ msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
+ is_set = true;
+ }
+
+ result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set);
+ kfree(map->r_addr);
+ return result;
+}
+
+/*
+ * Fill in AF_INET6 family-specific arguments to register
+ */
+static int rpcb_register_inet6(struct sunrpc_net *sn,
+ const struct sockaddr *sap,
+ struct rpc_message *msg)
+{
+ const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap;
+ struct rpcbind_args *map = msg->rpc_argp;
+ unsigned short port = ntohs(sin6->sin6_port);
+ bool is_set = false;
+ int result;
+
+ map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL);
+
+ msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
+ if (port != 0) {
+ msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
+ is_set = true;
+ }
+
+ result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set);
+ kfree(map->r_addr);
+ return result;
+}
+
+static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn,
+ struct rpc_message *msg)
+{
+ struct rpcbind_args *map = msg->rpc_argp;
+
+ trace_rpcb_unregister(map->r_prog, map->r_vers, map->r_netid);
+
+ map->r_addr = "";
+ msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
+
+ return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false);
+}
+
+/**
+ * rpcb_v4_register - set or unset a port registration with the local rpcbind
+ * @net: target network namespace
+ * @program: RPC program number of service to (un)register
+ * @version: RPC version number of service to (un)register
+ * @address: address family, IP address, and port to (un)register
+ * @netid: netid of transport protocol to (un)register
+ *
+ * Returns zero if the registration request was dispatched successfully
+ * and the rpcbind daemon returned success. Otherwise, returns an errno
+ * value that reflects the nature of the error (request could not be
+ * dispatched, timed out, or rpcbind returned an error).
+ *
+ * RPC services invoke this function to advertise their contact
+ * information via the system's rpcbind daemon. RPC services
+ * invoke this function once for each [program, version, address,
+ * netid] tuple they wish to advertise.
+ *
+ * Callers may also unregister RPC services that are registered at a
+ * specific address by setting the port number in @address to zero.
+ * They may unregister all registered protocol families at once for
+ * a service by passing a NULL @address argument. If @netid is ""
+ * then all netids for [program, version, address] are unregistered.
+ *
+ * This function uses rpcbind protocol version 4 to contact the
+ * local rpcbind daemon. The local rpcbind daemon must support
+ * version 4 of the rpcbind protocol in order for these functions
+ * to register a service successfully.
+ *
+ * Supported netids include "udp" and "tcp" for UDP and TCP over
+ * IPv4, and "udp6" and "tcp6" for UDP and TCP over IPv6,
+ * respectively.
+ *
+ * The contents of @address determine the address family and the
+ * port to be registered. The usual practice is to pass INADDR_ANY
+ * as the raw address, but specifying a non-zero address is also
+ * supported by this API if the caller wishes to advertise an RPC
+ * service on a specific network interface.
+ *
+ * Note that passing in INADDR_ANY does not create the same service
+ * registration as IN6ADDR_ANY. The former advertises an RPC
+ * service on any IPv4 address, but not on IPv6. The latter
+ * advertises the service on all IPv4 and IPv6 addresses.
+ */
+int rpcb_v4_register(struct net *net, const u32 program, const u32 version,
+ const struct sockaddr *address, const char *netid)
+{
+ struct rpcbind_args map = {
+ .r_prog = program,
+ .r_vers = version,
+ .r_netid = netid,
+ .r_owner = RPCB_OWNER_STRING,
+ };
+ struct rpc_message msg = {
+ .rpc_argp = &map,
+ };
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ if (sn->rpcb_local_clnt4 == NULL)
+ return -EPROTONOSUPPORT;
+
+ if (address == NULL)
+ return rpcb_unregister_all_protofamilies(sn, &msg);
+
+ trace_rpcb_register(map.r_prog, map.r_vers, map.r_addr, map.r_netid);
+
+ switch (address->sa_family) {
+ case AF_INET:
+ return rpcb_register_inet4(sn, address, &msg);
+ case AF_INET6:
+ return rpcb_register_inet6(sn, address, &msg);
+ }
+
+ return -EAFNOSUPPORT;
+}
+
+static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt,
+ struct rpcbind_args *map, const struct rpc_procinfo *proc)
+{
+ struct rpc_message msg = {
+ .rpc_proc = proc,
+ .rpc_argp = map,
+ .rpc_resp = map,
+ };
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = rpcb_clnt,
+ .rpc_message = &msg,
+ .callback_ops = &rpcb_getport_ops,
+ .callback_data = map,
+ .flags = RPC_TASK_ASYNC | RPC_TASK_SOFTCONN,
+ };
+
+ return rpc_run_task(&task_setup_data);
+}
+
+/*
+ * In the case where rpc clients have been cloned, we want to make
+ * sure that we use the program number/version etc of the actual
+ * owner of the xprt. To do so, we walk back up the tree of parents
+ * to find whoever created the transport and/or whoever has the
+ * autobind flag set.
+ */
+static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt)
+{
+ struct rpc_clnt *parent = clnt->cl_parent;
+ struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch);
+
+ while (parent != clnt) {
+ if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps)
+ break;
+ if (clnt->cl_autobind)
+ break;
+ clnt = parent;
+ parent = parent->cl_parent;
+ }
+ return clnt;
+}
+
+/**
+ * rpcb_getport_async - obtain the port for a given RPC service on a given host
+ * @task: task that is waiting for portmapper request
+ *
+ * This one can be called for an ongoing RPC request, and can be used in
+ * an async (rpciod) context.
+ */
+void rpcb_getport_async(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt;
+ const struct rpc_procinfo *proc;
+ u32 bind_version;
+ struct rpc_xprt *xprt;
+ struct rpc_clnt *rpcb_clnt;
+ struct rpcbind_args *map;
+ struct rpc_task *child;
+ struct sockaddr_storage addr;
+ struct sockaddr *sap = (struct sockaddr *)&addr;
+ size_t salen;
+ int status;
+
+ rcu_read_lock();
+ clnt = rpcb_find_transport_owner(task->tk_client);
+ rcu_read_unlock();
+ xprt = xprt_get(task->tk_xprt);
+
+ /* Put self on the wait queue to ensure we get notified if
+ * some other task is already attempting to bind the port */
+ rpc_sleep_on_timeout(&xprt->binding, task,
+ NULL, jiffies + xprt->bind_timeout);
+
+ if (xprt_test_and_set_binding(xprt)) {
+ xprt_put(xprt);
+ return;
+ }
+
+ /* Someone else may have bound if we slept */
+ if (xprt_bound(xprt)) {
+ status = 0;
+ goto bailout_nofree;
+ }
+
+ /* Parent transport's destination address */
+ salen = rpc_peeraddr(clnt, sap, sizeof(addr));
+
+ /* Don't ever use rpcbind v2 for AF_INET6 requests */
+ switch (sap->sa_family) {
+ case AF_INET:
+ proc = rpcb_next_version[xprt->bind_index].rpc_proc;
+ bind_version = rpcb_next_version[xprt->bind_index].rpc_vers;
+ break;
+ case AF_INET6:
+ proc = rpcb_next_version6[xprt->bind_index].rpc_proc;
+ bind_version = rpcb_next_version6[xprt->bind_index].rpc_vers;
+ break;
+ default:
+ status = -EAFNOSUPPORT;
+ goto bailout_nofree;
+ }
+ if (proc == NULL) {
+ xprt->bind_index = 0;
+ status = -EPFNOSUPPORT;
+ goto bailout_nofree;
+ }
+
+ trace_rpcb_getport(clnt, task, bind_version);
+
+ rpcb_clnt = rpcb_create(xprt->xprt_net,
+ clnt->cl_nodename,
+ xprt->servername, sap, salen,
+ xprt->prot, bind_version,
+ clnt->cl_cred,
+ task->tk_client->cl_timeout);
+ if (IS_ERR(rpcb_clnt)) {
+ status = PTR_ERR(rpcb_clnt);
+ goto bailout_nofree;
+ }
+
+ map = kzalloc(sizeof(struct rpcbind_args), rpc_task_gfp_mask());
+ if (!map) {
+ status = -ENOMEM;
+ goto bailout_release_client;
+ }
+ map->r_prog = clnt->cl_prog;
+ map->r_vers = clnt->cl_vers;
+ map->r_prot = xprt->prot;
+ map->r_port = 0;
+ map->r_xprt = xprt;
+ map->r_status = -EIO;
+
+ switch (bind_version) {
+ case RPCBVERS_4:
+ case RPCBVERS_3:
+ map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID];
+ map->r_addr = rpc_sockaddr2uaddr(sap, rpc_task_gfp_mask());
+ if (!map->r_addr) {
+ status = -ENOMEM;
+ goto bailout_free_args;
+ }
+ map->r_owner = "";
+ break;
+ case RPCBVERS_2:
+ map->r_addr = NULL;
+ break;
+ default:
+ BUG();
+ }
+
+ child = rpcb_call_async(rpcb_clnt, map, proc);
+ rpc_release_client(rpcb_clnt);
+ if (IS_ERR(child)) {
+ /* rpcb_map_release() has freed the arguments */
+ return;
+ }
+
+ xprt->stat.bind_count++;
+ rpc_put_task(child);
+ return;
+
+bailout_free_args:
+ kfree(map);
+bailout_release_client:
+ rpc_release_client(rpcb_clnt);
+bailout_nofree:
+ rpcb_wake_rpcbind_waiters(xprt, status);
+ task->tk_status = status;
+ xprt_put(xprt);
+}
+EXPORT_SYMBOL_GPL(rpcb_getport_async);
+
+/*
+ * Rpcbind child task calls this callback via tk_exit.
+ */
+static void rpcb_getport_done(struct rpc_task *child, void *data)
+{
+ struct rpcbind_args *map = data;
+ struct rpc_xprt *xprt = map->r_xprt;
+
+ map->r_status = child->tk_status;
+
+ /* Garbage reply: retry with a lesser rpcbind version */
+ if (map->r_status == -EIO)
+ map->r_status = -EPROTONOSUPPORT;
+
+ /* rpcbind server doesn't support this rpcbind protocol version */
+ if (map->r_status == -EPROTONOSUPPORT)
+ xprt->bind_index++;
+
+ if (map->r_status < 0) {
+ /* rpcbind server not available on remote host? */
+ map->r_port = 0;
+
+ } else if (map->r_port == 0) {
+ /* Requested RPC service wasn't registered on remote host */
+ map->r_status = -EACCES;
+ } else {
+ /* Succeeded */
+ map->r_status = 0;
+ }
+
+ trace_rpcb_setport(child, map->r_status, map->r_port);
+ xprt->ops->set_port(xprt, map->r_port);
+ if (map->r_port)
+ xprt_set_bound(xprt);
+}
+
+/*
+ * XDR functions for rpcbind
+ */
+
+static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr,
+ const void *data)
+{
+ const struct rpcbind_args *rpcb = data;
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2);
+ *p++ = cpu_to_be32(rpcb->r_prog);
+ *p++ = cpu_to_be32(rpcb->r_vers);
+ *p++ = cpu_to_be32(rpcb->r_prot);
+ *p = cpu_to_be32(rpcb->r_port);
+}
+
+static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr,
+ void *data)
+{
+ struct rpcbind_args *rpcb = data;
+ unsigned long port;
+ __be32 *p;
+
+ rpcb->r_port = 0;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -EIO;
+
+ port = be32_to_cpup(p);
+ if (unlikely(port > USHRT_MAX))
+ return -EIO;
+
+ rpcb->r_port = port;
+ return 0;
+}
+
+static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr,
+ void *data)
+{
+ unsigned int *boolp = data;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ return -EIO;
+
+ *boolp = 0;
+ if (*p != xdr_zero)
+ *boolp = 1;
+ return 0;
+}
+
+static void encode_rpcb_string(struct xdr_stream *xdr, const char *string,
+ const u32 maxstrlen)
+{
+ __be32 *p;
+ u32 len;
+
+ len = strlen(string);
+ WARN_ON_ONCE(len > maxstrlen);
+ if (len > maxstrlen)
+ /* truncate and hope for the best */
+ len = maxstrlen;
+ p = xdr_reserve_space(xdr, 4 + len);
+ xdr_encode_opaque(p, string, len);
+}
+
+static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
+ const void *data)
+{
+ const struct rpcbind_args *rpcb = data;
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, (RPCB_program_sz + RPCB_version_sz) << 2);
+ *p++ = cpu_to_be32(rpcb->r_prog);
+ *p = cpu_to_be32(rpcb->r_vers);
+
+ encode_rpcb_string(xdr, rpcb->r_netid, RPCBIND_MAXNETIDLEN);
+ encode_rpcb_string(xdr, rpcb->r_addr, RPCBIND_MAXUADDRLEN);
+ encode_rpcb_string(xdr, rpcb->r_owner, RPCB_MAXOWNERLEN);
+}
+
+static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
+ void *data)
+{
+ struct rpcbind_args *rpcb = data;
+ struct sockaddr_storage address;
+ struct sockaddr *sap = (struct sockaddr *)&address;
+ __be32 *p;
+ u32 len;
+
+ rpcb->r_port = 0;
+
+ p = xdr_inline_decode(xdr, 4);
+ if (unlikely(p == NULL))
+ goto out_fail;
+ len = be32_to_cpup(p);
+
+ /*
+ * If the returned universal address is a null string,
+ * the requested RPC service was not registered.
+ */
+ if (len == 0)
+ return 0;
+
+ if (unlikely(len > RPCBIND_MAXUADDRLEN))
+ goto out_fail;
+
+ p = xdr_inline_decode(xdr, len);
+ if (unlikely(p == NULL))
+ goto out_fail;
+
+ if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len,
+ sap, sizeof(address)) == 0)
+ goto out_fail;
+ rpcb->r_port = rpc_get_port(sap);
+
+ return 0;
+
+out_fail:
+ return -EIO;
+}
+
+/*
+ * Not all rpcbind procedures described in RFC 1833 are implemented
+ * since the Linux kernel RPC code requires only these.
+ */
+
+static const struct rpc_procinfo rpcb_procedures2[] = {
+ [RPCBPROC_SET] = {
+ .p_proc = RPCBPROC_SET,
+ .p_encode = rpcb_enc_mapping,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_mappingargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_SET,
+ .p_timer = 0,
+ .p_name = "SET",
+ },
+ [RPCBPROC_UNSET] = {
+ .p_proc = RPCBPROC_UNSET,
+ .p_encode = rpcb_enc_mapping,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_mappingargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_UNSET,
+ .p_timer = 0,
+ .p_name = "UNSET",
+ },
+ [RPCBPROC_GETPORT] = {
+ .p_proc = RPCBPROC_GETPORT,
+ .p_encode = rpcb_enc_mapping,
+ .p_decode = rpcb_dec_getport,
+ .p_arglen = RPCB_mappingargs_sz,
+ .p_replen = RPCB_getportres_sz,
+ .p_statidx = RPCBPROC_GETPORT,
+ .p_timer = 0,
+ .p_name = "GETPORT",
+ },
+};
+
+static const struct rpc_procinfo rpcb_procedures3[] = {
+ [RPCBPROC_SET] = {
+ .p_proc = RPCBPROC_SET,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_SET,
+ .p_timer = 0,
+ .p_name = "SET",
+ },
+ [RPCBPROC_UNSET] = {
+ .p_proc = RPCBPROC_UNSET,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_UNSET,
+ .p_timer = 0,
+ .p_name = "UNSET",
+ },
+ [RPCBPROC_GETADDR] = {
+ .p_proc = RPCBPROC_GETADDR,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_getaddr,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_getaddrres_sz,
+ .p_statidx = RPCBPROC_GETADDR,
+ .p_timer = 0,
+ .p_name = "GETADDR",
+ },
+};
+
+static const struct rpc_procinfo rpcb_procedures4[] = {
+ [RPCBPROC_SET] = {
+ .p_proc = RPCBPROC_SET,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_SET,
+ .p_timer = 0,
+ .p_name = "SET",
+ },
+ [RPCBPROC_UNSET] = {
+ .p_proc = RPCBPROC_UNSET,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_set,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_setres_sz,
+ .p_statidx = RPCBPROC_UNSET,
+ .p_timer = 0,
+ .p_name = "UNSET",
+ },
+ [RPCBPROC_GETADDR] = {
+ .p_proc = RPCBPROC_GETADDR,
+ .p_encode = rpcb_enc_getaddr,
+ .p_decode = rpcb_dec_getaddr,
+ .p_arglen = RPCB_getaddrargs_sz,
+ .p_replen = RPCB_getaddrres_sz,
+ .p_statidx = RPCBPROC_GETADDR,
+ .p_timer = 0,
+ .p_name = "GETADDR",
+ },
+};
+
+static const struct rpcb_info rpcb_next_version[] = {
+ {
+ .rpc_vers = RPCBVERS_2,
+ .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT],
+ },
+ {
+ .rpc_proc = NULL,
+ },
+};
+
+static const struct rpcb_info rpcb_next_version6[] = {
+ {
+ .rpc_vers = RPCBVERS_4,
+ .rpc_proc = &rpcb_procedures4[RPCBPROC_GETADDR],
+ },
+ {
+ .rpc_vers = RPCBVERS_3,
+ .rpc_proc = &rpcb_procedures3[RPCBPROC_GETADDR],
+ },
+ {
+ .rpc_proc = NULL,
+ },
+};
+
+static unsigned int rpcb_version2_counts[ARRAY_SIZE(rpcb_procedures2)];
+static const struct rpc_version rpcb_version2 = {
+ .number = RPCBVERS_2,
+ .nrprocs = ARRAY_SIZE(rpcb_procedures2),
+ .procs = rpcb_procedures2,
+ .counts = rpcb_version2_counts,
+};
+
+static unsigned int rpcb_version3_counts[ARRAY_SIZE(rpcb_procedures3)];
+static const struct rpc_version rpcb_version3 = {
+ .number = RPCBVERS_3,
+ .nrprocs = ARRAY_SIZE(rpcb_procedures3),
+ .procs = rpcb_procedures3,
+ .counts = rpcb_version3_counts,
+};
+
+static unsigned int rpcb_version4_counts[ARRAY_SIZE(rpcb_procedures4)];
+static const struct rpc_version rpcb_version4 = {
+ .number = RPCBVERS_4,
+ .nrprocs = ARRAY_SIZE(rpcb_procedures4),
+ .procs = rpcb_procedures4,
+ .counts = rpcb_version4_counts,
+};
+
+static const struct rpc_version *rpcb_version[] = {
+ NULL,
+ NULL,
+ &rpcb_version2,
+ &rpcb_version3,
+ &rpcb_version4
+};
+
+static struct rpc_stat rpcb_stats;
+
+static const struct rpc_program rpcb_program = {
+ .name = "rpcbind",
+ .number = RPCBIND_PROGRAM,
+ .nrvers = ARRAY_SIZE(rpcb_version),
+ .version = rpcb_version,
+ .stats = &rpcb_stats,
+};
diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c
new file mode 100644
index 0000000000..6debf4fd42
--- /dev/null
+++ b/net/sunrpc/sched.c
@@ -0,0 +1,1361 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/sched.c
+ *
+ * Scheduling for synchronous and asynchronous RPC requests.
+ *
+ * Copyright (C) 1996 Olaf Kirch, <okir@monad.swb.de>
+ *
+ * TCP NFS related read + write fixes
+ * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie>
+ */
+
+#include <linux/module.h>
+
+#include <linux/sched.h>
+#include <linux/interrupt.h>
+#include <linux/slab.h>
+#include <linux/mempool.h>
+#include <linux/smp.h>
+#include <linux/spinlock.h>
+#include <linux/mutex.h>
+#include <linux/freezer.h>
+#include <linux/sched/mm.h>
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/metrics.h>
+
+#include "sunrpc.h"
+
+#define CREATE_TRACE_POINTS
+#include <trace/events/sunrpc.h>
+
+/*
+ * RPC slabs and memory pools
+ */
+#define RPC_BUFFER_MAXSIZE (2048)
+#define RPC_BUFFER_POOLSIZE (8)
+#define RPC_TASK_POOLSIZE (8)
+static struct kmem_cache *rpc_task_slabp __read_mostly;
+static struct kmem_cache *rpc_buffer_slabp __read_mostly;
+static mempool_t *rpc_task_mempool __read_mostly;
+static mempool_t *rpc_buffer_mempool __read_mostly;
+
+static void rpc_async_schedule(struct work_struct *);
+static void rpc_release_task(struct rpc_task *task);
+static void __rpc_queue_timer_fn(struct work_struct *);
+
+/*
+ * RPC tasks sit here while waiting for conditions to improve.
+ */
+static struct rpc_wait_queue delay_queue;
+
+/*
+ * rpciod-related stuff
+ */
+struct workqueue_struct *rpciod_workqueue __read_mostly;
+struct workqueue_struct *xprtiod_workqueue __read_mostly;
+EXPORT_SYMBOL_GPL(xprtiod_workqueue);
+
+gfp_t rpc_task_gfp_mask(void)
+{
+ if (current->flags & PF_WQ_WORKER)
+ return GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN;
+ return GFP_KERNEL;
+}
+EXPORT_SYMBOL_GPL(rpc_task_gfp_mask);
+
+bool rpc_task_set_rpc_status(struct rpc_task *task, int rpc_status)
+{
+ if (cmpxchg(&task->tk_rpc_status, 0, rpc_status) == 0)
+ return true;
+ return false;
+}
+
+unsigned long
+rpc_task_timeout(const struct rpc_task *task)
+{
+ unsigned long timeout = READ_ONCE(task->tk_timeout);
+
+ if (timeout != 0) {
+ unsigned long now = jiffies;
+ if (time_before(now, timeout))
+ return timeout - now;
+ }
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_task_timeout);
+
+/*
+ * Disable the timer for a given RPC task. Should be called with
+ * queue->lock and bh_disabled in order to avoid races within
+ * rpc_run_timer().
+ */
+static void
+__rpc_disable_timer(struct rpc_wait_queue *queue, struct rpc_task *task)
+{
+ if (list_empty(&task->u.tk_wait.timer_list))
+ return;
+ task->tk_timeout = 0;
+ list_del(&task->u.tk_wait.timer_list);
+ if (list_empty(&queue->timer_list.list))
+ cancel_delayed_work(&queue->timer_list.dwork);
+}
+
+static void
+rpc_set_queue_timer(struct rpc_wait_queue *queue, unsigned long expires)
+{
+ unsigned long now = jiffies;
+ queue->timer_list.expires = expires;
+ if (time_before_eq(expires, now))
+ expires = 0;
+ else
+ expires -= now;
+ mod_delayed_work(rpciod_workqueue, &queue->timer_list.dwork, expires);
+}
+
+/*
+ * Set up a timer for the current task.
+ */
+static void
+__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task,
+ unsigned long timeout)
+{
+ task->tk_timeout = timeout;
+ if (list_empty(&queue->timer_list.list) || time_before(timeout, queue->timer_list.expires))
+ rpc_set_queue_timer(queue, timeout);
+ list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list);
+}
+
+static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority)
+{
+ if (queue->priority != priority) {
+ queue->priority = priority;
+ queue->nr = 1U << priority;
+ }
+}
+
+static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue)
+{
+ rpc_set_waitqueue_priority(queue, queue->maxpriority);
+}
+
+/*
+ * Add a request to a queue list
+ */
+static void
+__rpc_list_enqueue_task(struct list_head *q, struct rpc_task *task)
+{
+ struct rpc_task *t;
+
+ list_for_each_entry(t, q, u.tk_wait.list) {
+ if (t->tk_owner == task->tk_owner) {
+ list_add_tail(&task->u.tk_wait.links,
+ &t->u.tk_wait.links);
+ /* Cache the queue head in task->u.tk_wait.list */
+ task->u.tk_wait.list.next = q;
+ task->u.tk_wait.list.prev = NULL;
+ return;
+ }
+ }
+ INIT_LIST_HEAD(&task->u.tk_wait.links);
+ list_add_tail(&task->u.tk_wait.list, q);
+}
+
+/*
+ * Remove request from a queue list
+ */
+static void
+__rpc_list_dequeue_task(struct rpc_task *task)
+{
+ struct list_head *q;
+ struct rpc_task *t;
+
+ if (task->u.tk_wait.list.prev == NULL) {
+ list_del(&task->u.tk_wait.links);
+ return;
+ }
+ if (!list_empty(&task->u.tk_wait.links)) {
+ t = list_first_entry(&task->u.tk_wait.links,
+ struct rpc_task,
+ u.tk_wait.links);
+ /* Assume __rpc_list_enqueue_task() cached the queue head */
+ q = t->u.tk_wait.list.next;
+ list_add_tail(&t->u.tk_wait.list, q);
+ list_del(&task->u.tk_wait.links);
+ }
+ list_del(&task->u.tk_wait.list);
+}
+
+/*
+ * Add new request to a priority queue.
+ */
+static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue,
+ struct rpc_task *task,
+ unsigned char queue_priority)
+{
+ if (unlikely(queue_priority > queue->maxpriority))
+ queue_priority = queue->maxpriority;
+ __rpc_list_enqueue_task(&queue->tasks[queue_priority], task);
+}
+
+/*
+ * Add new request to wait queue.
+ */
+static void __rpc_add_wait_queue(struct rpc_wait_queue *queue,
+ struct rpc_task *task,
+ unsigned char queue_priority)
+{
+ INIT_LIST_HEAD(&task->u.tk_wait.timer_list);
+ if (RPC_IS_PRIORITY(queue))
+ __rpc_add_wait_queue_priority(queue, task, queue_priority);
+ else
+ list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]);
+ task->tk_waitqueue = queue;
+ queue->qlen++;
+ /* barrier matches the read in rpc_wake_up_task_queue_locked() */
+ smp_wmb();
+ rpc_set_queued(task);
+}
+
+/*
+ * Remove request from a priority queue.
+ */
+static void __rpc_remove_wait_queue_priority(struct rpc_task *task)
+{
+ __rpc_list_dequeue_task(task);
+}
+
+/*
+ * Remove request from queue.
+ * Note: must be called with spin lock held.
+ */
+static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task)
+{
+ __rpc_disable_timer(queue, task);
+ if (RPC_IS_PRIORITY(queue))
+ __rpc_remove_wait_queue_priority(task);
+ else
+ list_del(&task->u.tk_wait.list);
+ queue->qlen--;
+}
+
+static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues)
+{
+ int i;
+
+ spin_lock_init(&queue->lock);
+ for (i = 0; i < ARRAY_SIZE(queue->tasks); i++)
+ INIT_LIST_HEAD(&queue->tasks[i]);
+ queue->maxpriority = nr_queues - 1;
+ rpc_reset_waitqueue_priority(queue);
+ queue->qlen = 0;
+ queue->timer_list.expires = 0;
+ INIT_DELAYED_WORK(&queue->timer_list.dwork, __rpc_queue_timer_fn);
+ INIT_LIST_HEAD(&queue->timer_list.list);
+ rpc_assign_waitqueue_name(queue, qname);
+}
+
+void rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname)
+{
+ __rpc_init_priority_wait_queue(queue, qname, RPC_NR_PRIORITY);
+}
+EXPORT_SYMBOL_GPL(rpc_init_priority_wait_queue);
+
+void rpc_init_wait_queue(struct rpc_wait_queue *queue, const char *qname)
+{
+ __rpc_init_priority_wait_queue(queue, qname, 1);
+}
+EXPORT_SYMBOL_GPL(rpc_init_wait_queue);
+
+void rpc_destroy_wait_queue(struct rpc_wait_queue *queue)
+{
+ cancel_delayed_work_sync(&queue->timer_list.dwork);
+}
+EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue);
+
+static int rpc_wait_bit_killable(struct wait_bit_key *key, int mode)
+{
+ schedule();
+ if (signal_pending_state(mode, current))
+ return -ERESTARTSYS;
+ return 0;
+}
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) || IS_ENABLED(CONFIG_TRACEPOINTS)
+static void rpc_task_set_debuginfo(struct rpc_task *task)
+{
+ struct rpc_clnt *clnt = task->tk_client;
+
+ /* Might be a task carrying a reverse-direction operation */
+ if (!clnt) {
+ static atomic_t rpc_pid;
+
+ task->tk_pid = atomic_inc_return(&rpc_pid);
+ return;
+ }
+
+ task->tk_pid = atomic_inc_return(&clnt->cl_pid);
+}
+#else
+static inline void rpc_task_set_debuginfo(struct rpc_task *task)
+{
+}
+#endif
+
+static void rpc_set_active(struct rpc_task *task)
+{
+ rpc_task_set_debuginfo(task);
+ set_bit(RPC_TASK_ACTIVE, &task->tk_runstate);
+ trace_rpc_task_begin(task, NULL);
+}
+
+/*
+ * Mark an RPC call as having completed by clearing the 'active' bit
+ * and then waking up all tasks that were sleeping.
+ */
+static int rpc_complete_task(struct rpc_task *task)
+{
+ void *m = &task->tk_runstate;
+ wait_queue_head_t *wq = bit_waitqueue(m, RPC_TASK_ACTIVE);
+ struct wait_bit_key k = __WAIT_BIT_KEY_INITIALIZER(m, RPC_TASK_ACTIVE);
+ unsigned long flags;
+ int ret;
+
+ trace_rpc_task_complete(task, NULL);
+
+ spin_lock_irqsave(&wq->lock, flags);
+ clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate);
+ ret = atomic_dec_and_test(&task->tk_count);
+ if (waitqueue_active(wq))
+ __wake_up_locked_key(wq, TASK_NORMAL, &k);
+ spin_unlock_irqrestore(&wq->lock, flags);
+ return ret;
+}
+
+/*
+ * Allow callers to wait for completion of an RPC call
+ *
+ * Note the use of out_of_line_wait_on_bit() rather than wait_on_bit()
+ * to enforce taking of the wq->lock and hence avoid races with
+ * rpc_complete_task().
+ */
+int rpc_wait_for_completion_task(struct rpc_task *task)
+{
+ return out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE,
+ rpc_wait_bit_killable, TASK_KILLABLE|TASK_FREEZABLE_UNSAFE);
+}
+EXPORT_SYMBOL_GPL(rpc_wait_for_completion_task);
+
+/*
+ * Make an RPC task runnable.
+ *
+ * Note: If the task is ASYNC, and is being made runnable after sitting on an
+ * rpc_wait_queue, this must be called with the queue spinlock held to protect
+ * the wait queue operation.
+ * Note the ordering of rpc_test_and_set_running() and rpc_clear_queued(),
+ * which is needed to ensure that __rpc_execute() doesn't loop (due to the
+ * lockless RPC_IS_QUEUED() test) before we've had a chance to test
+ * the RPC_TASK_RUNNING flag.
+ */
+static void rpc_make_runnable(struct workqueue_struct *wq,
+ struct rpc_task *task)
+{
+ bool need_wakeup = !rpc_test_and_set_running(task);
+
+ rpc_clear_queued(task);
+ if (!need_wakeup)
+ return;
+ if (RPC_IS_ASYNC(task)) {
+ INIT_WORK(&task->u.tk_work, rpc_async_schedule);
+ queue_work(wq, &task->u.tk_work);
+ } else
+ wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED);
+}
+
+/*
+ * Prepare for sleeping on a wait queue.
+ * By always appending tasks to the list we ensure FIFO behavior.
+ * NB: An RPC task will only receive interrupt-driven events as long
+ * as it's on a wait queue.
+ */
+static void __rpc_do_sleep_on_priority(struct rpc_wait_queue *q,
+ struct rpc_task *task,
+ unsigned char queue_priority)
+{
+ trace_rpc_task_sleep(task, q);
+
+ __rpc_add_wait_queue(q, task, queue_priority);
+}
+
+static void __rpc_sleep_on_priority(struct rpc_wait_queue *q,
+ struct rpc_task *task,
+ unsigned char queue_priority)
+{
+ if (WARN_ON_ONCE(RPC_IS_QUEUED(task)))
+ return;
+ __rpc_do_sleep_on_priority(q, task, queue_priority);
+}
+
+static void __rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q,
+ struct rpc_task *task, unsigned long timeout,
+ unsigned char queue_priority)
+{
+ if (WARN_ON_ONCE(RPC_IS_QUEUED(task)))
+ return;
+ if (time_is_after_jiffies(timeout)) {
+ __rpc_do_sleep_on_priority(q, task, queue_priority);
+ __rpc_add_timer(q, task, timeout);
+ } else
+ task->tk_status = -ETIMEDOUT;
+}
+
+static void rpc_set_tk_callback(struct rpc_task *task, rpc_action action)
+{
+ if (action && !WARN_ON_ONCE(task->tk_callback != NULL))
+ task->tk_callback = action;
+}
+
+static bool rpc_sleep_check_activated(struct rpc_task *task)
+{
+ /* We shouldn't ever put an inactive task to sleep */
+ if (WARN_ON_ONCE(!RPC_IS_ACTIVATED(task))) {
+ task->tk_status = -EIO;
+ rpc_put_task_async(task);
+ return false;
+ }
+ return true;
+}
+
+void rpc_sleep_on_timeout(struct rpc_wait_queue *q, struct rpc_task *task,
+ rpc_action action, unsigned long timeout)
+{
+ if (!rpc_sleep_check_activated(task))
+ return;
+
+ rpc_set_tk_callback(task, action);
+
+ /*
+ * Protect the queue operations.
+ */
+ spin_lock(&q->lock);
+ __rpc_sleep_on_priority_timeout(q, task, timeout, task->tk_priority);
+ spin_unlock(&q->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_sleep_on_timeout);
+
+void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task,
+ rpc_action action)
+{
+ if (!rpc_sleep_check_activated(task))
+ return;
+
+ rpc_set_tk_callback(task, action);
+
+ WARN_ON_ONCE(task->tk_timeout != 0);
+ /*
+ * Protect the queue operations.
+ */
+ spin_lock(&q->lock);
+ __rpc_sleep_on_priority(q, task, task->tk_priority);
+ spin_unlock(&q->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_sleep_on);
+
+void rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q,
+ struct rpc_task *task, unsigned long timeout, int priority)
+{
+ if (!rpc_sleep_check_activated(task))
+ return;
+
+ priority -= RPC_PRIORITY_LOW;
+ /*
+ * Protect the queue operations.
+ */
+ spin_lock(&q->lock);
+ __rpc_sleep_on_priority_timeout(q, task, timeout, priority);
+ spin_unlock(&q->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_sleep_on_priority_timeout);
+
+void rpc_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task,
+ int priority)
+{
+ if (!rpc_sleep_check_activated(task))
+ return;
+
+ WARN_ON_ONCE(task->tk_timeout != 0);
+ priority -= RPC_PRIORITY_LOW;
+ /*
+ * Protect the queue operations.
+ */
+ spin_lock(&q->lock);
+ __rpc_sleep_on_priority(q, task, priority);
+ spin_unlock(&q->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_sleep_on_priority);
+
+/**
+ * __rpc_do_wake_up_task_on_wq - wake up a single rpc_task
+ * @wq: workqueue on which to run task
+ * @queue: wait queue
+ * @task: task to be woken up
+ *
+ * Caller must hold queue->lock, and have cleared the task queued flag.
+ */
+static void __rpc_do_wake_up_task_on_wq(struct workqueue_struct *wq,
+ struct rpc_wait_queue *queue,
+ struct rpc_task *task)
+{
+ /* Has the task been executed yet? If not, we cannot wake it up! */
+ if (!RPC_IS_ACTIVATED(task)) {
+ printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task);
+ return;
+ }
+
+ trace_rpc_task_wakeup(task, queue);
+
+ __rpc_remove_wait_queue(queue, task);
+
+ rpc_make_runnable(wq, task);
+}
+
+/*
+ * Wake up a queued task while the queue lock is being held
+ */
+static struct rpc_task *
+rpc_wake_up_task_on_wq_queue_action_locked(struct workqueue_struct *wq,
+ struct rpc_wait_queue *queue, struct rpc_task *task,
+ bool (*action)(struct rpc_task *, void *), void *data)
+{
+ if (RPC_IS_QUEUED(task)) {
+ smp_rmb();
+ if (task->tk_waitqueue == queue) {
+ if (action == NULL || action(task, data)) {
+ __rpc_do_wake_up_task_on_wq(wq, queue, task);
+ return task;
+ }
+ }
+ }
+ return NULL;
+}
+
+/*
+ * Wake up a queued task while the queue lock is being held
+ */
+static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue,
+ struct rpc_task *task)
+{
+ rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue,
+ task, NULL, NULL);
+}
+
+/*
+ * Wake up a task on a specific queue
+ */
+void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task)
+{
+ if (!RPC_IS_QUEUED(task))
+ return;
+ spin_lock(&queue->lock);
+ rpc_wake_up_task_queue_locked(queue, task);
+ spin_unlock(&queue->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task);
+
+static bool rpc_task_action_set_status(struct rpc_task *task, void *status)
+{
+ task->tk_status = *(int *)status;
+ return true;
+}
+
+static void
+rpc_wake_up_task_queue_set_status_locked(struct rpc_wait_queue *queue,
+ struct rpc_task *task, int status)
+{
+ rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue,
+ task, rpc_task_action_set_status, &status);
+}
+
+/**
+ * rpc_wake_up_queued_task_set_status - wake up a task and set task->tk_status
+ * @queue: pointer to rpc_wait_queue
+ * @task: pointer to rpc_task
+ * @status: integer error value
+ *
+ * If @task is queued on @queue, then it is woken up, and @task->tk_status is
+ * set to the value of @status.
+ */
+void
+rpc_wake_up_queued_task_set_status(struct rpc_wait_queue *queue,
+ struct rpc_task *task, int status)
+{
+ if (!RPC_IS_QUEUED(task))
+ return;
+ spin_lock(&queue->lock);
+ rpc_wake_up_task_queue_set_status_locked(queue, task, status);
+ spin_unlock(&queue->lock);
+}
+
+/*
+ * Wake up the next task on a priority queue.
+ */
+static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *queue)
+{
+ struct list_head *q;
+ struct rpc_task *task;
+
+ /*
+ * Service the privileged queue.
+ */
+ q = &queue->tasks[RPC_NR_PRIORITY - 1];
+ if (queue->maxpriority > RPC_PRIORITY_PRIVILEGED && !list_empty(q)) {
+ task = list_first_entry(q, struct rpc_task, u.tk_wait.list);
+ goto out;
+ }
+
+ /*
+ * Service a batch of tasks from a single owner.
+ */
+ q = &queue->tasks[queue->priority];
+ if (!list_empty(q) && queue->nr) {
+ queue->nr--;
+ task = list_first_entry(q, struct rpc_task, u.tk_wait.list);
+ goto out;
+ }
+
+ /*
+ * Service the next queue.
+ */
+ do {
+ if (q == &queue->tasks[0])
+ q = &queue->tasks[queue->maxpriority];
+ else
+ q = q - 1;
+ if (!list_empty(q)) {
+ task = list_first_entry(q, struct rpc_task, u.tk_wait.list);
+ goto new_queue;
+ }
+ } while (q != &queue->tasks[queue->priority]);
+
+ rpc_reset_waitqueue_priority(queue);
+ return NULL;
+
+new_queue:
+ rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0]));
+out:
+ return task;
+}
+
+static struct rpc_task *__rpc_find_next_queued(struct rpc_wait_queue *queue)
+{
+ if (RPC_IS_PRIORITY(queue))
+ return __rpc_find_next_queued_priority(queue);
+ if (!list_empty(&queue->tasks[0]))
+ return list_first_entry(&queue->tasks[0], struct rpc_task, u.tk_wait.list);
+ return NULL;
+}
+
+/*
+ * Wake up the first task on the wait queue.
+ */
+struct rpc_task *rpc_wake_up_first_on_wq(struct workqueue_struct *wq,
+ struct rpc_wait_queue *queue,
+ bool (*func)(struct rpc_task *, void *), void *data)
+{
+ struct rpc_task *task = NULL;
+
+ spin_lock(&queue->lock);
+ task = __rpc_find_next_queued(queue);
+ if (task != NULL)
+ task = rpc_wake_up_task_on_wq_queue_action_locked(wq, queue,
+ task, func, data);
+ spin_unlock(&queue->lock);
+
+ return task;
+}
+
+/*
+ * Wake up the first task on the wait queue.
+ */
+struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue,
+ bool (*func)(struct rpc_task *, void *), void *data)
+{
+ return rpc_wake_up_first_on_wq(rpciod_workqueue, queue, func, data);
+}
+EXPORT_SYMBOL_GPL(rpc_wake_up_first);
+
+static bool rpc_wake_up_next_func(struct rpc_task *task, void *data)
+{
+ return true;
+}
+
+/*
+ * Wake up the next task on the wait queue.
+*/
+struct rpc_task *rpc_wake_up_next(struct rpc_wait_queue *queue)
+{
+ return rpc_wake_up_first(queue, rpc_wake_up_next_func, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_wake_up_next);
+
+/**
+ * rpc_wake_up_locked - wake up all rpc_tasks
+ * @queue: rpc_wait_queue on which the tasks are sleeping
+ *
+ */
+static void rpc_wake_up_locked(struct rpc_wait_queue *queue)
+{
+ struct rpc_task *task;
+
+ for (;;) {
+ task = __rpc_find_next_queued(queue);
+ if (task == NULL)
+ break;
+ rpc_wake_up_task_queue_locked(queue, task);
+ }
+}
+
+/**
+ * rpc_wake_up - wake up all rpc_tasks
+ * @queue: rpc_wait_queue on which the tasks are sleeping
+ *
+ * Grabs queue->lock
+ */
+void rpc_wake_up(struct rpc_wait_queue *queue)
+{
+ spin_lock(&queue->lock);
+ rpc_wake_up_locked(queue);
+ spin_unlock(&queue->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_wake_up);
+
+/**
+ * rpc_wake_up_status_locked - wake up all rpc_tasks and set their status value.
+ * @queue: rpc_wait_queue on which the tasks are sleeping
+ * @status: status value to set
+ */
+static void rpc_wake_up_status_locked(struct rpc_wait_queue *queue, int status)
+{
+ struct rpc_task *task;
+
+ for (;;) {
+ task = __rpc_find_next_queued(queue);
+ if (task == NULL)
+ break;
+ rpc_wake_up_task_queue_set_status_locked(queue, task, status);
+ }
+}
+
+/**
+ * rpc_wake_up_status - wake up all rpc_tasks and set their status value.
+ * @queue: rpc_wait_queue on which the tasks are sleeping
+ * @status: status value to set
+ *
+ * Grabs queue->lock
+ */
+void rpc_wake_up_status(struct rpc_wait_queue *queue, int status)
+{
+ spin_lock(&queue->lock);
+ rpc_wake_up_status_locked(queue, status);
+ spin_unlock(&queue->lock);
+}
+EXPORT_SYMBOL_GPL(rpc_wake_up_status);
+
+static void __rpc_queue_timer_fn(struct work_struct *work)
+{
+ struct rpc_wait_queue *queue = container_of(work,
+ struct rpc_wait_queue,
+ timer_list.dwork.work);
+ struct rpc_task *task, *n;
+ unsigned long expires, now, timeo;
+
+ spin_lock(&queue->lock);
+ expires = now = jiffies;
+ list_for_each_entry_safe(task, n, &queue->timer_list.list, u.tk_wait.timer_list) {
+ timeo = task->tk_timeout;
+ if (time_after_eq(now, timeo)) {
+ trace_rpc_task_timeout(task, task->tk_action);
+ task->tk_status = -ETIMEDOUT;
+ rpc_wake_up_task_queue_locked(queue, task);
+ continue;
+ }
+ if (expires == now || time_after(expires, timeo))
+ expires = timeo;
+ }
+ if (!list_empty(&queue->timer_list.list))
+ rpc_set_queue_timer(queue, expires);
+ spin_unlock(&queue->lock);
+}
+
+static void __rpc_atrun(struct rpc_task *task)
+{
+ if (task->tk_status == -ETIMEDOUT)
+ task->tk_status = 0;
+}
+
+/*
+ * Run a task at a later time
+ */
+void rpc_delay(struct rpc_task *task, unsigned long delay)
+{
+ rpc_sleep_on_timeout(&delay_queue, task, __rpc_atrun, jiffies + delay);
+}
+EXPORT_SYMBOL_GPL(rpc_delay);
+
+/*
+ * Helper to call task->tk_ops->rpc_call_prepare
+ */
+void rpc_prepare_task(struct rpc_task *task)
+{
+ task->tk_ops->rpc_call_prepare(task, task->tk_calldata);
+}
+
+static void
+rpc_init_task_statistics(struct rpc_task *task)
+{
+ /* Initialize retry counters */
+ task->tk_garb_retry = 2;
+ task->tk_cred_retry = 2;
+
+ /* starting timestamp */
+ task->tk_start = ktime_get();
+}
+
+static void
+rpc_reset_task_statistics(struct rpc_task *task)
+{
+ task->tk_timeouts = 0;
+ task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_SENT);
+ rpc_init_task_statistics(task);
+}
+
+/*
+ * Helper that calls task->tk_ops->rpc_call_done if it exists
+ */
+void rpc_exit_task(struct rpc_task *task)
+{
+ trace_rpc_task_end(task, task->tk_action);
+ task->tk_action = NULL;
+ if (task->tk_ops->rpc_count_stats)
+ task->tk_ops->rpc_count_stats(task, task->tk_calldata);
+ else if (task->tk_client)
+ rpc_count_iostats(task, task->tk_client->cl_metrics);
+ if (task->tk_ops->rpc_call_done != NULL) {
+ trace_rpc_task_call_done(task, task->tk_ops->rpc_call_done);
+ task->tk_ops->rpc_call_done(task, task->tk_calldata);
+ if (task->tk_action != NULL) {
+ /* Always release the RPC slot and buffer memory */
+ xprt_release(task);
+ rpc_reset_task_statistics(task);
+ }
+ }
+}
+
+void rpc_signal_task(struct rpc_task *task)
+{
+ struct rpc_wait_queue *queue;
+
+ if (!RPC_IS_ACTIVATED(task))
+ return;
+
+ if (!rpc_task_set_rpc_status(task, -ERESTARTSYS))
+ return;
+ trace_rpc_task_signalled(task, task->tk_action);
+ set_bit(RPC_TASK_SIGNALLED, &task->tk_runstate);
+ smp_mb__after_atomic();
+ queue = READ_ONCE(task->tk_waitqueue);
+ if (queue)
+ rpc_wake_up_queued_task(queue, task);
+}
+
+void rpc_task_try_cancel(struct rpc_task *task, int error)
+{
+ struct rpc_wait_queue *queue;
+
+ if (!rpc_task_set_rpc_status(task, error))
+ return;
+ queue = READ_ONCE(task->tk_waitqueue);
+ if (queue)
+ rpc_wake_up_queued_task(queue, task);
+}
+
+void rpc_exit(struct rpc_task *task, int status)
+{
+ task->tk_status = status;
+ task->tk_action = rpc_exit_task;
+ rpc_wake_up_queued_task(task->tk_waitqueue, task);
+}
+EXPORT_SYMBOL_GPL(rpc_exit);
+
+void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata)
+{
+ if (ops->rpc_release != NULL)
+ ops->rpc_release(calldata);
+}
+
+static bool xprt_needs_memalloc(struct rpc_xprt *xprt, struct rpc_task *tk)
+{
+ if (!xprt)
+ return false;
+ if (!atomic_read(&xprt->swapper))
+ return false;
+ return test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == tk;
+}
+
+/*
+ * This is the RPC `scheduler' (or rather, the finite state machine).
+ */
+static void __rpc_execute(struct rpc_task *task)
+{
+ struct rpc_wait_queue *queue;
+ int task_is_async = RPC_IS_ASYNC(task);
+ int status = 0;
+ unsigned long pflags = current->flags;
+
+ WARN_ON_ONCE(RPC_IS_QUEUED(task));
+ if (RPC_IS_QUEUED(task))
+ return;
+
+ for (;;) {
+ void (*do_action)(struct rpc_task *);
+
+ /*
+ * Perform the next FSM step or a pending callback.
+ *
+ * tk_action may be NULL if the task has been killed.
+ */
+ do_action = task->tk_action;
+ /* Tasks with an RPC error status should exit */
+ if (do_action && do_action != rpc_exit_task &&
+ (status = READ_ONCE(task->tk_rpc_status)) != 0) {
+ task->tk_status = status;
+ do_action = rpc_exit_task;
+ }
+ /* Callbacks override all actions */
+ if (task->tk_callback) {
+ do_action = task->tk_callback;
+ task->tk_callback = NULL;
+ }
+ if (!do_action)
+ break;
+ if (RPC_IS_SWAPPER(task) ||
+ xprt_needs_memalloc(task->tk_xprt, task))
+ current->flags |= PF_MEMALLOC;
+
+ trace_rpc_task_run_action(task, do_action);
+ do_action(task);
+
+ /*
+ * Lockless check for whether task is sleeping or not.
+ */
+ if (!RPC_IS_QUEUED(task)) {
+ cond_resched();
+ continue;
+ }
+
+ /*
+ * The queue->lock protects against races with
+ * rpc_make_runnable().
+ *
+ * Note that once we clear RPC_TASK_RUNNING on an asynchronous
+ * rpc_task, rpc_make_runnable() can assign it to a
+ * different workqueue. We therefore cannot assume that the
+ * rpc_task pointer may still be dereferenced.
+ */
+ queue = task->tk_waitqueue;
+ spin_lock(&queue->lock);
+ if (!RPC_IS_QUEUED(task)) {
+ spin_unlock(&queue->lock);
+ continue;
+ }
+ /* Wake up any task that has an exit status */
+ if (READ_ONCE(task->tk_rpc_status) != 0) {
+ rpc_wake_up_task_queue_locked(queue, task);
+ spin_unlock(&queue->lock);
+ continue;
+ }
+ rpc_clear_running(task);
+ spin_unlock(&queue->lock);
+ if (task_is_async)
+ goto out;
+
+ /* sync task: sleep here */
+ trace_rpc_task_sync_sleep(task, task->tk_action);
+ status = out_of_line_wait_on_bit(&task->tk_runstate,
+ RPC_TASK_QUEUED, rpc_wait_bit_killable,
+ TASK_KILLABLE|TASK_FREEZABLE);
+ if (status < 0) {
+ /*
+ * When a sync task receives a signal, it exits with
+ * -ERESTARTSYS. In order to catch any callbacks that
+ * clean up after sleeping on some queue, we don't
+ * break the loop here, but go around once more.
+ */
+ rpc_signal_task(task);
+ }
+ trace_rpc_task_sync_wake(task, task->tk_action);
+ }
+
+ /* Release all resources associated with the task */
+ rpc_release_task(task);
+out:
+ current_restore_flags(pflags, PF_MEMALLOC);
+}
+
+/*
+ * User-visible entry point to the scheduler.
+ *
+ * This may be called recursively if e.g. an async NFS task updates
+ * the attributes and finds that dirty pages must be flushed.
+ * NOTE: Upon exit of this function the task is guaranteed to be
+ * released. In particular note that tk_release() will have
+ * been called, so your task memory may have been freed.
+ */
+void rpc_execute(struct rpc_task *task)
+{
+ bool is_async = RPC_IS_ASYNC(task);
+
+ rpc_set_active(task);
+ rpc_make_runnable(rpciod_workqueue, task);
+ if (!is_async) {
+ unsigned int pflags = memalloc_nofs_save();
+ __rpc_execute(task);
+ memalloc_nofs_restore(pflags);
+ }
+}
+
+static void rpc_async_schedule(struct work_struct *work)
+{
+ unsigned int pflags = memalloc_nofs_save();
+
+ __rpc_execute(container_of(work, struct rpc_task, u.tk_work));
+ memalloc_nofs_restore(pflags);
+}
+
+/**
+ * rpc_malloc - allocate RPC buffer resources
+ * @task: RPC task
+ *
+ * A single memory region is allocated, which is split between the
+ * RPC call and RPC reply that this task is being used for. When
+ * this RPC is retired, the memory is released by calling rpc_free.
+ *
+ * To prevent rpciod from hanging, this allocator never sleeps,
+ * returning -ENOMEM and suppressing warning if the request cannot
+ * be serviced immediately. The caller can arrange to sleep in a
+ * way that is safe for rpciod.
+ *
+ * Most requests are 'small' (under 2KiB) and can be serviced from a
+ * mempool, ensuring that NFS reads and writes can always proceed,
+ * and that there is good locality of reference for these buffers.
+ */
+int rpc_malloc(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+ size_t size = rqst->rq_callsize + rqst->rq_rcvsize;
+ struct rpc_buffer *buf;
+ gfp_t gfp = rpc_task_gfp_mask();
+
+ size += sizeof(struct rpc_buffer);
+ if (size <= RPC_BUFFER_MAXSIZE) {
+ buf = kmem_cache_alloc(rpc_buffer_slabp, gfp);
+ /* Reach for the mempool if dynamic allocation fails */
+ if (!buf && RPC_IS_ASYNC(task))
+ buf = mempool_alloc(rpc_buffer_mempool, GFP_NOWAIT);
+ } else
+ buf = kmalloc(size, gfp);
+
+ if (!buf)
+ return -ENOMEM;
+
+ buf->len = size;
+ rqst->rq_buffer = buf->data;
+ rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(rpc_malloc);
+
+/**
+ * rpc_free - free RPC buffer resources allocated via rpc_malloc
+ * @task: RPC task
+ *
+ */
+void rpc_free(struct rpc_task *task)
+{
+ void *buffer = task->tk_rqstp->rq_buffer;
+ size_t size;
+ struct rpc_buffer *buf;
+
+ buf = container_of(buffer, struct rpc_buffer, data);
+ size = buf->len;
+
+ if (size <= RPC_BUFFER_MAXSIZE)
+ mempool_free(buf, rpc_buffer_mempool);
+ else
+ kfree(buf);
+}
+EXPORT_SYMBOL_GPL(rpc_free);
+
+/*
+ * Creation and deletion of RPC task structures
+ */
+static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *task_setup_data)
+{
+ memset(task, 0, sizeof(*task));
+ atomic_set(&task->tk_count, 1);
+ task->tk_flags = task_setup_data->flags;
+ task->tk_ops = task_setup_data->callback_ops;
+ task->tk_calldata = task_setup_data->callback_data;
+ INIT_LIST_HEAD(&task->tk_task);
+
+ task->tk_priority = task_setup_data->priority - RPC_PRIORITY_LOW;
+ task->tk_owner = current->tgid;
+
+ /* Initialize workqueue for async tasks */
+ task->tk_workqueue = task_setup_data->workqueue;
+
+ task->tk_xprt = rpc_task_get_xprt(task_setup_data->rpc_client,
+ xprt_get(task_setup_data->rpc_xprt));
+
+ task->tk_op_cred = get_rpccred(task_setup_data->rpc_op_cred);
+
+ if (task->tk_ops->rpc_call_prepare != NULL)
+ task->tk_action = rpc_prepare_task;
+
+ rpc_init_task_statistics(task);
+}
+
+static struct rpc_task *rpc_alloc_task(void)
+{
+ struct rpc_task *task;
+
+ task = kmem_cache_alloc(rpc_task_slabp, rpc_task_gfp_mask());
+ if (task)
+ return task;
+ return mempool_alloc(rpc_task_mempool, GFP_NOWAIT);
+}
+
+/*
+ * Create a new task for the specified client.
+ */
+struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data)
+{
+ struct rpc_task *task = setup_data->task;
+ unsigned short flags = 0;
+
+ if (task == NULL) {
+ task = rpc_alloc_task();
+ if (task == NULL) {
+ rpc_release_calldata(setup_data->callback_ops,
+ setup_data->callback_data);
+ return ERR_PTR(-ENOMEM);
+ }
+ flags = RPC_TASK_DYNAMIC;
+ }
+
+ rpc_init_task(task, setup_data);
+ task->tk_flags |= flags;
+ return task;
+}
+
+/*
+ * rpc_free_task - release rpc task and perform cleanups
+ *
+ * Note that we free up the rpc_task _after_ rpc_release_calldata()
+ * in order to work around a workqueue dependency issue.
+ *
+ * Tejun Heo states:
+ * "Workqueue currently considers two work items to be the same if they're
+ * on the same address and won't execute them concurrently - ie. it
+ * makes a work item which is queued again while being executed wait
+ * for the previous execution to complete.
+ *
+ * If a work function frees the work item, and then waits for an event
+ * which should be performed by another work item and *that* work item
+ * recycles the freed work item, it can create a false dependency loop.
+ * There really is no reliable way to detect this short of verifying
+ * every memory free."
+ *
+ */
+static void rpc_free_task(struct rpc_task *task)
+{
+ unsigned short tk_flags = task->tk_flags;
+
+ put_rpccred(task->tk_op_cred);
+ rpc_release_calldata(task->tk_ops, task->tk_calldata);
+
+ if (tk_flags & RPC_TASK_DYNAMIC)
+ mempool_free(task, rpc_task_mempool);
+}
+
+static void rpc_async_release(struct work_struct *work)
+{
+ unsigned int pflags = memalloc_nofs_save();
+
+ rpc_free_task(container_of(work, struct rpc_task, u.tk_work));
+ memalloc_nofs_restore(pflags);
+}
+
+static void rpc_release_resources_task(struct rpc_task *task)
+{
+ xprt_release(task);
+ if (task->tk_msg.rpc_cred) {
+ if (!(task->tk_flags & RPC_TASK_CRED_NOREF))
+ put_cred(task->tk_msg.rpc_cred);
+ task->tk_msg.rpc_cred = NULL;
+ }
+ rpc_task_release_client(task);
+}
+
+static void rpc_final_put_task(struct rpc_task *task,
+ struct workqueue_struct *q)
+{
+ if (q != NULL) {
+ INIT_WORK(&task->u.tk_work, rpc_async_release);
+ queue_work(q, &task->u.tk_work);
+ } else
+ rpc_free_task(task);
+}
+
+static void rpc_do_put_task(struct rpc_task *task, struct workqueue_struct *q)
+{
+ if (atomic_dec_and_test(&task->tk_count)) {
+ rpc_release_resources_task(task);
+ rpc_final_put_task(task, q);
+ }
+}
+
+void rpc_put_task(struct rpc_task *task)
+{
+ rpc_do_put_task(task, NULL);
+}
+EXPORT_SYMBOL_GPL(rpc_put_task);
+
+void rpc_put_task_async(struct rpc_task *task)
+{
+ rpc_do_put_task(task, task->tk_workqueue);
+}
+EXPORT_SYMBOL_GPL(rpc_put_task_async);
+
+static void rpc_release_task(struct rpc_task *task)
+{
+ WARN_ON_ONCE(RPC_IS_QUEUED(task));
+
+ rpc_release_resources_task(task);
+
+ /*
+ * Note: at this point we have been removed from rpc_clnt->cl_tasks,
+ * so it should be safe to use task->tk_count as a test for whether
+ * or not any other processes still hold references to our rpc_task.
+ */
+ if (atomic_read(&task->tk_count) != 1 + !RPC_IS_ASYNC(task)) {
+ /* Wake up anyone who may be waiting for task completion */
+ if (!rpc_complete_task(task))
+ return;
+ } else {
+ if (!atomic_dec_and_test(&task->tk_count))
+ return;
+ }
+ rpc_final_put_task(task, task->tk_workqueue);
+}
+
+int rpciod_up(void)
+{
+ return try_module_get(THIS_MODULE) ? 0 : -EINVAL;
+}
+
+void rpciod_down(void)
+{
+ module_put(THIS_MODULE);
+}
+
+/*
+ * Start up the rpciod workqueue.
+ */
+static int rpciod_start(void)
+{
+ struct workqueue_struct *wq;
+
+ /*
+ * Create the rpciod thread and wait for it to start.
+ */
+ wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM | WQ_UNBOUND, 0);
+ if (!wq)
+ goto out_failed;
+ rpciod_workqueue = wq;
+ wq = alloc_workqueue("xprtiod", WQ_UNBOUND | WQ_MEM_RECLAIM, 0);
+ if (!wq)
+ goto free_rpciod;
+ xprtiod_workqueue = wq;
+ return 1;
+free_rpciod:
+ wq = rpciod_workqueue;
+ rpciod_workqueue = NULL;
+ destroy_workqueue(wq);
+out_failed:
+ return 0;
+}
+
+static void rpciod_stop(void)
+{
+ struct workqueue_struct *wq = NULL;
+
+ if (rpciod_workqueue == NULL)
+ return;
+
+ wq = rpciod_workqueue;
+ rpciod_workqueue = NULL;
+ destroy_workqueue(wq);
+ wq = xprtiod_workqueue;
+ xprtiod_workqueue = NULL;
+ destroy_workqueue(wq);
+}
+
+void
+rpc_destroy_mempool(void)
+{
+ rpciod_stop();
+ mempool_destroy(rpc_buffer_mempool);
+ mempool_destroy(rpc_task_mempool);
+ kmem_cache_destroy(rpc_task_slabp);
+ kmem_cache_destroy(rpc_buffer_slabp);
+ rpc_destroy_wait_queue(&delay_queue);
+}
+
+int
+rpc_init_mempool(void)
+{
+ /*
+ * The following is not strictly a mempool initialisation,
+ * but there is no harm in doing it here
+ */
+ rpc_init_wait_queue(&delay_queue, "delayq");
+ if (!rpciod_start())
+ goto err_nomem;
+
+ rpc_task_slabp = kmem_cache_create("rpc_tasks",
+ sizeof(struct rpc_task),
+ 0, SLAB_HWCACHE_ALIGN,
+ NULL);
+ if (!rpc_task_slabp)
+ goto err_nomem;
+ rpc_buffer_slabp = kmem_cache_create("rpc_buffers",
+ RPC_BUFFER_MAXSIZE,
+ 0, SLAB_HWCACHE_ALIGN,
+ NULL);
+ if (!rpc_buffer_slabp)
+ goto err_nomem;
+ rpc_task_mempool = mempool_create_slab_pool(RPC_TASK_POOLSIZE,
+ rpc_task_slabp);
+ if (!rpc_task_mempool)
+ goto err_nomem;
+ rpc_buffer_mempool = mempool_create_slab_pool(RPC_BUFFER_POOLSIZE,
+ rpc_buffer_slabp);
+ if (!rpc_buffer_mempool)
+ goto err_nomem;
+ return 0;
+err_nomem:
+ rpc_destroy_mempool();
+ return -ENOMEM;
+}
diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c
new file mode 100644
index 0000000000..1b2b84feee
--- /dev/null
+++ b/net/sunrpc/socklib.c
@@ -0,0 +1,324 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/socklib.c
+ *
+ * Common socket helper routines for RPC client and server
+ *
+ * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/compiler.h>
+#include <linux/netdevice.h>
+#include <linux/gfp.h>
+#include <linux/skbuff.h>
+#include <linux/types.h>
+#include <linux/pagemap.h>
+#include <linux/udp.h>
+#include <linux/sunrpc/msg_prot.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/export.h>
+
+#include "socklib.h"
+
+/*
+ * Helper structure for copying from an sk_buff.
+ */
+struct xdr_skb_reader {
+ struct sk_buff *skb;
+ unsigned int offset;
+ size_t count;
+ __wsum csum;
+};
+
+typedef size_t (*xdr_skb_read_actor)(struct xdr_skb_reader *desc, void *to,
+ size_t len);
+
+/**
+ * xdr_skb_read_bits - copy some data bits from skb to internal buffer
+ * @desc: sk_buff copy helper
+ * @to: copy destination
+ * @len: number of bytes to copy
+ *
+ * Possibly called several times to iterate over an sk_buff and copy
+ * data out of it.
+ */
+static size_t
+xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len)
+{
+ if (len > desc->count)
+ len = desc->count;
+ if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len)))
+ return 0;
+ desc->count -= len;
+ desc->offset += len;
+ return len;
+}
+
+/**
+ * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer
+ * @desc: sk_buff copy helper
+ * @to: copy destination
+ * @len: number of bytes to copy
+ *
+ * Same as skb_read_bits, but calculate a checksum at the same time.
+ */
+static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, size_t len)
+{
+ unsigned int pos;
+ __wsum csum2;
+
+ if (len > desc->count)
+ len = desc->count;
+ pos = desc->offset;
+ csum2 = skb_copy_and_csum_bits(desc->skb, pos, to, len);
+ desc->csum = csum_block_add(desc->csum, csum2, pos);
+ desc->count -= len;
+ desc->offset += len;
+ return len;
+}
+
+/**
+ * xdr_partial_copy_from_skb - copy data out of an skb
+ * @xdr: target XDR buffer
+ * @base: starting offset
+ * @desc: sk_buff copy helper
+ * @copy_actor: virtual method for copying data
+ *
+ */
+static ssize_t
+xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor)
+{
+ struct page **ppage = xdr->pages;
+ unsigned int len, pglen = xdr->page_len;
+ ssize_t copied = 0;
+ size_t ret;
+
+ len = xdr->head[0].iov_len;
+ if (base < len) {
+ len -= base;
+ ret = copy_actor(desc, (char *)xdr->head[0].iov_base + base, len);
+ copied += ret;
+ if (ret != len || !desc->count)
+ goto out;
+ base = 0;
+ } else
+ base -= len;
+
+ if (unlikely(pglen == 0))
+ goto copy_tail;
+ if (unlikely(base >= pglen)) {
+ base -= pglen;
+ goto copy_tail;
+ }
+ if (base || xdr->page_base) {
+ pglen -= base;
+ base += xdr->page_base;
+ ppage += base >> PAGE_SHIFT;
+ base &= ~PAGE_MASK;
+ }
+ do {
+ char *kaddr;
+
+ /* ACL likes to be lazy in allocating pages - ACLs
+ * are small by default but can get huge. */
+ if ((xdr->flags & XDRBUF_SPARSE_PAGES) && *ppage == NULL) {
+ *ppage = alloc_page(GFP_NOWAIT | __GFP_NOWARN);
+ if (unlikely(*ppage == NULL)) {
+ if (copied == 0)
+ copied = -ENOMEM;
+ goto out;
+ }
+ }
+
+ len = PAGE_SIZE;
+ kaddr = kmap_atomic(*ppage);
+ if (base) {
+ len -= base;
+ if (pglen < len)
+ len = pglen;
+ ret = copy_actor(desc, kaddr + base, len);
+ base = 0;
+ } else {
+ if (pglen < len)
+ len = pglen;
+ ret = copy_actor(desc, kaddr, len);
+ }
+ flush_dcache_page(*ppage);
+ kunmap_atomic(kaddr);
+ copied += ret;
+ if (ret != len || !desc->count)
+ goto out;
+ ppage++;
+ } while ((pglen -= len) != 0);
+copy_tail:
+ len = xdr->tail[0].iov_len;
+ if (base < len)
+ copied += copy_actor(desc, (char *)xdr->tail[0].iov_base + base, len - base);
+out:
+ return copied;
+}
+
+/**
+ * csum_partial_copy_to_xdr - checksum and copy data
+ * @xdr: target XDR buffer
+ * @skb: source skb
+ *
+ * We have set things up such that we perform the checksum of the UDP
+ * packet in parallel with the copies into the RPC client iovec. -DaveM
+ */
+int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb)
+{
+ struct xdr_skb_reader desc;
+
+ desc.skb = skb;
+ desc.offset = 0;
+ desc.count = skb->len - desc.offset;
+
+ if (skb_csum_unnecessary(skb))
+ goto no_checksum;
+
+ desc.csum = csum_partial(skb->data, desc.offset, skb->csum);
+ if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_and_csum_bits) < 0)
+ return -1;
+ if (desc.offset != skb->len) {
+ __wsum csum2;
+ csum2 = skb_checksum(skb, desc.offset, skb->len - desc.offset, 0);
+ desc.csum = csum_block_add(desc.csum, csum2, desc.offset);
+ }
+ if (desc.count)
+ return -1;
+ if (csum_fold(desc.csum))
+ return -1;
+ if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) &&
+ !skb->csum_complete_sw)
+ netdev_rx_csum_fault(skb->dev, skb);
+ return 0;
+no_checksum:
+ if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0)
+ return -1;
+ if (desc.count)
+ return -1;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(csum_partial_copy_to_xdr);
+
+static inline int xprt_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t seek)
+{
+ if (seek)
+ iov_iter_advance(&msg->msg_iter, seek);
+ return sock_sendmsg(sock, msg);
+}
+
+static int xprt_send_kvec(struct socket *sock, struct msghdr *msg,
+ struct kvec *vec, size_t seek)
+{
+ iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, 1, vec->iov_len);
+ return xprt_sendmsg(sock, msg, seek);
+}
+
+static int xprt_send_pagedata(struct socket *sock, struct msghdr *msg,
+ struct xdr_buf *xdr, size_t base)
+{
+ iov_iter_bvec(&msg->msg_iter, ITER_SOURCE, xdr->bvec, xdr_buf_pagecount(xdr),
+ xdr->page_len + xdr->page_base);
+ return xprt_sendmsg(sock, msg, base + xdr->page_base);
+}
+
+/* Common case:
+ * - stream transport
+ * - sending from byte 0 of the message
+ * - the message is wholly contained in @xdr's head iovec
+ */
+static int xprt_send_rm_and_kvec(struct socket *sock, struct msghdr *msg,
+ rpc_fraghdr marker, struct kvec *vec,
+ size_t base)
+{
+ struct kvec iov[2] = {
+ [0] = {
+ .iov_base = &marker,
+ .iov_len = sizeof(marker)
+ },
+ [1] = *vec,
+ };
+ size_t len = iov[0].iov_len + iov[1].iov_len;
+
+ iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, iov, 2, len);
+ return xprt_sendmsg(sock, msg, base);
+}
+
+/**
+ * xprt_sock_sendmsg - write an xdr_buf directly to a socket
+ * @sock: open socket to send on
+ * @msg: socket message metadata
+ * @xdr: xdr_buf containing this request
+ * @base: starting position in the buffer
+ * @marker: stream record marker field
+ * @sent_p: return the total number of bytes successfully queued for sending
+ *
+ * Return values:
+ * On success, returns zero and fills in @sent_p.
+ * %-ENOTSOCK if @sock is not a struct socket.
+ */
+int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg,
+ struct xdr_buf *xdr, unsigned int base,
+ rpc_fraghdr marker, unsigned int *sent_p)
+{
+ unsigned int rmsize = marker ? sizeof(marker) : 0;
+ unsigned int remainder = rmsize + xdr->len - base;
+ unsigned int want;
+ int err = 0;
+
+ *sent_p = 0;
+
+ if (unlikely(!sock))
+ return -ENOTSOCK;
+
+ msg->msg_flags |= MSG_MORE;
+ want = xdr->head[0].iov_len + rmsize;
+ if (base < want) {
+ unsigned int len = want - base;
+
+ remainder -= len;
+ if (remainder == 0)
+ msg->msg_flags &= ~MSG_MORE;
+ if (rmsize)
+ err = xprt_send_rm_and_kvec(sock, msg, marker,
+ &xdr->head[0], base);
+ else
+ err = xprt_send_kvec(sock, msg, &xdr->head[0], base);
+ if (remainder == 0 || err != len)
+ goto out;
+ *sent_p += err;
+ base = 0;
+ } else {
+ base -= want;
+ }
+
+ if (base < xdr->page_len) {
+ unsigned int len = xdr->page_len - base;
+
+ remainder -= len;
+ if (remainder == 0)
+ msg->msg_flags &= ~MSG_MORE;
+ err = xprt_send_pagedata(sock, msg, xdr, base);
+ if (remainder == 0 || err != len)
+ goto out;
+ *sent_p += err;
+ base = 0;
+ } else {
+ base -= xdr->page_len;
+ }
+
+ if (base >= xdr->tail[0].iov_len)
+ return 0;
+ msg->msg_flags &= ~MSG_MORE;
+ err = xprt_send_kvec(sock, msg, &xdr->tail[0], base);
+out:
+ if (err > 0) {
+ *sent_p += err;
+ err = 0;
+ }
+ return err;
+}
diff --git a/net/sunrpc/socklib.h b/net/sunrpc/socklib.h
new file mode 100644
index 0000000000..c48114ad6f
--- /dev/null
+++ b/net/sunrpc/socklib.h
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * Copyright (C) 1995-1997 Olaf Kirch <okir@monad.swb.de>
+ * Copyright (C) 2020, Oracle.
+ */
+
+#ifndef _NET_SUNRPC_SOCKLIB_H_
+#define _NET_SUNRPC_SOCKLIB_H_
+
+int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb);
+int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg,
+ struct xdr_buf *xdr, unsigned int base,
+ rpc_fraghdr marker, unsigned int *sent_p);
+
+#endif /* _NET_SUNRPC_SOCKLIB_H_ */
diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c
new file mode 100644
index 0000000000..65fc1297c6
--- /dev/null
+++ b/net/sunrpc/stats.c
@@ -0,0 +1,348 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/stats.c
+ *
+ * procfs-based user access to generic RPC statistics. The stats files
+ * reside in /proc/net/rpc.
+ *
+ * The read routines assume that the buffer passed in is just big enough.
+ * If you implement an RPC service that has its own stats routine which
+ * appends the generic RPC stats, make sure you don't exceed the PAGE_SIZE
+ * limit.
+ *
+ * Copyright (C) 1995, 1996, 1997 Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/module.h>
+#include <linux/slab.h>
+
+#include <linux/init.h>
+#include <linux/kernel.h>
+#include <linux/proc_fs.h>
+#include <linux/seq_file.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/metrics.h>
+#include <linux/rcupdate.h>
+
+#include <trace/events/sunrpc.h>
+
+#include "netns.h"
+
+#define RPCDBG_FACILITY RPCDBG_MISC
+
+/*
+ * Get RPC client stats
+ */
+static int rpc_proc_show(struct seq_file *seq, void *v) {
+ const struct rpc_stat *statp = seq->private;
+ const struct rpc_program *prog = statp->program;
+ unsigned int i, j;
+
+ seq_printf(seq,
+ "net %u %u %u %u\n",
+ statp->netcnt,
+ statp->netudpcnt,
+ statp->nettcpcnt,
+ statp->nettcpconn);
+ seq_printf(seq,
+ "rpc %u %u %u\n",
+ statp->rpccnt,
+ statp->rpcretrans,
+ statp->rpcauthrefresh);
+
+ for (i = 0; i < prog->nrvers; i++) {
+ const struct rpc_version *vers = prog->version[i];
+ if (!vers)
+ continue;
+ seq_printf(seq, "proc%u %u",
+ vers->number, vers->nrprocs);
+ for (j = 0; j < vers->nrprocs; j++)
+ seq_printf(seq, " %u", vers->counts[j]);
+ seq_putc(seq, '\n');
+ }
+ return 0;
+}
+
+static int rpc_proc_open(struct inode *inode, struct file *file)
+{
+ return single_open(file, rpc_proc_show, pde_data(inode));
+}
+
+static const struct proc_ops rpc_proc_ops = {
+ .proc_open = rpc_proc_open,
+ .proc_read = seq_read,
+ .proc_lseek = seq_lseek,
+ .proc_release = single_release,
+};
+
+/*
+ * Get RPC server stats
+ */
+void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp)
+{
+ const struct svc_program *prog = statp->program;
+ const struct svc_version *vers;
+ unsigned int i, j, k;
+ unsigned long count;
+
+ seq_printf(seq,
+ "net %u %u %u %u\n",
+ statp->netcnt,
+ statp->netudpcnt,
+ statp->nettcpcnt,
+ statp->nettcpconn);
+ seq_printf(seq,
+ "rpc %u %u %u %u %u\n",
+ statp->rpccnt,
+ statp->rpcbadfmt+statp->rpcbadauth+statp->rpcbadclnt,
+ statp->rpcbadfmt,
+ statp->rpcbadauth,
+ statp->rpcbadclnt);
+
+ for (i = 0; i < prog->pg_nvers; i++) {
+ vers = prog->pg_vers[i];
+ if (!vers)
+ continue;
+ seq_printf(seq, "proc%d %u", i, vers->vs_nproc);
+ for (j = 0; j < vers->vs_nproc; j++) {
+ count = 0;
+ for_each_possible_cpu(k)
+ count += per_cpu(vers->vs_count[j], k);
+ seq_printf(seq, " %lu", count);
+ }
+ seq_putc(seq, '\n');
+ }
+}
+EXPORT_SYMBOL_GPL(svc_seq_show);
+
+/**
+ * rpc_alloc_iostats - allocate an rpc_iostats structure
+ * @clnt: RPC program, version, and xprt
+ *
+ */
+struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt)
+{
+ struct rpc_iostats *stats;
+ int i;
+
+ stats = kcalloc(clnt->cl_maxproc, sizeof(*stats), GFP_KERNEL);
+ if (stats) {
+ for (i = 0; i < clnt->cl_maxproc; i++)
+ spin_lock_init(&stats[i].om_lock);
+ }
+ return stats;
+}
+EXPORT_SYMBOL_GPL(rpc_alloc_iostats);
+
+/**
+ * rpc_free_iostats - release an rpc_iostats structure
+ * @stats: doomed rpc_iostats structure
+ *
+ */
+void rpc_free_iostats(struct rpc_iostats *stats)
+{
+ kfree(stats);
+}
+EXPORT_SYMBOL_GPL(rpc_free_iostats);
+
+/**
+ * rpc_count_iostats_metrics - tally up per-task stats
+ * @task: completed rpc_task
+ * @op_metrics: stat structure for OP that will accumulate stats from @task
+ */
+void rpc_count_iostats_metrics(const struct rpc_task *task,
+ struct rpc_iostats *op_metrics)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ ktime_t backlog, execute, now;
+
+ if (!op_metrics || !req)
+ return;
+
+ now = ktime_get();
+ spin_lock(&op_metrics->om_lock);
+
+ op_metrics->om_ops++;
+ /* kernel API: om_ops must never become larger than om_ntrans */
+ op_metrics->om_ntrans += max(req->rq_ntrans, 1);
+ op_metrics->om_timeouts += task->tk_timeouts;
+
+ op_metrics->om_bytes_sent += req->rq_xmit_bytes_sent;
+ op_metrics->om_bytes_recv += req->rq_reply_bytes_recvd;
+
+ backlog = 0;
+ if (ktime_to_ns(req->rq_xtime)) {
+ backlog = ktime_sub(req->rq_xtime, task->tk_start);
+ op_metrics->om_queue = ktime_add(op_metrics->om_queue, backlog);
+ }
+
+ op_metrics->om_rtt = ktime_add(op_metrics->om_rtt, req->rq_rtt);
+
+ execute = ktime_sub(now, task->tk_start);
+ op_metrics->om_execute = ktime_add(op_metrics->om_execute, execute);
+ if (task->tk_status < 0)
+ op_metrics->om_error_status++;
+
+ spin_unlock(&op_metrics->om_lock);
+
+ trace_rpc_stats_latency(req->rq_task, backlog, req->rq_rtt, execute);
+}
+EXPORT_SYMBOL_GPL(rpc_count_iostats_metrics);
+
+/**
+ * rpc_count_iostats - tally up per-task stats
+ * @task: completed rpc_task
+ * @stats: array of stat structures
+ *
+ * Uses the statidx from @task
+ */
+void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats)
+{
+ rpc_count_iostats_metrics(task,
+ &stats[task->tk_msg.rpc_proc->p_statidx]);
+}
+EXPORT_SYMBOL_GPL(rpc_count_iostats);
+
+static void _print_name(struct seq_file *seq, unsigned int op,
+ const struct rpc_procinfo *procs)
+{
+ if (procs[op].p_name)
+ seq_printf(seq, "\t%12s: ", procs[op].p_name);
+ else if (op == 0)
+ seq_printf(seq, "\t NULL: ");
+ else
+ seq_printf(seq, "\t%12u: ", op);
+}
+
+static void _add_rpc_iostats(struct rpc_iostats *a, struct rpc_iostats *b)
+{
+ a->om_ops += b->om_ops;
+ a->om_ntrans += b->om_ntrans;
+ a->om_timeouts += b->om_timeouts;
+ a->om_bytes_sent += b->om_bytes_sent;
+ a->om_bytes_recv += b->om_bytes_recv;
+ a->om_queue = ktime_add(a->om_queue, b->om_queue);
+ a->om_rtt = ktime_add(a->om_rtt, b->om_rtt);
+ a->om_execute = ktime_add(a->om_execute, b->om_execute);
+ a->om_error_status += b->om_error_status;
+}
+
+static void _print_rpc_iostats(struct seq_file *seq, struct rpc_iostats *stats,
+ int op, const struct rpc_procinfo *procs)
+{
+ _print_name(seq, op, procs);
+ seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %llu %lu\n",
+ stats->om_ops,
+ stats->om_ntrans,
+ stats->om_timeouts,
+ stats->om_bytes_sent,
+ stats->om_bytes_recv,
+ ktime_to_ms(stats->om_queue),
+ ktime_to_ms(stats->om_rtt),
+ ktime_to_ms(stats->om_execute),
+ stats->om_error_status);
+}
+
+static int do_print_stats(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *seqv)
+{
+ struct seq_file *seq = seqv;
+
+ xprt->ops->print_stats(xprt, seq);
+ return 0;
+}
+
+void rpc_clnt_show_stats(struct seq_file *seq, struct rpc_clnt *clnt)
+{
+ unsigned int op, maxproc = clnt->cl_maxproc;
+
+ if (!clnt->cl_metrics)
+ return;
+
+ seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS);
+ seq_printf(seq, "p/v: %u/%u (%s)\n",
+ clnt->cl_prog, clnt->cl_vers, clnt->cl_program->name);
+
+ rpc_clnt_iterate_for_each_xprt(clnt, do_print_stats, seq);
+
+ seq_printf(seq, "\tper-op statistics\n");
+ for (op = 0; op < maxproc; op++) {
+ struct rpc_iostats stats = {};
+ struct rpc_clnt *next = clnt;
+ do {
+ _add_rpc_iostats(&stats, &next->cl_metrics[op]);
+ if (next == next->cl_parent)
+ break;
+ next = next->cl_parent;
+ } while (next);
+ _print_rpc_iostats(seq, &stats, op, clnt->cl_procinfo);
+ }
+}
+EXPORT_SYMBOL_GPL(rpc_clnt_show_stats);
+
+/*
+ * Register/unregister RPC proc files
+ */
+static inline struct proc_dir_entry *
+do_register(struct net *net, const char *name, void *data,
+ const struct proc_ops *proc_ops)
+{
+ struct sunrpc_net *sn;
+
+ dprintk("RPC: registering /proc/net/rpc/%s\n", name);
+ sn = net_generic(net, sunrpc_net_id);
+ return proc_create_data(name, 0, sn->proc_net_rpc, proc_ops, data);
+}
+
+struct proc_dir_entry *
+rpc_proc_register(struct net *net, struct rpc_stat *statp)
+{
+ return do_register(net, statp->program->name, statp, &rpc_proc_ops);
+}
+EXPORT_SYMBOL_GPL(rpc_proc_register);
+
+void
+rpc_proc_unregister(struct net *net, const char *name)
+{
+ struct sunrpc_net *sn;
+
+ sn = net_generic(net, sunrpc_net_id);
+ remove_proc_entry(name, sn->proc_net_rpc);
+}
+EXPORT_SYMBOL_GPL(rpc_proc_unregister);
+
+struct proc_dir_entry *
+svc_proc_register(struct net *net, struct svc_stat *statp, const struct proc_ops *proc_ops)
+{
+ return do_register(net, statp->program->pg_name, statp, proc_ops);
+}
+EXPORT_SYMBOL_GPL(svc_proc_register);
+
+void
+svc_proc_unregister(struct net *net, const char *name)
+{
+ struct sunrpc_net *sn;
+
+ sn = net_generic(net, sunrpc_net_id);
+ remove_proc_entry(name, sn->proc_net_rpc);
+}
+EXPORT_SYMBOL_GPL(svc_proc_unregister);
+
+int rpc_proc_init(struct net *net)
+{
+ struct sunrpc_net *sn;
+
+ dprintk("RPC: registering /proc/net/rpc\n");
+ sn = net_generic(net, sunrpc_net_id);
+ sn->proc_net_rpc = proc_mkdir("rpc", net->proc_net);
+ if (sn->proc_net_rpc == NULL)
+ return -ENOMEM;
+
+ return 0;
+}
+
+void rpc_proc_exit(struct net *net)
+{
+ dprintk("RPC: unregistering /proc/net/rpc\n");
+ remove_proc_entry("rpc", net->proc_net);
+}
diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h
new file mode 100644
index 0000000000..d4a362c9e4
--- /dev/null
+++ b/net/sunrpc/sunrpc.h
@@ -0,0 +1,42 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/******************************************************************************
+
+(c) 2008 NetApp. All Rights Reserved.
+
+
+******************************************************************************/
+
+/*
+ * Functions and macros used internally by RPC
+ */
+
+#ifndef _NET_SUNRPC_SUNRPC_H
+#define _NET_SUNRPC_SUNRPC_H
+
+#include <linux/net.h>
+
+/*
+ * Header for dynamically allocated rpc buffers.
+ */
+struct rpc_buffer {
+ size_t len;
+ char data[];
+};
+
+static inline int sock_is_loopback(struct sock *sk)
+{
+ struct dst_entry *dst;
+ int loopback = 0;
+ rcu_read_lock();
+ dst = rcu_dereference(sk->sk_dst_cache);
+ if (dst && dst->dev &&
+ (dst->dev->features & NETIF_F_LOOPBACK))
+ loopback = 1;
+ rcu_read_unlock();
+ return loopback;
+}
+
+int rpc_clients_notifier_register(void);
+void rpc_clients_notifier_unregister(void);
+void auth_domain_cleanup(void);
+#endif /* _NET_SUNRPC_SUNRPC_H */
diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c
new file mode 100644
index 0000000000..691c0000e9
--- /dev/null
+++ b/net/sunrpc/sunrpc_syms.c
@@ -0,0 +1,153 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/sunrpc_syms.c
+ *
+ * Symbols exported by the sunrpc module.
+ *
+ * Copyright (C) 1997 Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/module.h>
+
+#include <linux/types.h>
+#include <linux/uio.h>
+#include <linux/unistd.h>
+#include <linux/init.h>
+
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/svc.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/auth.h>
+#include <linux/workqueue.h>
+#include <linux/sunrpc/rpc_pipe_fs.h>
+#include <linux/sunrpc/xprtsock.h>
+
+#include "sunrpc.h"
+#include "sysfs.h"
+#include "netns.h"
+
+unsigned int sunrpc_net_id;
+EXPORT_SYMBOL_GPL(sunrpc_net_id);
+
+static __net_init int sunrpc_init_net(struct net *net)
+{
+ int err;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ err = rpc_proc_init(net);
+ if (err)
+ goto err_proc;
+
+ err = ip_map_cache_create(net);
+ if (err)
+ goto err_ipmap;
+
+ err = unix_gid_cache_create(net);
+ if (err)
+ goto err_unixgid;
+
+ err = rpc_pipefs_init_net(net);
+ if (err)
+ goto err_pipefs;
+
+ INIT_LIST_HEAD(&sn->all_clients);
+ spin_lock_init(&sn->rpc_client_lock);
+ spin_lock_init(&sn->rpcb_clnt_lock);
+ return 0;
+
+err_pipefs:
+ unix_gid_cache_destroy(net);
+err_unixgid:
+ ip_map_cache_destroy(net);
+err_ipmap:
+ rpc_proc_exit(net);
+err_proc:
+ return err;
+}
+
+static __net_exit void sunrpc_exit_net(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ rpc_pipefs_exit_net(net);
+ unix_gid_cache_destroy(net);
+ ip_map_cache_destroy(net);
+ rpc_proc_exit(net);
+ WARN_ON_ONCE(!list_empty(&sn->all_clients));
+}
+
+static struct pernet_operations sunrpc_net_ops = {
+ .init = sunrpc_init_net,
+ .exit = sunrpc_exit_net,
+ .id = &sunrpc_net_id,
+ .size = sizeof(struct sunrpc_net),
+};
+
+static int __init
+init_sunrpc(void)
+{
+ int err = rpc_init_mempool();
+ if (err)
+ goto out;
+ err = rpcauth_init_module();
+ if (err)
+ goto out2;
+
+ cache_initialize();
+
+ err = register_pernet_subsys(&sunrpc_net_ops);
+ if (err)
+ goto out3;
+
+ err = register_rpc_pipefs();
+ if (err)
+ goto out4;
+
+ err = rpc_sysfs_init();
+ if (err)
+ goto out5;
+
+ sunrpc_debugfs_init();
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+ rpc_register_sysctl();
+#endif
+ svc_init_xprt_sock(); /* svc sock transport */
+ init_socket_xprt(); /* clnt sock transport */
+ return 0;
+
+out5:
+ unregister_rpc_pipefs();
+out4:
+ unregister_pernet_subsys(&sunrpc_net_ops);
+out3:
+ rpcauth_remove_module();
+out2:
+ rpc_destroy_mempool();
+out:
+ return err;
+}
+
+static void __exit
+cleanup_sunrpc(void)
+{
+ rpc_sysfs_exit();
+ rpc_cleanup_clids();
+ xprt_cleanup_ids();
+ xprt_multipath_cleanup_ids();
+ rpcauth_remove_module();
+ cleanup_socket_xprt();
+ svc_cleanup_xprt_sock();
+ sunrpc_debugfs_exit();
+ unregister_rpc_pipefs();
+ rpc_destroy_mempool();
+ unregister_pernet_subsys(&sunrpc_net_ops);
+ auth_domain_cleanup();
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+ rpc_unregister_sysctl();
+#endif
+ rcu_barrier(); /* Wait for completion of call_rcu()'s */
+}
+MODULE_LICENSE("GPL");
+fs_initcall(init_sunrpc); /* Ensure we're initialised before nfs */
+module_exit(cleanup_sunrpc);
diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c
new file mode 100644
index 0000000000..812fda9d45
--- /dev/null
+++ b/net/sunrpc/svc.c
@@ -0,0 +1,1764 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/svc.c
+ *
+ * High-level RPC service routines
+ *
+ * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
+ *
+ * Multiple threads pools and NUMAisation
+ * Copyright (c) 2006 Silicon Graphics, Inc.
+ * by Greg Banks <gnb@melbourne.sgi.com>
+ */
+
+#include <linux/linkage.h>
+#include <linux/sched/signal.h>
+#include <linux/errno.h>
+#include <linux/net.h>
+#include <linux/in.h>
+#include <linux/mm.h>
+#include <linux/interrupt.h>
+#include <linux/module.h>
+#include <linux/kthread.h>
+#include <linux/slab.h>
+
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/stats.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/bc_xprt.h>
+
+#include <trace/events/sunrpc.h>
+
+#include "fail.h"
+
+#define RPCDBG_FACILITY RPCDBG_SVCDSP
+
+static void svc_unregister(const struct svc_serv *serv, struct net *net);
+
+#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL
+
+/*
+ * Mode for mapping cpus to pools.
+ */
+enum {
+ SVC_POOL_AUTO = -1, /* choose one of the others */
+ SVC_POOL_GLOBAL, /* no mapping, just a single global pool
+ * (legacy & UP mode) */
+ SVC_POOL_PERCPU, /* one pool per cpu */
+ SVC_POOL_PERNODE /* one pool per numa node */
+};
+
+/*
+ * Structure for mapping cpus to pools and vice versa.
+ * Setup once during sunrpc initialisation.
+ */
+
+struct svc_pool_map {
+ int count; /* How many svc_servs use us */
+ int mode; /* Note: int not enum to avoid
+ * warnings about "enumeration value
+ * not handled in switch" */
+ unsigned int npools;
+ unsigned int *pool_to; /* maps pool id to cpu or node */
+ unsigned int *to_pool; /* maps cpu or node to pool id */
+};
+
+static struct svc_pool_map svc_pool_map = {
+ .mode = SVC_POOL_DEFAULT
+};
+
+static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */
+
+static int
+param_set_pool_mode(const char *val, const struct kernel_param *kp)
+{
+ int *ip = (int *)kp->arg;
+ struct svc_pool_map *m = &svc_pool_map;
+ int err;
+
+ mutex_lock(&svc_pool_map_mutex);
+
+ err = -EBUSY;
+ if (m->count)
+ goto out;
+
+ err = 0;
+ if (!strncmp(val, "auto", 4))
+ *ip = SVC_POOL_AUTO;
+ else if (!strncmp(val, "global", 6))
+ *ip = SVC_POOL_GLOBAL;
+ else if (!strncmp(val, "percpu", 6))
+ *ip = SVC_POOL_PERCPU;
+ else if (!strncmp(val, "pernode", 7))
+ *ip = SVC_POOL_PERNODE;
+ else
+ err = -EINVAL;
+
+out:
+ mutex_unlock(&svc_pool_map_mutex);
+ return err;
+}
+
+static int
+param_get_pool_mode(char *buf, const struct kernel_param *kp)
+{
+ int *ip = (int *)kp->arg;
+
+ switch (*ip)
+ {
+ case SVC_POOL_AUTO:
+ return sysfs_emit(buf, "auto\n");
+ case SVC_POOL_GLOBAL:
+ return sysfs_emit(buf, "global\n");
+ case SVC_POOL_PERCPU:
+ return sysfs_emit(buf, "percpu\n");
+ case SVC_POOL_PERNODE:
+ return sysfs_emit(buf, "pernode\n");
+ default:
+ return sysfs_emit(buf, "%d\n", *ip);
+ }
+}
+
+module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode,
+ &svc_pool_map.mode, 0644);
+
+/*
+ * Detect best pool mapping mode heuristically,
+ * according to the machine's topology.
+ */
+static int
+svc_pool_map_choose_mode(void)
+{
+ unsigned int node;
+
+ if (nr_online_nodes > 1) {
+ /*
+ * Actually have multiple NUMA nodes,
+ * so split pools on NUMA node boundaries
+ */
+ return SVC_POOL_PERNODE;
+ }
+
+ node = first_online_node;
+ if (nr_cpus_node(node) > 2) {
+ /*
+ * Non-trivial SMP, or CONFIG_NUMA on
+ * non-NUMA hardware, e.g. with a generic
+ * x86_64 kernel on Xeons. In this case we
+ * want to divide the pools on cpu boundaries.
+ */
+ return SVC_POOL_PERCPU;
+ }
+
+ /* default: one global pool */
+ return SVC_POOL_GLOBAL;
+}
+
+/*
+ * Allocate the to_pool[] and pool_to[] arrays.
+ * Returns 0 on success or an errno.
+ */
+static int
+svc_pool_map_alloc_arrays(struct svc_pool_map *m, unsigned int maxpools)
+{
+ m->to_pool = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL);
+ if (!m->to_pool)
+ goto fail;
+ m->pool_to = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL);
+ if (!m->pool_to)
+ goto fail_free;
+
+ return 0;
+
+fail_free:
+ kfree(m->to_pool);
+ m->to_pool = NULL;
+fail:
+ return -ENOMEM;
+}
+
+/*
+ * Initialise the pool map for SVC_POOL_PERCPU mode.
+ * Returns number of pools or <0 on error.
+ */
+static int
+svc_pool_map_init_percpu(struct svc_pool_map *m)
+{
+ unsigned int maxpools = nr_cpu_ids;
+ unsigned int pidx = 0;
+ unsigned int cpu;
+ int err;
+
+ err = svc_pool_map_alloc_arrays(m, maxpools);
+ if (err)
+ return err;
+
+ for_each_online_cpu(cpu) {
+ BUG_ON(pidx >= maxpools);
+ m->to_pool[cpu] = pidx;
+ m->pool_to[pidx] = cpu;
+ pidx++;
+ }
+ /* cpus brought online later all get mapped to pool0, sorry */
+
+ return pidx;
+};
+
+
+/*
+ * Initialise the pool map for SVC_POOL_PERNODE mode.
+ * Returns number of pools or <0 on error.
+ */
+static int
+svc_pool_map_init_pernode(struct svc_pool_map *m)
+{
+ unsigned int maxpools = nr_node_ids;
+ unsigned int pidx = 0;
+ unsigned int node;
+ int err;
+
+ err = svc_pool_map_alloc_arrays(m, maxpools);
+ if (err)
+ return err;
+
+ for_each_node_with_cpus(node) {
+ /* some architectures (e.g. SN2) have cpuless nodes */
+ BUG_ON(pidx > maxpools);
+ m->to_pool[node] = pidx;
+ m->pool_to[pidx] = node;
+ pidx++;
+ }
+ /* nodes brought online later all get mapped to pool0, sorry */
+
+ return pidx;
+}
+
+
+/*
+ * Add a reference to the global map of cpus to pools (and
+ * vice versa) if pools are in use.
+ * Initialise the map if we're the first user.
+ * Returns the number of pools. If this is '1', no reference
+ * was taken.
+ */
+static unsigned int
+svc_pool_map_get(void)
+{
+ struct svc_pool_map *m = &svc_pool_map;
+ int npools = -1;
+
+ mutex_lock(&svc_pool_map_mutex);
+
+ if (m->count++) {
+ mutex_unlock(&svc_pool_map_mutex);
+ WARN_ON_ONCE(m->npools <= 1);
+ return m->npools;
+ }
+
+ if (m->mode == SVC_POOL_AUTO)
+ m->mode = svc_pool_map_choose_mode();
+
+ switch (m->mode) {
+ case SVC_POOL_PERCPU:
+ npools = svc_pool_map_init_percpu(m);
+ break;
+ case SVC_POOL_PERNODE:
+ npools = svc_pool_map_init_pernode(m);
+ break;
+ }
+
+ if (npools <= 0) {
+ /* default, or memory allocation failure */
+ npools = 1;
+ m->mode = SVC_POOL_GLOBAL;
+ }
+ m->npools = npools;
+
+ if (npools == 1)
+ /* service is unpooled, so doesn't hold a reference */
+ m->count--;
+
+ mutex_unlock(&svc_pool_map_mutex);
+ return npools;
+}
+
+/*
+ * Drop a reference to the global map of cpus to pools, if
+ * pools were in use, i.e. if npools > 1.
+ * When the last reference is dropped, the map data is
+ * freed; this allows the sysadmin to change the pool
+ * mode using the pool_mode module option without
+ * rebooting or re-loading sunrpc.ko.
+ */
+static void
+svc_pool_map_put(int npools)
+{
+ struct svc_pool_map *m = &svc_pool_map;
+
+ if (npools <= 1)
+ return;
+ mutex_lock(&svc_pool_map_mutex);
+
+ if (!--m->count) {
+ kfree(m->to_pool);
+ m->to_pool = NULL;
+ kfree(m->pool_to);
+ m->pool_to = NULL;
+ m->npools = 0;
+ }
+
+ mutex_unlock(&svc_pool_map_mutex);
+}
+
+static int svc_pool_map_get_node(unsigned int pidx)
+{
+ const struct svc_pool_map *m = &svc_pool_map;
+
+ if (m->count) {
+ if (m->mode == SVC_POOL_PERCPU)
+ return cpu_to_node(m->pool_to[pidx]);
+ if (m->mode == SVC_POOL_PERNODE)
+ return m->pool_to[pidx];
+ }
+ return NUMA_NO_NODE;
+}
+/*
+ * Set the given thread's cpus_allowed mask so that it
+ * will only run on cpus in the given pool.
+ */
+static inline void
+svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx)
+{
+ struct svc_pool_map *m = &svc_pool_map;
+ unsigned int node = m->pool_to[pidx];
+
+ /*
+ * The caller checks for sv_nrpools > 1, which
+ * implies that we've been initialized.
+ */
+ WARN_ON_ONCE(m->count == 0);
+ if (m->count == 0)
+ return;
+
+ switch (m->mode) {
+ case SVC_POOL_PERCPU:
+ {
+ set_cpus_allowed_ptr(task, cpumask_of(node));
+ break;
+ }
+ case SVC_POOL_PERNODE:
+ {
+ set_cpus_allowed_ptr(task, cpumask_of_node(node));
+ break;
+ }
+ }
+}
+
+/**
+ * svc_pool_for_cpu - Select pool to run a thread on this cpu
+ * @serv: An RPC service
+ *
+ * Use the active CPU and the svc_pool_map's mode setting to
+ * select the svc thread pool to use. Once initialized, the
+ * svc_pool_map does not change.
+ *
+ * Return value:
+ * A pointer to an svc_pool
+ */
+struct svc_pool *svc_pool_for_cpu(struct svc_serv *serv)
+{
+ struct svc_pool_map *m = &svc_pool_map;
+ int cpu = raw_smp_processor_id();
+ unsigned int pidx = 0;
+
+ if (serv->sv_nrpools <= 1)
+ return serv->sv_pools;
+
+ switch (m->mode) {
+ case SVC_POOL_PERCPU:
+ pidx = m->to_pool[cpu];
+ break;
+ case SVC_POOL_PERNODE:
+ pidx = m->to_pool[cpu_to_node(cpu)];
+ break;
+ }
+
+ return &serv->sv_pools[pidx % serv->sv_nrpools];
+}
+
+int svc_rpcb_setup(struct svc_serv *serv, struct net *net)
+{
+ int err;
+
+ err = rpcb_create_local(net);
+ if (err)
+ return err;
+
+ /* Remove any stale portmap registrations */
+ svc_unregister(serv, net);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(svc_rpcb_setup);
+
+void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net)
+{
+ svc_unregister(serv, net);
+ rpcb_put_local(net);
+}
+EXPORT_SYMBOL_GPL(svc_rpcb_cleanup);
+
+static int svc_uses_rpcbind(struct svc_serv *serv)
+{
+ struct svc_program *progp;
+ unsigned int i;
+
+ for (progp = serv->sv_program; progp; progp = progp->pg_next) {
+ for (i = 0; i < progp->pg_nvers; i++) {
+ if (progp->pg_vers[i] == NULL)
+ continue;
+ if (!progp->pg_vers[i]->vs_hidden)
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+int svc_bind(struct svc_serv *serv, struct net *net)
+{
+ if (!svc_uses_rpcbind(serv))
+ return 0;
+ return svc_rpcb_setup(serv, net);
+}
+EXPORT_SYMBOL_GPL(svc_bind);
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+static void
+__svc_init_bc(struct svc_serv *serv)
+{
+ INIT_LIST_HEAD(&serv->sv_cb_list);
+ spin_lock_init(&serv->sv_cb_lock);
+ init_waitqueue_head(&serv->sv_cb_waitq);
+}
+#else
+static void
+__svc_init_bc(struct svc_serv *serv)
+{
+}
+#endif
+
+/*
+ * Create an RPC service
+ */
+static struct svc_serv *
+__svc_create(struct svc_program *prog, unsigned int bufsize, int npools,
+ int (*threadfn)(void *data))
+{
+ struct svc_serv *serv;
+ unsigned int vers;
+ unsigned int xdrsize;
+ unsigned int i;
+
+ if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL)))
+ return NULL;
+ serv->sv_name = prog->pg_name;
+ serv->sv_program = prog;
+ kref_init(&serv->sv_refcnt);
+ serv->sv_stats = prog->pg_stats;
+ if (bufsize > RPCSVC_MAXPAYLOAD)
+ bufsize = RPCSVC_MAXPAYLOAD;
+ serv->sv_max_payload = bufsize? bufsize : 4096;
+ serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE);
+ serv->sv_threadfn = threadfn;
+ xdrsize = 0;
+ while (prog) {
+ prog->pg_lovers = prog->pg_nvers-1;
+ for (vers=0; vers<prog->pg_nvers ; vers++)
+ if (prog->pg_vers[vers]) {
+ prog->pg_hivers = vers;
+ if (prog->pg_lovers > vers)
+ prog->pg_lovers = vers;
+ if (prog->pg_vers[vers]->vs_xdrsize > xdrsize)
+ xdrsize = prog->pg_vers[vers]->vs_xdrsize;
+ }
+ prog = prog->pg_next;
+ }
+ serv->sv_xdrsize = xdrsize;
+ INIT_LIST_HEAD(&serv->sv_tempsocks);
+ INIT_LIST_HEAD(&serv->sv_permsocks);
+ timer_setup(&serv->sv_temptimer, NULL, 0);
+ spin_lock_init(&serv->sv_lock);
+
+ __svc_init_bc(serv);
+
+ serv->sv_nrpools = npools;
+ serv->sv_pools =
+ kcalloc(serv->sv_nrpools, sizeof(struct svc_pool),
+ GFP_KERNEL);
+ if (!serv->sv_pools) {
+ kfree(serv);
+ return NULL;
+ }
+
+ for (i = 0; i < serv->sv_nrpools; i++) {
+ struct svc_pool *pool = &serv->sv_pools[i];
+
+ dprintk("svc: initialising pool %u for %s\n",
+ i, serv->sv_name);
+
+ pool->sp_id = i;
+ INIT_LIST_HEAD(&pool->sp_sockets);
+ INIT_LIST_HEAD(&pool->sp_all_threads);
+ spin_lock_init(&pool->sp_lock);
+
+ percpu_counter_init(&pool->sp_messages_arrived, 0, GFP_KERNEL);
+ percpu_counter_init(&pool->sp_sockets_queued, 0, GFP_KERNEL);
+ percpu_counter_init(&pool->sp_threads_woken, 0, GFP_KERNEL);
+ }
+
+ return serv;
+}
+
+/**
+ * svc_create - Create an RPC service
+ * @prog: the RPC program the new service will handle
+ * @bufsize: maximum message size for @prog
+ * @threadfn: a function to service RPC requests for @prog
+ *
+ * Returns an instantiated struct svc_serv object or NULL.
+ */
+struct svc_serv *svc_create(struct svc_program *prog, unsigned int bufsize,
+ int (*threadfn)(void *data))
+{
+ return __svc_create(prog, bufsize, 1, threadfn);
+}
+EXPORT_SYMBOL_GPL(svc_create);
+
+/**
+ * svc_create_pooled - Create an RPC service with pooled threads
+ * @prog: the RPC program the new service will handle
+ * @bufsize: maximum message size for @prog
+ * @threadfn: a function to service RPC requests for @prog
+ *
+ * Returns an instantiated struct svc_serv object or NULL.
+ */
+struct svc_serv *svc_create_pooled(struct svc_program *prog,
+ unsigned int bufsize,
+ int (*threadfn)(void *data))
+{
+ struct svc_serv *serv;
+ unsigned int npools = svc_pool_map_get();
+
+ serv = __svc_create(prog, bufsize, npools, threadfn);
+ if (!serv)
+ goto out_err;
+ return serv;
+out_err:
+ svc_pool_map_put(npools);
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(svc_create_pooled);
+
+/*
+ * Destroy an RPC service. Should be called with appropriate locking to
+ * protect sv_permsocks and sv_tempsocks.
+ */
+void
+svc_destroy(struct kref *ref)
+{
+ struct svc_serv *serv = container_of(ref, struct svc_serv, sv_refcnt);
+ unsigned int i;
+
+ dprintk("svc: svc_destroy(%s)\n", serv->sv_program->pg_name);
+ timer_shutdown_sync(&serv->sv_temptimer);
+
+ /*
+ * The last user is gone and thus all sockets have to be destroyed to
+ * the point. Check this.
+ */
+ BUG_ON(!list_empty(&serv->sv_permsocks));
+ BUG_ON(!list_empty(&serv->sv_tempsocks));
+
+ cache_clean_deferred(serv);
+
+ svc_pool_map_put(serv->sv_nrpools);
+
+ for (i = 0; i < serv->sv_nrpools; i++) {
+ struct svc_pool *pool = &serv->sv_pools[i];
+
+ percpu_counter_destroy(&pool->sp_messages_arrived);
+ percpu_counter_destroy(&pool->sp_sockets_queued);
+ percpu_counter_destroy(&pool->sp_threads_woken);
+ }
+ kfree(serv->sv_pools);
+ kfree(serv);
+}
+EXPORT_SYMBOL_GPL(svc_destroy);
+
+static bool
+svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node)
+{
+ unsigned long pages, ret;
+
+ /* bc_xprt uses fore channel allocated buffers */
+ if (svc_is_backchannel(rqstp))
+ return true;
+
+ pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply.
+ * We assume one is at most one page
+ */
+ WARN_ON_ONCE(pages > RPCSVC_MAXPAGES);
+ if (pages > RPCSVC_MAXPAGES)
+ pages = RPCSVC_MAXPAGES;
+
+ ret = alloc_pages_bulk_array_node(GFP_KERNEL, node, pages,
+ rqstp->rq_pages);
+ return ret == pages;
+}
+
+/*
+ * Release an RPC server buffer
+ */
+static void
+svc_release_buffer(struct svc_rqst *rqstp)
+{
+ unsigned int i;
+
+ for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++)
+ if (rqstp->rq_pages[i])
+ put_page(rqstp->rq_pages[i]);
+}
+
+struct svc_rqst *
+svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node)
+{
+ struct svc_rqst *rqstp;
+
+ rqstp = kzalloc_node(sizeof(*rqstp), GFP_KERNEL, node);
+ if (!rqstp)
+ return rqstp;
+
+ folio_batch_init(&rqstp->rq_fbatch);
+
+ __set_bit(RQ_BUSY, &rqstp->rq_flags);
+ rqstp->rq_server = serv;
+ rqstp->rq_pool = pool;
+
+ rqstp->rq_scratch_page = alloc_pages_node(node, GFP_KERNEL, 0);
+ if (!rqstp->rq_scratch_page)
+ goto out_enomem;
+
+ rqstp->rq_argp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node);
+ if (!rqstp->rq_argp)
+ goto out_enomem;
+
+ rqstp->rq_resp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node);
+ if (!rqstp->rq_resp)
+ goto out_enomem;
+
+ if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node))
+ goto out_enomem;
+
+ return rqstp;
+out_enomem:
+ svc_rqst_free(rqstp);
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(svc_rqst_alloc);
+
+static struct svc_rqst *
+svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node)
+{
+ struct svc_rqst *rqstp;
+
+ rqstp = svc_rqst_alloc(serv, pool, node);
+ if (!rqstp)
+ return ERR_PTR(-ENOMEM);
+
+ svc_get(serv);
+ spin_lock_bh(&serv->sv_lock);
+ serv->sv_nrthreads += 1;
+ spin_unlock_bh(&serv->sv_lock);
+
+ spin_lock_bh(&pool->sp_lock);
+ pool->sp_nrthreads++;
+ list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads);
+ spin_unlock_bh(&pool->sp_lock);
+ return rqstp;
+}
+
+/**
+ * svc_pool_wake_idle_thread - Awaken an idle thread in @pool
+ * @pool: service thread pool
+ *
+ * Can be called from soft IRQ or process context. Finding an idle
+ * service thread and marking it BUSY is atomic with respect to
+ * other calls to svc_pool_wake_idle_thread().
+ *
+ */
+void svc_pool_wake_idle_thread(struct svc_pool *pool)
+{
+ struct svc_rqst *rqstp;
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(rqstp, &pool->sp_all_threads, rq_all) {
+ if (test_and_set_bit(RQ_BUSY, &rqstp->rq_flags))
+ continue;
+
+ WRITE_ONCE(rqstp->rq_qtime, ktime_get());
+ wake_up_process(rqstp->rq_task);
+ rcu_read_unlock();
+ percpu_counter_inc(&pool->sp_threads_woken);
+ trace_svc_wake_up(rqstp->rq_task->pid);
+ return;
+ }
+ rcu_read_unlock();
+
+ set_bit(SP_CONGESTED, &pool->sp_flags);
+}
+
+static struct svc_pool *
+svc_pool_next(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state)
+{
+ return pool ? pool : &serv->sv_pools[(*state)++ % serv->sv_nrpools];
+}
+
+static struct task_struct *
+svc_pool_victim(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state)
+{
+ unsigned int i;
+ struct task_struct *task = NULL;
+
+ if (pool != NULL) {
+ spin_lock_bh(&pool->sp_lock);
+ } else {
+ for (i = 0; i < serv->sv_nrpools; i++) {
+ pool = &serv->sv_pools[--(*state) % serv->sv_nrpools];
+ spin_lock_bh(&pool->sp_lock);
+ if (!list_empty(&pool->sp_all_threads))
+ goto found_pool;
+ spin_unlock_bh(&pool->sp_lock);
+ }
+ return NULL;
+ }
+
+found_pool:
+ if (!list_empty(&pool->sp_all_threads)) {
+ struct svc_rqst *rqstp;
+
+ rqstp = list_entry(pool->sp_all_threads.next, struct svc_rqst, rq_all);
+ set_bit(RQ_VICTIM, &rqstp->rq_flags);
+ list_del_rcu(&rqstp->rq_all);
+ task = rqstp->rq_task;
+ }
+ spin_unlock_bh(&pool->sp_lock);
+ return task;
+}
+
+static int
+svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs)
+{
+ struct svc_rqst *rqstp;
+ struct task_struct *task;
+ struct svc_pool *chosen_pool;
+ unsigned int state = serv->sv_nrthreads-1;
+ int node;
+
+ do {
+ nrservs--;
+ chosen_pool = svc_pool_next(serv, pool, &state);
+ node = svc_pool_map_get_node(chosen_pool->sp_id);
+
+ rqstp = svc_prepare_thread(serv, chosen_pool, node);
+ if (IS_ERR(rqstp))
+ return PTR_ERR(rqstp);
+ task = kthread_create_on_node(serv->sv_threadfn, rqstp,
+ node, "%s", serv->sv_name);
+ if (IS_ERR(task)) {
+ svc_exit_thread(rqstp);
+ return PTR_ERR(task);
+ }
+
+ rqstp->rq_task = task;
+ if (serv->sv_nrpools > 1)
+ svc_pool_map_set_cpumask(task, chosen_pool->sp_id);
+
+ svc_sock_update_bufs(serv);
+ wake_up_process(task);
+ } while (nrservs > 0);
+
+ return 0;
+}
+
+static int
+svc_stop_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs)
+{
+ struct svc_rqst *rqstp;
+ struct task_struct *task;
+ unsigned int state = serv->sv_nrthreads-1;
+
+ do {
+ task = svc_pool_victim(serv, pool, &state);
+ if (task == NULL)
+ break;
+ rqstp = kthread_data(task);
+ /* Did we lose a race to svo_function threadfn? */
+ if (kthread_stop(task) == -EINTR)
+ svc_exit_thread(rqstp);
+ nrservs++;
+ } while (nrservs < 0);
+ return 0;
+}
+
+/**
+ * svc_set_num_threads - adjust number of threads per RPC service
+ * @serv: RPC service to adjust
+ * @pool: Specific pool from which to choose threads, or NULL
+ * @nrservs: New number of threads for @serv (0 or less means kill all threads)
+ *
+ * Create or destroy threads to make the number of threads for @serv the
+ * given number. If @pool is non-NULL, change only threads in that pool;
+ * otherwise, round-robin between all pools for @serv. @serv's
+ * sv_nrthreads is adjusted for each thread created or destroyed.
+ *
+ * Caller must ensure mutual exclusion between this and server startup or
+ * shutdown.
+ *
+ * Returns zero on success or a negative errno if an error occurred while
+ * starting a thread.
+ */
+int
+svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs)
+{
+ if (pool == NULL) {
+ nrservs -= serv->sv_nrthreads;
+ } else {
+ spin_lock_bh(&pool->sp_lock);
+ nrservs -= pool->sp_nrthreads;
+ spin_unlock_bh(&pool->sp_lock);
+ }
+
+ if (nrservs > 0)
+ return svc_start_kthreads(serv, pool, nrservs);
+ if (nrservs < 0)
+ return svc_stop_kthreads(serv, pool, nrservs);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(svc_set_num_threads);
+
+/**
+ * svc_rqst_replace_page - Replace one page in rq_pages[]
+ * @rqstp: svc_rqst with pages to replace
+ * @page: replacement page
+ *
+ * When replacing a page in rq_pages, batch the release of the
+ * replaced pages to avoid hammering the page allocator.
+ *
+ * Return values:
+ * %true: page replaced
+ * %false: array bounds checking failed
+ */
+bool svc_rqst_replace_page(struct svc_rqst *rqstp, struct page *page)
+{
+ struct page **begin = rqstp->rq_pages;
+ struct page **end = &rqstp->rq_pages[RPCSVC_MAXPAGES];
+
+ if (unlikely(rqstp->rq_next_page < begin || rqstp->rq_next_page > end)) {
+ trace_svc_replace_page_err(rqstp);
+ return false;
+ }
+
+ if (*rqstp->rq_next_page) {
+ if (!folio_batch_add(&rqstp->rq_fbatch,
+ page_folio(*rqstp->rq_next_page)))
+ __folio_batch_release(&rqstp->rq_fbatch);
+ }
+
+ get_page(page);
+ *(rqstp->rq_next_page++) = page;
+ return true;
+}
+EXPORT_SYMBOL_GPL(svc_rqst_replace_page);
+
+/**
+ * svc_rqst_release_pages - Release Reply buffer pages
+ * @rqstp: RPC transaction context
+ *
+ * Release response pages that might still be in flight after
+ * svc_send, and any spliced filesystem-owned pages.
+ */
+void svc_rqst_release_pages(struct svc_rqst *rqstp)
+{
+ int i, count = rqstp->rq_next_page - rqstp->rq_respages;
+
+ if (count) {
+ release_pages(rqstp->rq_respages, count);
+ for (i = 0; i < count; i++)
+ rqstp->rq_respages[i] = NULL;
+ }
+}
+
+/*
+ * Called from a server thread as it's exiting. Caller must hold the "service
+ * mutex" for the service.
+ */
+void
+svc_rqst_free(struct svc_rqst *rqstp)
+{
+ folio_batch_release(&rqstp->rq_fbatch);
+ svc_release_buffer(rqstp);
+ if (rqstp->rq_scratch_page)
+ put_page(rqstp->rq_scratch_page);
+ kfree(rqstp->rq_resp);
+ kfree(rqstp->rq_argp);
+ kfree(rqstp->rq_auth_data);
+ kfree_rcu(rqstp, rq_rcu_head);
+}
+EXPORT_SYMBOL_GPL(svc_rqst_free);
+
+void
+svc_exit_thread(struct svc_rqst *rqstp)
+{
+ struct svc_serv *serv = rqstp->rq_server;
+ struct svc_pool *pool = rqstp->rq_pool;
+
+ spin_lock_bh(&pool->sp_lock);
+ pool->sp_nrthreads--;
+ if (!test_and_set_bit(RQ_VICTIM, &rqstp->rq_flags))
+ list_del_rcu(&rqstp->rq_all);
+ spin_unlock_bh(&pool->sp_lock);
+
+ spin_lock_bh(&serv->sv_lock);
+ serv->sv_nrthreads -= 1;
+ spin_unlock_bh(&serv->sv_lock);
+ svc_sock_update_bufs(serv);
+
+ svc_rqst_free(rqstp);
+
+ svc_put(serv);
+}
+EXPORT_SYMBOL_GPL(svc_exit_thread);
+
+/*
+ * Register an "inet" protocol family netid with the local
+ * rpcbind daemon via an rpcbind v4 SET request.
+ *
+ * No netconfig infrastructure is available in the kernel, so
+ * we map IP_ protocol numbers to netids by hand.
+ *
+ * Returns zero on success; a negative errno value is returned
+ * if any error occurs.
+ */
+static int __svc_rpcb_register4(struct net *net, const u32 program,
+ const u32 version,
+ const unsigned short protocol,
+ const unsigned short port)
+{
+ const struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_addr.s_addr = htonl(INADDR_ANY),
+ .sin_port = htons(port),
+ };
+ const char *netid;
+ int error;
+
+ switch (protocol) {
+ case IPPROTO_UDP:
+ netid = RPCBIND_NETID_UDP;
+ break;
+ case IPPROTO_TCP:
+ netid = RPCBIND_NETID_TCP;
+ break;
+ default:
+ return -ENOPROTOOPT;
+ }
+
+ error = rpcb_v4_register(net, program, version,
+ (const struct sockaddr *)&sin, netid);
+
+ /*
+ * User space didn't support rpcbind v4, so retry this
+ * registration request with the legacy rpcbind v2 protocol.
+ */
+ if (error == -EPROTONOSUPPORT)
+ error = rpcb_register(net, program, version, protocol, port);
+
+ return error;
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+/*
+ * Register an "inet6" protocol family netid with the local
+ * rpcbind daemon via an rpcbind v4 SET request.
+ *
+ * No netconfig infrastructure is available in the kernel, so
+ * we map IP_ protocol numbers to netids by hand.
+ *
+ * Returns zero on success; a negative errno value is returned
+ * if any error occurs.
+ */
+static int __svc_rpcb_register6(struct net *net, const u32 program,
+ const u32 version,
+ const unsigned short protocol,
+ const unsigned short port)
+{
+ const struct sockaddr_in6 sin6 = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = IN6ADDR_ANY_INIT,
+ .sin6_port = htons(port),
+ };
+ const char *netid;
+ int error;
+
+ switch (protocol) {
+ case IPPROTO_UDP:
+ netid = RPCBIND_NETID_UDP6;
+ break;
+ case IPPROTO_TCP:
+ netid = RPCBIND_NETID_TCP6;
+ break;
+ default:
+ return -ENOPROTOOPT;
+ }
+
+ error = rpcb_v4_register(net, program, version,
+ (const struct sockaddr *)&sin6, netid);
+
+ /*
+ * User space didn't support rpcbind version 4, so we won't
+ * use a PF_INET6 listener.
+ */
+ if (error == -EPROTONOSUPPORT)
+ error = -EAFNOSUPPORT;
+
+ return error;
+}
+#endif /* IS_ENABLED(CONFIG_IPV6) */
+
+/*
+ * Register a kernel RPC service via rpcbind version 4.
+ *
+ * Returns zero on success; a negative errno value is returned
+ * if any error occurs.
+ */
+static int __svc_register(struct net *net, const char *progname,
+ const u32 program, const u32 version,
+ const int family,
+ const unsigned short protocol,
+ const unsigned short port)
+{
+ int error = -EAFNOSUPPORT;
+
+ switch (family) {
+ case PF_INET:
+ error = __svc_rpcb_register4(net, program, version,
+ protocol, port);
+ break;
+#if IS_ENABLED(CONFIG_IPV6)
+ case PF_INET6:
+ error = __svc_rpcb_register6(net, program, version,
+ protocol, port);
+#endif
+ }
+
+ trace_svc_register(progname, version, family, protocol, port, error);
+ return error;
+}
+
+int svc_rpcbind_set_version(struct net *net,
+ const struct svc_program *progp,
+ u32 version, int family,
+ unsigned short proto,
+ unsigned short port)
+{
+ return __svc_register(net, progp->pg_name, progp->pg_prog,
+ version, family, proto, port);
+
+}
+EXPORT_SYMBOL_GPL(svc_rpcbind_set_version);
+
+int svc_generic_rpcbind_set(struct net *net,
+ const struct svc_program *progp,
+ u32 version, int family,
+ unsigned short proto,
+ unsigned short port)
+{
+ const struct svc_version *vers = progp->pg_vers[version];
+ int error;
+
+ if (vers == NULL)
+ return 0;
+
+ if (vers->vs_hidden) {
+ trace_svc_noregister(progp->pg_name, version, proto,
+ port, family, 0);
+ return 0;
+ }
+
+ /*
+ * Don't register a UDP port if we need congestion
+ * control.
+ */
+ if (vers->vs_need_cong_ctrl && proto == IPPROTO_UDP)
+ return 0;
+
+ error = svc_rpcbind_set_version(net, progp, version,
+ family, proto, port);
+
+ return (vers->vs_rpcb_optnl) ? 0 : error;
+}
+EXPORT_SYMBOL_GPL(svc_generic_rpcbind_set);
+
+/**
+ * svc_register - register an RPC service with the local portmapper
+ * @serv: svc_serv struct for the service to register
+ * @net: net namespace for the service to register
+ * @family: protocol family of service's listener socket
+ * @proto: transport protocol number to advertise
+ * @port: port to advertise
+ *
+ * Service is registered for any address in the passed-in protocol family
+ */
+int svc_register(const struct svc_serv *serv, struct net *net,
+ const int family, const unsigned short proto,
+ const unsigned short port)
+{
+ struct svc_program *progp;
+ unsigned int i;
+ int error = 0;
+
+ WARN_ON_ONCE(proto == 0 && port == 0);
+ if (proto == 0 && port == 0)
+ return -EINVAL;
+
+ for (progp = serv->sv_program; progp; progp = progp->pg_next) {
+ for (i = 0; i < progp->pg_nvers; i++) {
+
+ error = progp->pg_rpcbind_set(net, progp, i,
+ family, proto, port);
+ if (error < 0) {
+ printk(KERN_WARNING "svc: failed to register "
+ "%sv%u RPC service (errno %d).\n",
+ progp->pg_name, i, -error);
+ break;
+ }
+ }
+ }
+
+ return error;
+}
+
+/*
+ * If user space is running rpcbind, it should take the v4 UNSET
+ * and clear everything for this [program, version]. If user space
+ * is running portmap, it will reject the v4 UNSET, but won't have
+ * any "inet6" entries anyway. So a PMAP_UNSET should be sufficient
+ * in this case to clear all existing entries for [program, version].
+ */
+static void __svc_unregister(struct net *net, const u32 program, const u32 version,
+ const char *progname)
+{
+ int error;
+
+ error = rpcb_v4_register(net, program, version, NULL, "");
+
+ /*
+ * User space didn't support rpcbind v4, so retry this
+ * request with the legacy rpcbind v2 protocol.
+ */
+ if (error == -EPROTONOSUPPORT)
+ error = rpcb_register(net, program, version, 0, 0);
+
+ trace_svc_unregister(progname, version, error);
+}
+
+/*
+ * All netids, bind addresses and ports registered for [program, version]
+ * are removed from the local rpcbind database (if the service is not
+ * hidden) to make way for a new instance of the service.
+ *
+ * The result of unregistration is reported via dprintk for those who want
+ * verification of the result, but is otherwise not important.
+ */
+static void svc_unregister(const struct svc_serv *serv, struct net *net)
+{
+ struct sighand_struct *sighand;
+ struct svc_program *progp;
+ unsigned long flags;
+ unsigned int i;
+
+ clear_thread_flag(TIF_SIGPENDING);
+
+ for (progp = serv->sv_program; progp; progp = progp->pg_next) {
+ for (i = 0; i < progp->pg_nvers; i++) {
+ if (progp->pg_vers[i] == NULL)
+ continue;
+ if (progp->pg_vers[i]->vs_hidden)
+ continue;
+ __svc_unregister(net, progp->pg_prog, i, progp->pg_name);
+ }
+ }
+
+ rcu_read_lock();
+ sighand = rcu_dereference(current->sighand);
+ spin_lock_irqsave(&sighand->siglock, flags);
+ recalc_sigpending();
+ spin_unlock_irqrestore(&sighand->siglock, flags);
+ rcu_read_unlock();
+}
+
+/*
+ * dprintk the given error with the address of the client that caused it.
+ */
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+static __printf(2, 3)
+void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...)
+{
+ struct va_format vaf;
+ va_list args;
+ char buf[RPC_MAX_ADDRBUFLEN];
+
+ va_start(args, fmt);
+
+ vaf.fmt = fmt;
+ vaf.va = &args;
+
+ dprintk("svc: %s: %pV", svc_print_addr(rqstp, buf, sizeof(buf)), &vaf);
+
+ va_end(args);
+}
+#else
+static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) {}
+#endif
+
+__be32
+svc_generic_init_request(struct svc_rqst *rqstp,
+ const struct svc_program *progp,
+ struct svc_process_info *ret)
+{
+ const struct svc_version *versp = NULL; /* compiler food */
+ const struct svc_procedure *procp = NULL;
+
+ if (rqstp->rq_vers >= progp->pg_nvers )
+ goto err_bad_vers;
+ versp = progp->pg_vers[rqstp->rq_vers];
+ if (!versp)
+ goto err_bad_vers;
+
+ /*
+ * Some protocol versions (namely NFSv4) require some form of
+ * congestion control. (See RFC 7530 section 3.1 paragraph 2)
+ * In other words, UDP is not allowed. We mark those when setting
+ * up the svc_xprt, and verify that here.
+ *
+ * The spec is not very clear about what error should be returned
+ * when someone tries to access a server that is listening on UDP
+ * for lower versions. RPC_PROG_MISMATCH seems to be the closest
+ * fit.
+ */
+ if (versp->vs_need_cong_ctrl && rqstp->rq_xprt &&
+ !test_bit(XPT_CONG_CTRL, &rqstp->rq_xprt->xpt_flags))
+ goto err_bad_vers;
+
+ if (rqstp->rq_proc >= versp->vs_nproc)
+ goto err_bad_proc;
+ rqstp->rq_procinfo = procp = &versp->vs_proc[rqstp->rq_proc];
+ if (!procp)
+ goto err_bad_proc;
+
+ /* Initialize storage for argp and resp */
+ memset(rqstp->rq_argp, 0, procp->pc_argzero);
+ memset(rqstp->rq_resp, 0, procp->pc_ressize);
+
+ /* Bump per-procedure stats counter */
+ this_cpu_inc(versp->vs_count[rqstp->rq_proc]);
+
+ ret->dispatch = versp->vs_dispatch;
+ return rpc_success;
+err_bad_vers:
+ ret->mismatch.lovers = progp->pg_lovers;
+ ret->mismatch.hivers = progp->pg_hivers;
+ return rpc_prog_mismatch;
+err_bad_proc:
+ return rpc_proc_unavail;
+}
+EXPORT_SYMBOL_GPL(svc_generic_init_request);
+
+/*
+ * Common routine for processing the RPC request.
+ */
+static int
+svc_process_common(struct svc_rqst *rqstp)
+{
+ struct xdr_stream *xdr = &rqstp->rq_res_stream;
+ struct svc_program *progp;
+ const struct svc_procedure *procp = NULL;
+ struct svc_serv *serv = rqstp->rq_server;
+ struct svc_process_info process;
+ enum svc_auth_status auth_res;
+ unsigned int aoffset;
+ int rc;
+ __be32 *p;
+
+ /* Will be turned off by GSS integrity and privacy services */
+ set_bit(RQ_SPLICE_OK, &rqstp->rq_flags);
+ /* Will be turned off only when NFSv4 Sessions are used */
+ set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags);
+ clear_bit(RQ_DROPME, &rqstp->rq_flags);
+
+ /* Construct the first words of the reply: */
+ svcxdr_init_encode(rqstp);
+ xdr_stream_encode_be32(xdr, rqstp->rq_xid);
+ xdr_stream_encode_be32(xdr, rpc_reply);
+
+ p = xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 4);
+ if (unlikely(!p))
+ goto err_short_len;
+ if (*p++ != cpu_to_be32(RPC_VERSION))
+ goto err_bad_rpc;
+
+ xdr_stream_encode_be32(xdr, rpc_msg_accepted);
+
+ rqstp->rq_prog = be32_to_cpup(p++);
+ rqstp->rq_vers = be32_to_cpup(p++);
+ rqstp->rq_proc = be32_to_cpup(p);
+
+ for (progp = serv->sv_program; progp; progp = progp->pg_next)
+ if (rqstp->rq_prog == progp->pg_prog)
+ break;
+
+ /*
+ * Decode auth data, and add verifier to reply buffer.
+ * We do this before anything else in order to get a decent
+ * auth verifier.
+ */
+ auth_res = svc_authenticate(rqstp);
+ /* Also give the program a chance to reject this call: */
+ if (auth_res == SVC_OK && progp)
+ auth_res = progp->pg_authenticate(rqstp);
+ trace_svc_authenticate(rqstp, auth_res);
+ switch (auth_res) {
+ case SVC_OK:
+ break;
+ case SVC_GARBAGE:
+ goto err_garbage_args;
+ case SVC_SYSERR:
+ goto err_system_err;
+ case SVC_DENIED:
+ goto err_bad_auth;
+ case SVC_CLOSE:
+ goto close;
+ case SVC_DROP:
+ goto dropit;
+ case SVC_COMPLETE:
+ goto sendit;
+ default:
+ pr_warn_once("Unexpected svc_auth_status (%d)\n", auth_res);
+ goto err_system_err;
+ }
+
+ if (progp == NULL)
+ goto err_bad_prog;
+
+ switch (progp->pg_init_request(rqstp, progp, &process)) {
+ case rpc_success:
+ break;
+ case rpc_prog_unavail:
+ goto err_bad_prog;
+ case rpc_prog_mismatch:
+ goto err_bad_vers;
+ case rpc_proc_unavail:
+ goto err_bad_proc;
+ }
+
+ procp = rqstp->rq_procinfo;
+ /* Should this check go into the dispatcher? */
+ if (!procp || !procp->pc_func)
+ goto err_bad_proc;
+
+ /* Syntactic check complete */
+ serv->sv_stats->rpccnt++;
+ trace_svc_process(rqstp, progp->pg_name);
+
+ aoffset = xdr_stream_pos(xdr);
+
+ /* un-reserve some of the out-queue now that we have a
+ * better idea of reply size
+ */
+ if (procp->pc_xdrressize)
+ svc_reserve_auth(rqstp, procp->pc_xdrressize<<2);
+
+ /* Call the function that processes the request. */
+ rc = process.dispatch(rqstp);
+ if (procp->pc_release)
+ procp->pc_release(rqstp);
+ xdr_finish_decode(xdr);
+
+ if (!rc)
+ goto dropit;
+ if (rqstp->rq_auth_stat != rpc_auth_ok)
+ goto err_bad_auth;
+
+ if (*rqstp->rq_accept_statp != rpc_success)
+ xdr_truncate_encode(xdr, aoffset);
+
+ if (procp->pc_encode == NULL)
+ goto dropit;
+
+ sendit:
+ if (svc_authorise(rqstp))
+ goto close_xprt;
+ return 1; /* Caller can now send it */
+
+ dropit:
+ svc_authorise(rqstp); /* doesn't hurt to call this twice */
+ dprintk("svc: svc_process dropit\n");
+ return 0;
+
+ close:
+ svc_authorise(rqstp);
+close_xprt:
+ if (rqstp->rq_xprt && test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags))
+ svc_xprt_close(rqstp->rq_xprt);
+ dprintk("svc: svc_process close\n");
+ return 0;
+
+err_short_len:
+ svc_printk(rqstp, "short len %u, dropping request\n",
+ rqstp->rq_arg.len);
+ goto close_xprt;
+
+err_bad_rpc:
+ serv->sv_stats->rpcbadfmt++;
+ xdr_stream_encode_u32(xdr, RPC_MSG_DENIED);
+ xdr_stream_encode_u32(xdr, RPC_MISMATCH);
+ /* Only RPCv2 supported */
+ xdr_stream_encode_u32(xdr, RPC_VERSION);
+ xdr_stream_encode_u32(xdr, RPC_VERSION);
+ return 1; /* don't wrap */
+
+err_bad_auth:
+ dprintk("svc: authentication failed (%d)\n",
+ be32_to_cpu(rqstp->rq_auth_stat));
+ serv->sv_stats->rpcbadauth++;
+ /* Restore write pointer to location of reply status: */
+ xdr_truncate_encode(xdr, XDR_UNIT * 2);
+ xdr_stream_encode_u32(xdr, RPC_MSG_DENIED);
+ xdr_stream_encode_u32(xdr, RPC_AUTH_ERROR);
+ xdr_stream_encode_be32(xdr, rqstp->rq_auth_stat);
+ goto sendit;
+
+err_bad_prog:
+ dprintk("svc: unknown program %d\n", rqstp->rq_prog);
+ serv->sv_stats->rpcbadfmt++;
+ *rqstp->rq_accept_statp = rpc_prog_unavail;
+ goto sendit;
+
+err_bad_vers:
+ svc_printk(rqstp, "unknown version (%d for prog %d, %s)\n",
+ rqstp->rq_vers, rqstp->rq_prog, progp->pg_name);
+
+ serv->sv_stats->rpcbadfmt++;
+ *rqstp->rq_accept_statp = rpc_prog_mismatch;
+
+ /*
+ * svc_authenticate() has already added the verifier and
+ * advanced the stream just past rq_accept_statp.
+ */
+ xdr_stream_encode_u32(xdr, process.mismatch.lovers);
+ xdr_stream_encode_u32(xdr, process.mismatch.hivers);
+ goto sendit;
+
+err_bad_proc:
+ svc_printk(rqstp, "unknown procedure (%d)\n", rqstp->rq_proc);
+
+ serv->sv_stats->rpcbadfmt++;
+ *rqstp->rq_accept_statp = rpc_proc_unavail;
+ goto sendit;
+
+err_garbage_args:
+ svc_printk(rqstp, "failed to decode RPC header\n");
+
+ serv->sv_stats->rpcbadfmt++;
+ *rqstp->rq_accept_statp = rpc_garbage_args;
+ goto sendit;
+
+err_system_err:
+ serv->sv_stats->rpcbadfmt++;
+ *rqstp->rq_accept_statp = rpc_system_err;
+ goto sendit;
+}
+
+/**
+ * svc_process - Execute one RPC transaction
+ * @rqstp: RPC transaction context
+ *
+ */
+void svc_process(struct svc_rqst *rqstp)
+{
+ struct kvec *resv = &rqstp->rq_res.head[0];
+ __be32 *p;
+
+#if IS_ENABLED(CONFIG_FAIL_SUNRPC)
+ if (!fail_sunrpc.ignore_server_disconnect &&
+ should_fail(&fail_sunrpc.attr, 1))
+ svc_xprt_deferred_close(rqstp->rq_xprt);
+#endif
+
+ /*
+ * Setup response xdr_buf.
+ * Initially it has just one page
+ */
+ rqstp->rq_next_page = &rqstp->rq_respages[1];
+ resv->iov_base = page_address(rqstp->rq_respages[0]);
+ resv->iov_len = 0;
+ rqstp->rq_res.pages = rqstp->rq_next_page;
+ rqstp->rq_res.len = 0;
+ rqstp->rq_res.page_base = 0;
+ rqstp->rq_res.page_len = 0;
+ rqstp->rq_res.buflen = PAGE_SIZE;
+ rqstp->rq_res.tail[0].iov_base = NULL;
+ rqstp->rq_res.tail[0].iov_len = 0;
+
+ svcxdr_init_decode(rqstp);
+ p = xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 2);
+ if (unlikely(!p))
+ goto out_drop;
+ rqstp->rq_xid = *p++;
+ if (unlikely(*p != rpc_call))
+ goto out_baddir;
+
+ if (!svc_process_common(rqstp))
+ goto out_drop;
+ svc_send(rqstp);
+ return;
+
+out_baddir:
+ svc_printk(rqstp, "bad direction 0x%08x, dropping request\n",
+ be32_to_cpu(*p));
+ rqstp->rq_server->sv_stats->rpcbadfmt++;
+out_drop:
+ svc_drop(rqstp);
+}
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+/*
+ * Process a backchannel RPC request that arrived over an existing
+ * outbound connection
+ */
+int
+bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req,
+ struct svc_rqst *rqstp)
+{
+ struct rpc_task *task;
+ int proc_error;
+ int error;
+
+ dprintk("svc: %s(%p)\n", __func__, req);
+
+ /* Build the svc_rqst used by the common processing routine */
+ rqstp->rq_xid = req->rq_xid;
+ rqstp->rq_prot = req->rq_xprt->prot;
+ rqstp->rq_server = serv;
+ rqstp->rq_bc_net = req->rq_xprt->xprt_net;
+
+ rqstp->rq_addrlen = sizeof(req->rq_xprt->addr);
+ memcpy(&rqstp->rq_addr, &req->rq_xprt->addr, rqstp->rq_addrlen);
+ memcpy(&rqstp->rq_arg, &req->rq_rcv_buf, sizeof(rqstp->rq_arg));
+ memcpy(&rqstp->rq_res, &req->rq_snd_buf, sizeof(rqstp->rq_res));
+
+ /* Adjust the argument buffer length */
+ rqstp->rq_arg.len = req->rq_private_buf.len;
+ if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) {
+ rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len;
+ rqstp->rq_arg.page_len = 0;
+ } else if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len +
+ rqstp->rq_arg.page_len)
+ rqstp->rq_arg.page_len = rqstp->rq_arg.len -
+ rqstp->rq_arg.head[0].iov_len;
+ else
+ rqstp->rq_arg.len = rqstp->rq_arg.head[0].iov_len +
+ rqstp->rq_arg.page_len;
+
+ /* Reset the response buffer */
+ rqstp->rq_res.head[0].iov_len = 0;
+
+ /*
+ * Skip the XID and calldir fields because they've already
+ * been processed by the caller.
+ */
+ svcxdr_init_decode(rqstp);
+ if (!xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 2)) {
+ error = -EINVAL;
+ goto out;
+ }
+
+ /* Parse and execute the bc call */
+ proc_error = svc_process_common(rqstp);
+
+ atomic_dec(&req->rq_xprt->bc_slot_count);
+ if (!proc_error) {
+ /* Processing error: drop the request */
+ xprt_free_bc_request(req);
+ error = -EINVAL;
+ goto out;
+ }
+ /* Finally, send the reply synchronously */
+ memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf));
+ task = rpc_run_bc_task(req);
+ if (IS_ERR(task)) {
+ error = PTR_ERR(task);
+ goto out;
+ }
+
+ WARN_ON_ONCE(atomic_read(&task->tk_count) != 1);
+ error = task->tk_status;
+ rpc_put_task(task);
+
+out:
+ dprintk("svc: %s(), error=%d\n", __func__, error);
+ return error;
+}
+EXPORT_SYMBOL_GPL(bc_svc_process);
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+/**
+ * svc_max_payload - Return transport-specific limit on the RPC payload
+ * @rqstp: RPC transaction context
+ *
+ * Returns the maximum number of payload bytes the current transport
+ * allows.
+ */
+u32 svc_max_payload(const struct svc_rqst *rqstp)
+{
+ u32 max = rqstp->rq_xprt->xpt_class->xcl_max_payload;
+
+ if (rqstp->rq_server->sv_max_payload < max)
+ max = rqstp->rq_server->sv_max_payload;
+ return max;
+}
+EXPORT_SYMBOL_GPL(svc_max_payload);
+
+/**
+ * svc_proc_name - Return RPC procedure name in string form
+ * @rqstp: svc_rqst to operate on
+ *
+ * Return value:
+ * Pointer to a NUL-terminated string
+ */
+const char *svc_proc_name(const struct svc_rqst *rqstp)
+{
+ if (rqstp && rqstp->rq_procinfo)
+ return rqstp->rq_procinfo->pc_name;
+ return "unknown";
+}
+
+
+/**
+ * svc_encode_result_payload - mark a range of bytes as a result payload
+ * @rqstp: svc_rqst to operate on
+ * @offset: payload's byte offset in rqstp->rq_res
+ * @length: size of payload, in bytes
+ *
+ * Returns zero on success, or a negative errno if a permanent
+ * error occurred.
+ */
+int svc_encode_result_payload(struct svc_rqst *rqstp, unsigned int offset,
+ unsigned int length)
+{
+ return rqstp->rq_xprt->xpt_ops->xpo_result_payload(rqstp, offset,
+ length);
+}
+EXPORT_SYMBOL_GPL(svc_encode_result_payload);
+
+/**
+ * svc_fill_write_vector - Construct data argument for VFS write call
+ * @rqstp: svc_rqst to operate on
+ * @payload: xdr_buf containing only the write data payload
+ *
+ * Fills in rqstp::rq_vec, and returns the number of elements.
+ */
+unsigned int svc_fill_write_vector(struct svc_rqst *rqstp,
+ struct xdr_buf *payload)
+{
+ struct page **pages = payload->pages;
+ struct kvec *first = payload->head;
+ struct kvec *vec = rqstp->rq_vec;
+ size_t total = payload->len;
+ unsigned int i;
+
+ /* Some types of transport can present the write payload
+ * entirely in rq_arg.pages. In this case, @first is empty.
+ */
+ i = 0;
+ if (first->iov_len) {
+ vec[i].iov_base = first->iov_base;
+ vec[i].iov_len = min_t(size_t, total, first->iov_len);
+ total -= vec[i].iov_len;
+ ++i;
+ }
+
+ while (total) {
+ vec[i].iov_base = page_address(*pages);
+ vec[i].iov_len = min_t(size_t, total, PAGE_SIZE);
+ total -= vec[i].iov_len;
+ ++i;
+ ++pages;
+ }
+
+ WARN_ON_ONCE(i > ARRAY_SIZE(rqstp->rq_vec));
+ return i;
+}
+EXPORT_SYMBOL_GPL(svc_fill_write_vector);
+
+/**
+ * svc_fill_symlink_pathname - Construct pathname argument for VFS symlink call
+ * @rqstp: svc_rqst to operate on
+ * @first: buffer containing first section of pathname
+ * @p: buffer containing remaining section of pathname
+ * @total: total length of the pathname argument
+ *
+ * The VFS symlink API demands a NUL-terminated pathname in mapped memory.
+ * Returns pointer to a NUL-terminated string, or an ERR_PTR. Caller must free
+ * the returned string.
+ */
+char *svc_fill_symlink_pathname(struct svc_rqst *rqstp, struct kvec *first,
+ void *p, size_t total)
+{
+ size_t len, remaining;
+ char *result, *dst;
+
+ result = kmalloc(total + 1, GFP_KERNEL);
+ if (!result)
+ return ERR_PTR(-ESERVERFAULT);
+
+ dst = result;
+ remaining = total;
+
+ len = min_t(size_t, total, first->iov_len);
+ if (len) {
+ memcpy(dst, first->iov_base, len);
+ dst += len;
+ remaining -= len;
+ }
+
+ if (remaining) {
+ len = min_t(size_t, remaining, PAGE_SIZE);
+ memcpy(dst, p, len);
+ dst += len;
+ }
+
+ *dst = '\0';
+
+ /* Sanity check: Linux doesn't allow the pathname argument to
+ * contain a NUL byte.
+ */
+ if (strlen(result) != total) {
+ kfree(result);
+ return ERR_PTR(-EINVAL);
+ }
+ return result;
+}
+EXPORT_SYMBOL_GPL(svc_fill_symlink_pathname);
diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c
new file mode 100644
index 0000000000..5cfe5c7408
--- /dev/null
+++ b/net/sunrpc/svc_xprt.c
@@ -0,0 +1,1450 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/svc_xprt.c
+ *
+ * Author: Tom Tucker <tom@opengridcomputing.com>
+ */
+
+#include <linux/sched.h>
+#include <linux/sched/mm.h>
+#include <linux/errno.h>
+#include <linux/freezer.h>
+#include <linux/kthread.h>
+#include <linux/slab.h>
+#include <net/sock.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/stats.h>
+#include <linux/sunrpc/svc_xprt.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/xprt.h>
+#include <linux/module.h>
+#include <linux/netdevice.h>
+#include <trace/events/sunrpc.h>
+
+#define RPCDBG_FACILITY RPCDBG_SVCXPRT
+
+static unsigned int svc_rpc_per_connection_limit __read_mostly;
+module_param(svc_rpc_per_connection_limit, uint, 0644);
+
+
+static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt);
+static int svc_deferred_recv(struct svc_rqst *rqstp);
+static struct cache_deferred_req *svc_defer(struct cache_req *req);
+static void svc_age_temp_xprts(struct timer_list *t);
+static void svc_delete_xprt(struct svc_xprt *xprt);
+
+/* apparently the "standard" is that clients close
+ * idle connections after 5 minutes, servers after
+ * 6 minutes
+ * http://nfsv4bat.org/Documents/ConnectAThon/1996/nfstcp.pdf
+ */
+static int svc_conn_age_period = 6*60;
+
+/* List of registered transport classes */
+static DEFINE_SPINLOCK(svc_xprt_class_lock);
+static LIST_HEAD(svc_xprt_class_list);
+
+/* SMP locking strategy:
+ *
+ * svc_pool->sp_lock protects most of the fields of that pool.
+ * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt.
+ * when both need to be taken (rare), svc_serv->sv_lock is first.
+ * The "service mutex" protects svc_serv->sv_nrthread.
+ * svc_sock->sk_lock protects the svc_sock->sk_deferred list
+ * and the ->sk_info_authunix cache.
+ *
+ * The XPT_BUSY bit in xprt->xpt_flags prevents a transport being
+ * enqueued multiply. During normal transport processing this bit
+ * is set by svc_xprt_enqueue and cleared by svc_xprt_received.
+ * Providers should not manipulate this bit directly.
+ *
+ * Some flags can be set to certain values at any time
+ * providing that certain rules are followed:
+ *
+ * XPT_CONN, XPT_DATA:
+ * - Can be set or cleared at any time.
+ * - After a set, svc_xprt_enqueue must be called to enqueue
+ * the transport for processing.
+ * - After a clear, the transport must be read/accepted.
+ * If this succeeds, it must be set again.
+ * XPT_CLOSE:
+ * - Can set at any time. It is never cleared.
+ * XPT_DEAD:
+ * - Can only be set while XPT_BUSY is held which ensures
+ * that no other thread will be using the transport or will
+ * try to set XPT_DEAD.
+ */
+
+/**
+ * svc_reg_xprt_class - Register a server-side RPC transport class
+ * @xcl: New transport class to be registered
+ *
+ * Returns zero on success; otherwise a negative errno is returned.
+ */
+int svc_reg_xprt_class(struct svc_xprt_class *xcl)
+{
+ struct svc_xprt_class *cl;
+ int res = -EEXIST;
+
+ INIT_LIST_HEAD(&xcl->xcl_list);
+ spin_lock(&svc_xprt_class_lock);
+ /* Make sure there isn't already a class with the same name */
+ list_for_each_entry(cl, &svc_xprt_class_list, xcl_list) {
+ if (strcmp(xcl->xcl_name, cl->xcl_name) == 0)
+ goto out;
+ }
+ list_add_tail(&xcl->xcl_list, &svc_xprt_class_list);
+ res = 0;
+out:
+ spin_unlock(&svc_xprt_class_lock);
+ return res;
+}
+EXPORT_SYMBOL_GPL(svc_reg_xprt_class);
+
+/**
+ * svc_unreg_xprt_class - Unregister a server-side RPC transport class
+ * @xcl: Transport class to be unregistered
+ *
+ */
+void svc_unreg_xprt_class(struct svc_xprt_class *xcl)
+{
+ spin_lock(&svc_xprt_class_lock);
+ list_del_init(&xcl->xcl_list);
+ spin_unlock(&svc_xprt_class_lock);
+}
+EXPORT_SYMBOL_GPL(svc_unreg_xprt_class);
+
+/**
+ * svc_print_xprts - Format the transport list for printing
+ * @buf: target buffer for formatted address
+ * @maxlen: length of target buffer
+ *
+ * Fills in @buf with a string containing a list of transport names, each name
+ * terminated with '\n'. If the buffer is too small, some entries may be
+ * missing, but it is guaranteed that all lines in the output buffer are
+ * complete.
+ *
+ * Returns positive length of the filled-in string.
+ */
+int svc_print_xprts(char *buf, int maxlen)
+{
+ struct svc_xprt_class *xcl;
+ char tmpstr[80];
+ int len = 0;
+ buf[0] = '\0';
+
+ spin_lock(&svc_xprt_class_lock);
+ list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) {
+ int slen;
+
+ slen = snprintf(tmpstr, sizeof(tmpstr), "%s %d\n",
+ xcl->xcl_name, xcl->xcl_max_payload);
+ if (slen >= sizeof(tmpstr) || len + slen >= maxlen)
+ break;
+ len += slen;
+ strcat(buf, tmpstr);
+ }
+ spin_unlock(&svc_xprt_class_lock);
+
+ return len;
+}
+
+/**
+ * svc_xprt_deferred_close - Close a transport
+ * @xprt: transport instance
+ *
+ * Used in contexts that need to defer the work of shutting down
+ * the transport to an nfsd thread.
+ */
+void svc_xprt_deferred_close(struct svc_xprt *xprt)
+{
+ if (!test_and_set_bit(XPT_CLOSE, &xprt->xpt_flags))
+ svc_xprt_enqueue(xprt);
+}
+EXPORT_SYMBOL_GPL(svc_xprt_deferred_close);
+
+static void svc_xprt_free(struct kref *kref)
+{
+ struct svc_xprt *xprt =
+ container_of(kref, struct svc_xprt, xpt_ref);
+ struct module *owner = xprt->xpt_class->xcl_owner;
+ if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags))
+ svcauth_unix_info_release(xprt);
+ put_cred(xprt->xpt_cred);
+ put_net_track(xprt->xpt_net, &xprt->ns_tracker);
+ /* See comment on corresponding get in xs_setup_bc_tcp(): */
+ if (xprt->xpt_bc_xprt)
+ xprt_put(xprt->xpt_bc_xprt);
+ if (xprt->xpt_bc_xps)
+ xprt_switch_put(xprt->xpt_bc_xps);
+ trace_svc_xprt_free(xprt);
+ xprt->xpt_ops->xpo_free(xprt);
+ module_put(owner);
+}
+
+void svc_xprt_put(struct svc_xprt *xprt)
+{
+ kref_put(&xprt->xpt_ref, svc_xprt_free);
+}
+EXPORT_SYMBOL_GPL(svc_xprt_put);
+
+/*
+ * Called by transport drivers to initialize the transport independent
+ * portion of the transport instance.
+ */
+void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl,
+ struct svc_xprt *xprt, struct svc_serv *serv)
+{
+ memset(xprt, 0, sizeof(*xprt));
+ xprt->xpt_class = xcl;
+ xprt->xpt_ops = xcl->xcl_ops;
+ kref_init(&xprt->xpt_ref);
+ xprt->xpt_server = serv;
+ INIT_LIST_HEAD(&xprt->xpt_list);
+ INIT_LIST_HEAD(&xprt->xpt_ready);
+ INIT_LIST_HEAD(&xprt->xpt_deferred);
+ INIT_LIST_HEAD(&xprt->xpt_users);
+ mutex_init(&xprt->xpt_mutex);
+ spin_lock_init(&xprt->xpt_lock);
+ set_bit(XPT_BUSY, &xprt->xpt_flags);
+ xprt->xpt_net = get_net_track(net, &xprt->ns_tracker, GFP_ATOMIC);
+ strcpy(xprt->xpt_remotebuf, "uninitialized");
+}
+EXPORT_SYMBOL_GPL(svc_xprt_init);
+
+static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl,
+ struct svc_serv *serv,
+ struct net *net,
+ const int family,
+ const unsigned short port,
+ int flags)
+{
+ struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_addr.s_addr = htonl(INADDR_ANY),
+ .sin_port = htons(port),
+ };
+#if IS_ENABLED(CONFIG_IPV6)
+ struct sockaddr_in6 sin6 = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = IN6ADDR_ANY_INIT,
+ .sin6_port = htons(port),
+ };
+#endif
+ struct svc_xprt *xprt;
+ struct sockaddr *sap;
+ size_t len;
+
+ switch (family) {
+ case PF_INET:
+ sap = (struct sockaddr *)&sin;
+ len = sizeof(sin);
+ break;
+#if IS_ENABLED(CONFIG_IPV6)
+ case PF_INET6:
+ sap = (struct sockaddr *)&sin6;
+ len = sizeof(sin6);
+ break;
+#endif
+ default:
+ return ERR_PTR(-EAFNOSUPPORT);
+ }
+
+ xprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags);
+ if (IS_ERR(xprt))
+ trace_svc_xprt_create_err(serv->sv_program->pg_name,
+ xcl->xcl_name, sap, len, xprt);
+ return xprt;
+}
+
+/**
+ * svc_xprt_received - start next receiver thread
+ * @xprt: controlling transport
+ *
+ * The caller must hold the XPT_BUSY bit and must
+ * not thereafter touch transport data.
+ *
+ * Note: XPT_DATA only gets cleared when a read-attempt finds no (or
+ * insufficient) data.
+ */
+void svc_xprt_received(struct svc_xprt *xprt)
+{
+ if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) {
+ WARN_ONCE(1, "xprt=0x%p already busy!", xprt);
+ return;
+ }
+
+ /* As soon as we clear busy, the xprt could be closed and
+ * 'put', so we need a reference to call svc_xprt_enqueue with:
+ */
+ svc_xprt_get(xprt);
+ smp_mb__before_atomic();
+ clear_bit(XPT_BUSY, &xprt->xpt_flags);
+ svc_xprt_enqueue(xprt);
+ svc_xprt_put(xprt);
+}
+EXPORT_SYMBOL_GPL(svc_xprt_received);
+
+void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new)
+{
+ clear_bit(XPT_TEMP, &new->xpt_flags);
+ spin_lock_bh(&serv->sv_lock);
+ list_add(&new->xpt_list, &serv->sv_permsocks);
+ spin_unlock_bh(&serv->sv_lock);
+ svc_xprt_received(new);
+}
+
+static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name,
+ struct net *net, const int family,
+ const unsigned short port, int flags,
+ const struct cred *cred)
+{
+ struct svc_xprt_class *xcl;
+
+ spin_lock(&svc_xprt_class_lock);
+ list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) {
+ struct svc_xprt *newxprt;
+ unsigned short newport;
+
+ if (strcmp(xprt_name, xcl->xcl_name))
+ continue;
+
+ if (!try_module_get(xcl->xcl_owner))
+ goto err;
+
+ spin_unlock(&svc_xprt_class_lock);
+ newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags);
+ if (IS_ERR(newxprt)) {
+ module_put(xcl->xcl_owner);
+ return PTR_ERR(newxprt);
+ }
+ newxprt->xpt_cred = get_cred(cred);
+ svc_add_new_perm_xprt(serv, newxprt);
+ newport = svc_xprt_local_port(newxprt);
+ return newport;
+ }
+ err:
+ spin_unlock(&svc_xprt_class_lock);
+ /* This errno is exposed to user space. Provide a reasonable
+ * perror msg for a bad transport. */
+ return -EPROTONOSUPPORT;
+}
+
+/**
+ * svc_xprt_create - Add a new listener to @serv
+ * @serv: target RPC service
+ * @xprt_name: transport class name
+ * @net: network namespace
+ * @family: network address family
+ * @port: listener port
+ * @flags: SVC_SOCK flags
+ * @cred: credential to bind to this transport
+ *
+ * Return values:
+ * %0: New listener added successfully
+ * %-EPROTONOSUPPORT: Requested transport type not supported
+ */
+int svc_xprt_create(struct svc_serv *serv, const char *xprt_name,
+ struct net *net, const int family,
+ const unsigned short port, int flags,
+ const struct cred *cred)
+{
+ int err;
+
+ err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred);
+ if (err == -EPROTONOSUPPORT) {
+ request_module("svc%s", xprt_name);
+ err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred);
+ }
+ return err;
+}
+EXPORT_SYMBOL_GPL(svc_xprt_create);
+
+/*
+ * Copy the local and remote xprt addresses to the rqstp structure
+ */
+void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt)
+{
+ memcpy(&rqstp->rq_addr, &xprt->xpt_remote, xprt->xpt_remotelen);
+ rqstp->rq_addrlen = xprt->xpt_remotelen;
+
+ /*
+ * Destination address in request is needed for binding the
+ * source address in RPC replies/callbacks later.
+ */
+ memcpy(&rqstp->rq_daddr, &xprt->xpt_local, xprt->xpt_locallen);
+ rqstp->rq_daddrlen = xprt->xpt_locallen;
+}
+EXPORT_SYMBOL_GPL(svc_xprt_copy_addrs);
+
+/**
+ * svc_print_addr - Format rq_addr field for printing
+ * @rqstp: svc_rqst struct containing address to print
+ * @buf: target buffer for formatted address
+ * @len: length of target buffer
+ *
+ */
+char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len)
+{
+ return __svc_print_addr(svc_addr(rqstp), buf, len);
+}
+EXPORT_SYMBOL_GPL(svc_print_addr);
+
+static bool svc_xprt_slots_in_range(struct svc_xprt *xprt)
+{
+ unsigned int limit = svc_rpc_per_connection_limit;
+ int nrqsts = atomic_read(&xprt->xpt_nr_rqsts);
+
+ return limit == 0 || (nrqsts >= 0 && nrqsts < limit);
+}
+
+static bool svc_xprt_reserve_slot(struct svc_rqst *rqstp, struct svc_xprt *xprt)
+{
+ if (!test_bit(RQ_DATA, &rqstp->rq_flags)) {
+ if (!svc_xprt_slots_in_range(xprt))
+ return false;
+ atomic_inc(&xprt->xpt_nr_rqsts);
+ set_bit(RQ_DATA, &rqstp->rq_flags);
+ }
+ return true;
+}
+
+static void svc_xprt_release_slot(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ if (test_and_clear_bit(RQ_DATA, &rqstp->rq_flags)) {
+ atomic_dec(&xprt->xpt_nr_rqsts);
+ smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */
+ svc_xprt_enqueue(xprt);
+ }
+}
+
+static bool svc_xprt_ready(struct svc_xprt *xprt)
+{
+ unsigned long xpt_flags;
+
+ /*
+ * If another cpu has recently updated xpt_flags,
+ * sk_sock->flags, xpt_reserved, or xpt_nr_rqsts, we need to
+ * know about it; otherwise it's possible that both that cpu and
+ * this one could call svc_xprt_enqueue() without either
+ * svc_xprt_enqueue() recognizing that the conditions below
+ * are satisfied, and we could stall indefinitely:
+ */
+ smp_rmb();
+ xpt_flags = READ_ONCE(xprt->xpt_flags);
+
+ trace_svc_xprt_enqueue(xprt, xpt_flags);
+ if (xpt_flags & BIT(XPT_BUSY))
+ return false;
+ if (xpt_flags & (BIT(XPT_CONN) | BIT(XPT_CLOSE) | BIT(XPT_HANDSHAKE)))
+ return true;
+ if (xpt_flags & (BIT(XPT_DATA) | BIT(XPT_DEFERRED))) {
+ if (xprt->xpt_ops->xpo_has_wspace(xprt) &&
+ svc_xprt_slots_in_range(xprt))
+ return true;
+ trace_svc_xprt_no_write_space(xprt);
+ return false;
+ }
+ return false;
+}
+
+/**
+ * svc_xprt_enqueue - Queue a transport on an idle nfsd thread
+ * @xprt: transport with data pending
+ *
+ */
+void svc_xprt_enqueue(struct svc_xprt *xprt)
+{
+ struct svc_pool *pool;
+
+ if (!svc_xprt_ready(xprt))
+ return;
+
+ /* Mark transport as busy. It will remain in this state until
+ * the provider calls svc_xprt_received. We update XPT_BUSY
+ * atomically because it also guards against trying to enqueue
+ * the transport twice.
+ */
+ if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags))
+ return;
+
+ pool = svc_pool_for_cpu(xprt->xpt_server);
+
+ percpu_counter_inc(&pool->sp_sockets_queued);
+ spin_lock_bh(&pool->sp_lock);
+ list_add_tail(&xprt->xpt_ready, &pool->sp_sockets);
+ spin_unlock_bh(&pool->sp_lock);
+
+ svc_pool_wake_idle_thread(pool);
+}
+EXPORT_SYMBOL_GPL(svc_xprt_enqueue);
+
+/*
+ * Dequeue the first transport, if there is one.
+ */
+static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool)
+{
+ struct svc_xprt *xprt = NULL;
+
+ if (list_empty(&pool->sp_sockets))
+ goto out;
+
+ spin_lock_bh(&pool->sp_lock);
+ if (likely(!list_empty(&pool->sp_sockets))) {
+ xprt = list_first_entry(&pool->sp_sockets,
+ struct svc_xprt, xpt_ready);
+ list_del_init(&xprt->xpt_ready);
+ svc_xprt_get(xprt);
+ }
+ spin_unlock_bh(&pool->sp_lock);
+out:
+ return xprt;
+}
+
+/**
+ * svc_reserve - change the space reserved for the reply to a request.
+ * @rqstp: The request in question
+ * @space: new max space to reserve
+ *
+ * Each request reserves some space on the output queue of the transport
+ * to make sure the reply fits. This function reduces that reserved
+ * space to be the amount of space used already, plus @space.
+ *
+ */
+void svc_reserve(struct svc_rqst *rqstp, int space)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+
+ space += rqstp->rq_res.head[0].iov_len;
+
+ if (xprt && space < rqstp->rq_reserved) {
+ atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved);
+ rqstp->rq_reserved = space;
+ smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */
+ svc_xprt_enqueue(xprt);
+ }
+}
+EXPORT_SYMBOL_GPL(svc_reserve);
+
+static void free_deferred(struct svc_xprt *xprt, struct svc_deferred_req *dr)
+{
+ if (!dr)
+ return;
+
+ xprt->xpt_ops->xpo_release_ctxt(xprt, dr->xprt_ctxt);
+ kfree(dr);
+}
+
+static void svc_xprt_release(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+
+ xprt->xpt_ops->xpo_release_ctxt(xprt, rqstp->rq_xprt_ctxt);
+ rqstp->rq_xprt_ctxt = NULL;
+
+ free_deferred(xprt, rqstp->rq_deferred);
+ rqstp->rq_deferred = NULL;
+
+ svc_rqst_release_pages(rqstp);
+ rqstp->rq_res.page_len = 0;
+ rqstp->rq_res.page_base = 0;
+
+ /* Reset response buffer and release
+ * the reservation.
+ * But first, check that enough space was reserved
+ * for the reply, otherwise we have a bug!
+ */
+ if ((rqstp->rq_res.len) > rqstp->rq_reserved)
+ printk(KERN_ERR "RPC request reserved %d but used %d\n",
+ rqstp->rq_reserved,
+ rqstp->rq_res.len);
+
+ rqstp->rq_res.head[0].iov_len = 0;
+ svc_reserve(rqstp, 0);
+ svc_xprt_release_slot(rqstp);
+ rqstp->rq_xprt = NULL;
+ svc_xprt_put(xprt);
+}
+
+/**
+ * svc_wake_up - Wake up a service thread for non-transport work
+ * @serv: RPC service
+ *
+ * Some svc_serv's will have occasional work to do, even when a xprt is not
+ * waiting to be serviced. This function is there to "kick" a task in one of
+ * those services so that it can wake up and do that work. Note that we only
+ * bother with pool 0 as we don't need to wake up more than one thread for
+ * this purpose.
+ */
+void svc_wake_up(struct svc_serv *serv)
+{
+ struct svc_pool *pool = &serv->sv_pools[0];
+
+ set_bit(SP_TASK_PENDING, &pool->sp_flags);
+ svc_pool_wake_idle_thread(pool);
+}
+EXPORT_SYMBOL_GPL(svc_wake_up);
+
+int svc_port_is_privileged(struct sockaddr *sin)
+{
+ switch (sin->sa_family) {
+ case AF_INET:
+ return ntohs(((struct sockaddr_in *)sin)->sin_port)
+ < PROT_SOCK;
+ case AF_INET6:
+ return ntohs(((struct sockaddr_in6 *)sin)->sin6_port)
+ < PROT_SOCK;
+ default:
+ return 0;
+ }
+}
+
+/*
+ * Make sure that we don't have too many active connections. If we have,
+ * something must be dropped. It's not clear what will happen if we allow
+ * "too many" connections, but when dealing with network-facing software,
+ * we have to code defensively. Here we do that by imposing hard limits.
+ *
+ * There's no point in trying to do random drop here for DoS
+ * prevention. The NFS clients does 1 reconnect in 15 seconds. An
+ * attacker can easily beat that.
+ *
+ * The only somewhat efficient mechanism would be if drop old
+ * connections from the same IP first. But right now we don't even
+ * record the client IP in svc_sock.
+ *
+ * single-threaded services that expect a lot of clients will probably
+ * need to set sv_maxconn to override the default value which is based
+ * on the number of threads
+ */
+static void svc_check_conn_limits(struct svc_serv *serv)
+{
+ unsigned int limit = serv->sv_maxconn ? serv->sv_maxconn :
+ (serv->sv_nrthreads+3) * 20;
+
+ if (serv->sv_tmpcnt > limit) {
+ struct svc_xprt *xprt = NULL;
+ spin_lock_bh(&serv->sv_lock);
+ if (!list_empty(&serv->sv_tempsocks)) {
+ /* Try to help the admin */
+ net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n",
+ serv->sv_name, serv->sv_maxconn ?
+ "max number of connections" :
+ "number of threads");
+ /*
+ * Always select the oldest connection. It's not fair,
+ * but so is life
+ */
+ xprt = list_entry(serv->sv_tempsocks.prev,
+ struct svc_xprt,
+ xpt_list);
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ svc_xprt_get(xprt);
+ }
+ spin_unlock_bh(&serv->sv_lock);
+
+ if (xprt) {
+ svc_xprt_enqueue(xprt);
+ svc_xprt_put(xprt);
+ }
+ }
+}
+
+static bool svc_alloc_arg(struct svc_rqst *rqstp)
+{
+ struct svc_serv *serv = rqstp->rq_server;
+ struct xdr_buf *arg = &rqstp->rq_arg;
+ unsigned long pages, filled, ret;
+
+ pages = (serv->sv_max_mesg + 2 * PAGE_SIZE) >> PAGE_SHIFT;
+ if (pages > RPCSVC_MAXPAGES) {
+ pr_warn_once("svc: warning: pages=%lu > RPCSVC_MAXPAGES=%lu\n",
+ pages, RPCSVC_MAXPAGES);
+ /* use as many pages as possible */
+ pages = RPCSVC_MAXPAGES;
+ }
+
+ for (filled = 0; filled < pages; filled = ret) {
+ ret = alloc_pages_bulk_array(GFP_KERNEL, pages,
+ rqstp->rq_pages);
+ if (ret > filled)
+ /* Made progress, don't sleep yet */
+ continue;
+
+ set_current_state(TASK_IDLE);
+ if (kthread_should_stop()) {
+ set_current_state(TASK_RUNNING);
+ return false;
+ }
+ trace_svc_alloc_arg_err(pages, ret);
+ memalloc_retry_wait(GFP_KERNEL);
+ }
+ rqstp->rq_page_end = &rqstp->rq_pages[pages];
+ rqstp->rq_pages[pages] = NULL; /* this might be seen in nfsd_splice_actor() */
+
+ /* Make arg->head point to first page and arg->pages point to rest */
+ arg->head[0].iov_base = page_address(rqstp->rq_pages[0]);
+ arg->head[0].iov_len = PAGE_SIZE;
+ arg->pages = rqstp->rq_pages + 1;
+ arg->page_base = 0;
+ /* save at least one page for response */
+ arg->page_len = (pages-2)*PAGE_SIZE;
+ arg->len = (pages-1)*PAGE_SIZE;
+ arg->tail[0].iov_len = 0;
+
+ rqstp->rq_xid = xdr_zero;
+ return true;
+}
+
+static bool
+rqst_should_sleep(struct svc_rqst *rqstp)
+{
+ struct svc_pool *pool = rqstp->rq_pool;
+
+ /* did someone call svc_wake_up? */
+ if (test_bit(SP_TASK_PENDING, &pool->sp_flags))
+ return false;
+
+ /* was a socket queued? */
+ if (!list_empty(&pool->sp_sockets))
+ return false;
+
+ /* are we shutting down? */
+ if (kthread_should_stop())
+ return false;
+
+ /* are we freezing? */
+ if (freezing(current))
+ return false;
+
+ return true;
+}
+
+static struct svc_xprt *svc_get_next_xprt(struct svc_rqst *rqstp)
+{
+ struct svc_pool *pool = rqstp->rq_pool;
+
+ /* rq_xprt should be clear on entry */
+ WARN_ON_ONCE(rqstp->rq_xprt);
+
+ rqstp->rq_xprt = svc_xprt_dequeue(pool);
+ if (rqstp->rq_xprt)
+ goto out_found;
+
+ set_current_state(TASK_IDLE);
+ smp_mb__before_atomic();
+ clear_bit(SP_CONGESTED, &pool->sp_flags);
+ clear_bit(RQ_BUSY, &rqstp->rq_flags);
+ smp_mb__after_atomic();
+
+ if (likely(rqst_should_sleep(rqstp)))
+ schedule();
+ else
+ __set_current_state(TASK_RUNNING);
+
+ try_to_freeze();
+
+ set_bit(RQ_BUSY, &rqstp->rq_flags);
+ smp_mb__after_atomic();
+ clear_bit(SP_TASK_PENDING, &pool->sp_flags);
+ rqstp->rq_xprt = svc_xprt_dequeue(pool);
+ if (rqstp->rq_xprt)
+ goto out_found;
+
+ if (kthread_should_stop())
+ return NULL;
+ return NULL;
+out_found:
+ clear_bit(SP_TASK_PENDING, &pool->sp_flags);
+ /* Normally we will wait up to 5 seconds for any required
+ * cache information to be provided.
+ */
+ if (!test_bit(SP_CONGESTED, &pool->sp_flags))
+ rqstp->rq_chandle.thread_wait = 5*HZ;
+ else
+ rqstp->rq_chandle.thread_wait = 1*HZ;
+ trace_svc_xprt_dequeue(rqstp);
+ return rqstp->rq_xprt;
+}
+
+static void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt)
+{
+ spin_lock_bh(&serv->sv_lock);
+ set_bit(XPT_TEMP, &newxpt->xpt_flags);
+ list_add(&newxpt->xpt_list, &serv->sv_tempsocks);
+ serv->sv_tmpcnt++;
+ if (serv->sv_temptimer.function == NULL) {
+ /* setup timer to age temp transports */
+ serv->sv_temptimer.function = svc_age_temp_xprts;
+ mod_timer(&serv->sv_temptimer,
+ jiffies + svc_conn_age_period * HZ);
+ }
+ spin_unlock_bh(&serv->sv_lock);
+ svc_xprt_received(newxpt);
+}
+
+static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt)
+{
+ struct svc_serv *serv = rqstp->rq_server;
+ int len = 0;
+
+ if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) {
+ if (test_and_clear_bit(XPT_KILL_TEMP, &xprt->xpt_flags))
+ xprt->xpt_ops->xpo_kill_temp_xprt(xprt);
+ svc_delete_xprt(xprt);
+ /* Leave XPT_BUSY set on the dead xprt: */
+ goto out;
+ }
+ if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) {
+ struct svc_xprt *newxpt;
+ /*
+ * We know this module_get will succeed because the
+ * listener holds a reference too
+ */
+ __module_get(xprt->xpt_class->xcl_owner);
+ svc_check_conn_limits(xprt->xpt_server);
+ newxpt = xprt->xpt_ops->xpo_accept(xprt);
+ if (newxpt) {
+ newxpt->xpt_cred = get_cred(xprt->xpt_cred);
+ svc_add_new_temp_xprt(serv, newxpt);
+ trace_svc_xprt_accept(newxpt, serv->sv_name);
+ } else {
+ module_put(xprt->xpt_class->xcl_owner);
+ }
+ svc_xprt_received(xprt);
+ } else if (test_bit(XPT_HANDSHAKE, &xprt->xpt_flags)) {
+ xprt->xpt_ops->xpo_handshake(xprt);
+ svc_xprt_received(xprt);
+ } else if (svc_xprt_reserve_slot(rqstp, xprt)) {
+ /* XPT_DATA|XPT_DEFERRED case: */
+ rqstp->rq_deferred = svc_deferred_dequeue(xprt);
+ if (rqstp->rq_deferred)
+ len = svc_deferred_recv(rqstp);
+ else
+ len = xprt->xpt_ops->xpo_recvfrom(rqstp);
+ rqstp->rq_reserved = serv->sv_max_mesg;
+ atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved);
+ } else
+ svc_xprt_received(xprt);
+
+out:
+ return len;
+}
+
+/**
+ * svc_recv - Receive and process the next request on any transport
+ * @rqstp: an idle RPC service thread
+ *
+ * This code is carefully organised not to touch any cachelines in
+ * the shared svc_serv structure, only cachelines in the local
+ * svc_pool.
+ */
+void svc_recv(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = NULL;
+ struct svc_serv *serv = rqstp->rq_server;
+ int len;
+
+ if (!svc_alloc_arg(rqstp))
+ goto out;
+
+ try_to_freeze();
+ cond_resched();
+ if (kthread_should_stop())
+ goto out;
+
+ xprt = svc_get_next_xprt(rqstp);
+ if (!xprt)
+ goto out;
+
+ len = svc_handle_xprt(rqstp, xprt);
+
+ /* No data, incomplete (TCP) read, or accept() */
+ if (len <= 0)
+ goto out_release;
+
+ trace_svc_xdr_recvfrom(&rqstp->rq_arg);
+
+ clear_bit(XPT_OLD, &xprt->xpt_flags);
+
+ rqstp->rq_chandle.defer = svc_defer;
+
+ if (serv->sv_stats)
+ serv->sv_stats->netcnt++;
+ percpu_counter_inc(&rqstp->rq_pool->sp_messages_arrived);
+ rqstp->rq_stime = ktime_get();
+ svc_process(rqstp);
+out:
+ return;
+out_release:
+ rqstp->rq_res.len = 0;
+ svc_xprt_release(rqstp);
+}
+EXPORT_SYMBOL_GPL(svc_recv);
+
+/*
+ * Drop request
+ */
+void svc_drop(struct svc_rqst *rqstp)
+{
+ trace_svc_drop(rqstp);
+ svc_xprt_release(rqstp);
+}
+EXPORT_SYMBOL_GPL(svc_drop);
+
+/**
+ * svc_send - Return reply to client
+ * @rqstp: RPC transaction context
+ *
+ */
+void svc_send(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt;
+ struct xdr_buf *xb;
+ int status;
+
+ xprt = rqstp->rq_xprt;
+ if (!xprt)
+ return;
+
+ /* calculate over-all length */
+ xb = &rqstp->rq_res;
+ xb->len = xb->head[0].iov_len +
+ xb->page_len +
+ xb->tail[0].iov_len;
+ trace_svc_xdr_sendto(rqstp->rq_xid, xb);
+ trace_svc_stats_latency(rqstp);
+
+ status = xprt->xpt_ops->xpo_sendto(rqstp);
+
+ trace_svc_send(rqstp, status);
+ svc_xprt_release(rqstp);
+}
+
+/*
+ * Timer function to close old temporary transports, using
+ * a mark-and-sweep algorithm.
+ */
+static void svc_age_temp_xprts(struct timer_list *t)
+{
+ struct svc_serv *serv = from_timer(serv, t, sv_temptimer);
+ struct svc_xprt *xprt;
+ struct list_head *le, *next;
+
+ dprintk("svc_age_temp_xprts\n");
+
+ if (!spin_trylock_bh(&serv->sv_lock)) {
+ /* busy, try again 1 sec later */
+ dprintk("svc_age_temp_xprts: busy\n");
+ mod_timer(&serv->sv_temptimer, jiffies + HZ);
+ return;
+ }
+
+ list_for_each_safe(le, next, &serv->sv_tempsocks) {
+ xprt = list_entry(le, struct svc_xprt, xpt_list);
+
+ /* First time through, just mark it OLD. Second time
+ * through, close it. */
+ if (!test_and_set_bit(XPT_OLD, &xprt->xpt_flags))
+ continue;
+ if (kref_read(&xprt->xpt_ref) > 1 ||
+ test_bit(XPT_BUSY, &xprt->xpt_flags))
+ continue;
+ list_del_init(le);
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ dprintk("queuing xprt %p for closing\n", xprt);
+
+ /* a thread will dequeue and close it soon */
+ svc_xprt_enqueue(xprt);
+ }
+ spin_unlock_bh(&serv->sv_lock);
+
+ mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ);
+}
+
+/* Close temporary transports whose xpt_local matches server_addr immediately
+ * instead of waiting for them to be picked up by the timer.
+ *
+ * This is meant to be called from a notifier_block that runs when an ip
+ * address is deleted.
+ */
+void svc_age_temp_xprts_now(struct svc_serv *serv, struct sockaddr *server_addr)
+{
+ struct svc_xprt *xprt;
+ struct list_head *le, *next;
+ LIST_HEAD(to_be_closed);
+
+ spin_lock_bh(&serv->sv_lock);
+ list_for_each_safe(le, next, &serv->sv_tempsocks) {
+ xprt = list_entry(le, struct svc_xprt, xpt_list);
+ if (rpc_cmp_addr(server_addr, (struct sockaddr *)
+ &xprt->xpt_local)) {
+ dprintk("svc_age_temp_xprts_now: found %p\n", xprt);
+ list_move(le, &to_be_closed);
+ }
+ }
+ spin_unlock_bh(&serv->sv_lock);
+
+ while (!list_empty(&to_be_closed)) {
+ le = to_be_closed.next;
+ list_del_init(le);
+ xprt = list_entry(le, struct svc_xprt, xpt_list);
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ set_bit(XPT_KILL_TEMP, &xprt->xpt_flags);
+ dprintk("svc_age_temp_xprts_now: queuing xprt %p for closing\n",
+ xprt);
+ svc_xprt_enqueue(xprt);
+ }
+}
+EXPORT_SYMBOL_GPL(svc_age_temp_xprts_now);
+
+static void call_xpt_users(struct svc_xprt *xprt)
+{
+ struct svc_xpt_user *u;
+
+ spin_lock(&xprt->xpt_lock);
+ while (!list_empty(&xprt->xpt_users)) {
+ u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list);
+ list_del_init(&u->list);
+ u->callback(u);
+ }
+ spin_unlock(&xprt->xpt_lock);
+}
+
+/*
+ * Remove a dead transport
+ */
+static void svc_delete_xprt(struct svc_xprt *xprt)
+{
+ struct svc_serv *serv = xprt->xpt_server;
+ struct svc_deferred_req *dr;
+
+ if (test_and_set_bit(XPT_DEAD, &xprt->xpt_flags))
+ return;
+
+ trace_svc_xprt_detach(xprt);
+ xprt->xpt_ops->xpo_detach(xprt);
+ if (xprt->xpt_bc_xprt)
+ xprt->xpt_bc_xprt->ops->close(xprt->xpt_bc_xprt);
+
+ spin_lock_bh(&serv->sv_lock);
+ list_del_init(&xprt->xpt_list);
+ WARN_ON_ONCE(!list_empty(&xprt->xpt_ready));
+ if (test_bit(XPT_TEMP, &xprt->xpt_flags))
+ serv->sv_tmpcnt--;
+ spin_unlock_bh(&serv->sv_lock);
+
+ while ((dr = svc_deferred_dequeue(xprt)) != NULL)
+ free_deferred(xprt, dr);
+
+ call_xpt_users(xprt);
+ svc_xprt_put(xprt);
+}
+
+/**
+ * svc_xprt_close - Close a client connection
+ * @xprt: transport to disconnect
+ *
+ */
+void svc_xprt_close(struct svc_xprt *xprt)
+{
+ trace_svc_xprt_close(xprt);
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags))
+ /* someone else will have to effect the close */
+ return;
+ /*
+ * We expect svc_close_xprt() to work even when no threads are
+ * running (e.g., while configuring the server before starting
+ * any threads), so if the transport isn't busy, we delete
+ * it ourself:
+ */
+ svc_delete_xprt(xprt);
+}
+EXPORT_SYMBOL_GPL(svc_xprt_close);
+
+static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, struct net *net)
+{
+ struct svc_xprt *xprt;
+ int ret = 0;
+
+ spin_lock_bh(&serv->sv_lock);
+ list_for_each_entry(xprt, xprt_list, xpt_list) {
+ if (xprt->xpt_net != net)
+ continue;
+ ret++;
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ svc_xprt_enqueue(xprt);
+ }
+ spin_unlock_bh(&serv->sv_lock);
+ return ret;
+}
+
+static struct svc_xprt *svc_dequeue_net(struct svc_serv *serv, struct net *net)
+{
+ struct svc_pool *pool;
+ struct svc_xprt *xprt;
+ struct svc_xprt *tmp;
+ int i;
+
+ for (i = 0; i < serv->sv_nrpools; i++) {
+ pool = &serv->sv_pools[i];
+
+ spin_lock_bh(&pool->sp_lock);
+ list_for_each_entry_safe(xprt, tmp, &pool->sp_sockets, xpt_ready) {
+ if (xprt->xpt_net != net)
+ continue;
+ list_del_init(&xprt->xpt_ready);
+ spin_unlock_bh(&pool->sp_lock);
+ return xprt;
+ }
+ spin_unlock_bh(&pool->sp_lock);
+ }
+ return NULL;
+}
+
+static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net)
+{
+ struct svc_xprt *xprt;
+
+ while ((xprt = svc_dequeue_net(serv, net))) {
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+ svc_delete_xprt(xprt);
+ }
+}
+
+/**
+ * svc_xprt_destroy_all - Destroy transports associated with @serv
+ * @serv: RPC service to be shut down
+ * @net: target network namespace
+ *
+ * Server threads may still be running (especially in the case where the
+ * service is still running in other network namespaces).
+ *
+ * So we shut down sockets the same way we would on a running server, by
+ * setting XPT_CLOSE, enqueuing, and letting a thread pick it up to do
+ * the close. In the case there are no such other threads,
+ * threads running, svc_clean_up_xprts() does a simple version of a
+ * server's main event loop, and in the case where there are other
+ * threads, we may need to wait a little while and then check again to
+ * see if they're done.
+ */
+void svc_xprt_destroy_all(struct svc_serv *serv, struct net *net)
+{
+ int delay = 0;
+
+ while (svc_close_list(serv, &serv->sv_permsocks, net) +
+ svc_close_list(serv, &serv->sv_tempsocks, net)) {
+
+ svc_clean_up_xprts(serv, net);
+ msleep(delay++);
+ }
+}
+EXPORT_SYMBOL_GPL(svc_xprt_destroy_all);
+
+/*
+ * Handle defer and revisit of requests
+ */
+
+static void svc_revisit(struct cache_deferred_req *dreq, int too_many)
+{
+ struct svc_deferred_req *dr =
+ container_of(dreq, struct svc_deferred_req, handle);
+ struct svc_xprt *xprt = dr->xprt;
+
+ spin_lock(&xprt->xpt_lock);
+ set_bit(XPT_DEFERRED, &xprt->xpt_flags);
+ if (too_many || test_bit(XPT_DEAD, &xprt->xpt_flags)) {
+ spin_unlock(&xprt->xpt_lock);
+ trace_svc_defer_drop(dr);
+ free_deferred(xprt, dr);
+ svc_xprt_put(xprt);
+ return;
+ }
+ dr->xprt = NULL;
+ list_add(&dr->handle.recent, &xprt->xpt_deferred);
+ spin_unlock(&xprt->xpt_lock);
+ trace_svc_defer_queue(dr);
+ svc_xprt_enqueue(xprt);
+ svc_xprt_put(xprt);
+}
+
+/*
+ * Save the request off for later processing. The request buffer looks
+ * like this:
+ *
+ * <xprt-header><rpc-header><rpc-pagelist><rpc-tail>
+ *
+ * This code can only handle requests that consist of an xprt-header
+ * and rpc-header.
+ */
+static struct cache_deferred_req *svc_defer(struct cache_req *req)
+{
+ struct svc_rqst *rqstp = container_of(req, struct svc_rqst, rq_chandle);
+ struct svc_deferred_req *dr;
+
+ if (rqstp->rq_arg.page_len || !test_bit(RQ_USEDEFERRAL, &rqstp->rq_flags))
+ return NULL; /* if more than a page, give up FIXME */
+ if (rqstp->rq_deferred) {
+ dr = rqstp->rq_deferred;
+ rqstp->rq_deferred = NULL;
+ } else {
+ size_t skip;
+ size_t size;
+ /* FIXME maybe discard if size too large */
+ size = sizeof(struct svc_deferred_req) + rqstp->rq_arg.len;
+ dr = kmalloc(size, GFP_KERNEL);
+ if (dr == NULL)
+ return NULL;
+
+ dr->handle.owner = rqstp->rq_server;
+ dr->prot = rqstp->rq_prot;
+ memcpy(&dr->addr, &rqstp->rq_addr, rqstp->rq_addrlen);
+ dr->addrlen = rqstp->rq_addrlen;
+ dr->daddr = rqstp->rq_daddr;
+ dr->argslen = rqstp->rq_arg.len >> 2;
+
+ /* back up head to the start of the buffer and copy */
+ skip = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len;
+ memcpy(dr->args, rqstp->rq_arg.head[0].iov_base - skip,
+ dr->argslen << 2);
+ }
+ dr->xprt_ctxt = rqstp->rq_xprt_ctxt;
+ rqstp->rq_xprt_ctxt = NULL;
+ trace_svc_defer(rqstp);
+ svc_xprt_get(rqstp->rq_xprt);
+ dr->xprt = rqstp->rq_xprt;
+ set_bit(RQ_DROPME, &rqstp->rq_flags);
+
+ dr->handle.revisit = svc_revisit;
+ return &dr->handle;
+}
+
+/*
+ * recv data from a deferred request into an active one
+ */
+static noinline int svc_deferred_recv(struct svc_rqst *rqstp)
+{
+ struct svc_deferred_req *dr = rqstp->rq_deferred;
+
+ trace_svc_defer_recv(dr);
+
+ /* setup iov_base past transport header */
+ rqstp->rq_arg.head[0].iov_base = dr->args;
+ /* The iov_len does not include the transport header bytes */
+ rqstp->rq_arg.head[0].iov_len = dr->argslen << 2;
+ rqstp->rq_arg.page_len = 0;
+ /* The rq_arg.len includes the transport header bytes */
+ rqstp->rq_arg.len = dr->argslen << 2;
+ rqstp->rq_prot = dr->prot;
+ memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen);
+ rqstp->rq_addrlen = dr->addrlen;
+ /* Save off transport header len in case we get deferred again */
+ rqstp->rq_daddr = dr->daddr;
+ rqstp->rq_respages = rqstp->rq_pages;
+ rqstp->rq_xprt_ctxt = dr->xprt_ctxt;
+
+ dr->xprt_ctxt = NULL;
+ svc_xprt_received(rqstp->rq_xprt);
+ return dr->argslen << 2;
+}
+
+
+static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt)
+{
+ struct svc_deferred_req *dr = NULL;
+
+ if (!test_bit(XPT_DEFERRED, &xprt->xpt_flags))
+ return NULL;
+ spin_lock(&xprt->xpt_lock);
+ if (!list_empty(&xprt->xpt_deferred)) {
+ dr = list_entry(xprt->xpt_deferred.next,
+ struct svc_deferred_req,
+ handle.recent);
+ list_del_init(&dr->handle.recent);
+ } else
+ clear_bit(XPT_DEFERRED, &xprt->xpt_flags);
+ spin_unlock(&xprt->xpt_lock);
+ return dr;
+}
+
+/**
+ * svc_find_xprt - find an RPC transport instance
+ * @serv: pointer to svc_serv to search
+ * @xcl_name: C string containing transport's class name
+ * @net: owner net pointer
+ * @af: Address family of transport's local address
+ * @port: transport's IP port number
+ *
+ * Return the transport instance pointer for the endpoint accepting
+ * connections/peer traffic from the specified transport class,
+ * address family and port.
+ *
+ * Specifying 0 for the address family or port is effectively a
+ * wild-card, and will result in matching the first transport in the
+ * service's list that has a matching class name.
+ */
+struct svc_xprt *svc_find_xprt(struct svc_serv *serv, const char *xcl_name,
+ struct net *net, const sa_family_t af,
+ const unsigned short port)
+{
+ struct svc_xprt *xprt;
+ struct svc_xprt *found = NULL;
+
+ /* Sanity check the args */
+ if (serv == NULL || xcl_name == NULL)
+ return found;
+
+ spin_lock_bh(&serv->sv_lock);
+ list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) {
+ if (xprt->xpt_net != net)
+ continue;
+ if (strcmp(xprt->xpt_class->xcl_name, xcl_name))
+ continue;
+ if (af != AF_UNSPEC && af != xprt->xpt_local.ss_family)
+ continue;
+ if (port != 0 && port != svc_xprt_local_port(xprt))
+ continue;
+ found = xprt;
+ svc_xprt_get(xprt);
+ break;
+ }
+ spin_unlock_bh(&serv->sv_lock);
+ return found;
+}
+EXPORT_SYMBOL_GPL(svc_find_xprt);
+
+static int svc_one_xprt_name(const struct svc_xprt *xprt,
+ char *pos, int remaining)
+{
+ int len;
+
+ len = snprintf(pos, remaining, "%s %u\n",
+ xprt->xpt_class->xcl_name,
+ svc_xprt_local_port(xprt));
+ if (len >= remaining)
+ return -ENAMETOOLONG;
+ return len;
+}
+
+/**
+ * svc_xprt_names - format a buffer with a list of transport names
+ * @serv: pointer to an RPC service
+ * @buf: pointer to a buffer to be filled in
+ * @buflen: length of buffer to be filled in
+ *
+ * Fills in @buf with a string containing a list of transport names,
+ * each name terminated with '\n'.
+ *
+ * Returns positive length of the filled-in string on success; otherwise
+ * a negative errno value is returned if an error occurs.
+ */
+int svc_xprt_names(struct svc_serv *serv, char *buf, const int buflen)
+{
+ struct svc_xprt *xprt;
+ int len, totlen;
+ char *pos;
+
+ /* Sanity check args */
+ if (!serv)
+ return 0;
+
+ spin_lock_bh(&serv->sv_lock);
+
+ pos = buf;
+ totlen = 0;
+ list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) {
+ len = svc_one_xprt_name(xprt, pos, buflen - totlen);
+ if (len < 0) {
+ *buf = '\0';
+ totlen = len;
+ }
+ if (len <= 0)
+ break;
+
+ pos += len;
+ totlen += len;
+ }
+
+ spin_unlock_bh(&serv->sv_lock);
+ return totlen;
+}
+EXPORT_SYMBOL_GPL(svc_xprt_names);
+
+
+/*----------------------------------------------------------------------------*/
+
+static void *svc_pool_stats_start(struct seq_file *m, loff_t *pos)
+{
+ unsigned int pidx = (unsigned int)*pos;
+ struct svc_serv *serv = m->private;
+
+ dprintk("svc_pool_stats_start, *pidx=%u\n", pidx);
+
+ if (!pidx)
+ return SEQ_START_TOKEN;
+ return (pidx > serv->sv_nrpools ? NULL : &serv->sv_pools[pidx-1]);
+}
+
+static void *svc_pool_stats_next(struct seq_file *m, void *p, loff_t *pos)
+{
+ struct svc_pool *pool = p;
+ struct svc_serv *serv = m->private;
+
+ dprintk("svc_pool_stats_next, *pos=%llu\n", *pos);
+
+ if (p == SEQ_START_TOKEN) {
+ pool = &serv->sv_pools[0];
+ } else {
+ unsigned int pidx = (pool - &serv->sv_pools[0]);
+ if (pidx < serv->sv_nrpools-1)
+ pool = &serv->sv_pools[pidx+1];
+ else
+ pool = NULL;
+ }
+ ++*pos;
+ return pool;
+}
+
+static void svc_pool_stats_stop(struct seq_file *m, void *p)
+{
+}
+
+static int svc_pool_stats_show(struct seq_file *m, void *p)
+{
+ struct svc_pool *pool = p;
+
+ if (p == SEQ_START_TOKEN) {
+ seq_puts(m, "# pool packets-arrived sockets-enqueued threads-woken threads-timedout\n");
+ return 0;
+ }
+
+ seq_printf(m, "%u %llu %llu %llu 0\n",
+ pool->sp_id,
+ percpu_counter_sum_positive(&pool->sp_messages_arrived),
+ percpu_counter_sum_positive(&pool->sp_sockets_queued),
+ percpu_counter_sum_positive(&pool->sp_threads_woken));
+
+ return 0;
+}
+
+static const struct seq_operations svc_pool_stats_seq_ops = {
+ .start = svc_pool_stats_start,
+ .next = svc_pool_stats_next,
+ .stop = svc_pool_stats_stop,
+ .show = svc_pool_stats_show,
+};
+
+int svc_pool_stats_open(struct svc_serv *serv, struct file *file)
+{
+ int err;
+
+ err = seq_open(file, &svc_pool_stats_seq_ops);
+ if (!err)
+ ((struct seq_file *) file->private_data)->private = serv;
+ return err;
+}
+EXPORT_SYMBOL(svc_pool_stats_open);
+
+/*----------------------------------------------------------------------------*/
diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c
new file mode 100644
index 0000000000..aa4429d0b8
--- /dev/null
+++ b/net/sunrpc/svcauth.c
@@ -0,0 +1,260 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/svcauth.c
+ *
+ * The generic interface for RPC authentication on the server side.
+ *
+ * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
+ *
+ * CHANGES
+ * 19-Apr-2000 Chris Evans - Security fix
+ */
+
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/svcauth.h>
+#include <linux/err.h>
+#include <linux/hash.h>
+
+#include <trace/events/sunrpc.h>
+
+#include "sunrpc.h"
+
+#define RPCDBG_FACILITY RPCDBG_AUTH
+
+
+/*
+ * Table of authenticators
+ */
+extern struct auth_ops svcauth_null;
+extern struct auth_ops svcauth_unix;
+extern struct auth_ops svcauth_tls;
+
+static struct auth_ops __rcu *authtab[RPC_AUTH_MAXFLAVOR] = {
+ [RPC_AUTH_NULL] = (struct auth_ops __force __rcu *)&svcauth_null,
+ [RPC_AUTH_UNIX] = (struct auth_ops __force __rcu *)&svcauth_unix,
+ [RPC_AUTH_TLS] = (struct auth_ops __force __rcu *)&svcauth_tls,
+};
+
+static struct auth_ops *
+svc_get_auth_ops(rpc_authflavor_t flavor)
+{
+ struct auth_ops *aops;
+
+ if (flavor >= RPC_AUTH_MAXFLAVOR)
+ return NULL;
+ rcu_read_lock();
+ aops = rcu_dereference(authtab[flavor]);
+ if (aops != NULL && !try_module_get(aops->owner))
+ aops = NULL;
+ rcu_read_unlock();
+ return aops;
+}
+
+static void
+svc_put_auth_ops(struct auth_ops *aops)
+{
+ module_put(aops->owner);
+}
+
+/**
+ * svc_authenticate - Initialize an outgoing credential
+ * @rqstp: RPC execution context
+ *
+ * Return values:
+ * %SVC_OK: XDR encoding of the result can begin
+ * %SVC_DENIED: Credential or verifier is not valid
+ * %SVC_GARBAGE: Failed to decode credential or verifier
+ * %SVC_COMPLETE: GSS context lifetime event; no further action
+ * %SVC_DROP: Drop this request; no further action
+ * %SVC_CLOSE: Like drop, but also close transport connection
+ */
+enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp)
+{
+ struct auth_ops *aops;
+ u32 flavor;
+
+ rqstp->rq_auth_stat = rpc_auth_ok;
+
+ /*
+ * Decode the Call credential's flavor field. The credential's
+ * body field is decoded in the chosen ->accept method below.
+ */
+ if (xdr_stream_decode_u32(&rqstp->rq_arg_stream, &flavor) < 0)
+ return SVC_GARBAGE;
+
+ aops = svc_get_auth_ops(flavor);
+ if (aops == NULL) {
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+ }
+
+ rqstp->rq_auth_slack = 0;
+ init_svc_cred(&rqstp->rq_cred);
+
+ rqstp->rq_authop = aops;
+ return aops->accept(rqstp);
+}
+EXPORT_SYMBOL_GPL(svc_authenticate);
+
+/**
+ * svc_set_client - Assign an appropriate 'auth_domain' as the client
+ * @rqstp: RPC execution context
+ *
+ * Return values:
+ * %SVC_OK: Client was found and assigned
+ * %SVC_DENY: Client was explicitly denied
+ * %SVC_DROP: Ignore this request
+ * %SVC_CLOSE: Ignore this request and close the connection
+ */
+enum svc_auth_status svc_set_client(struct svc_rqst *rqstp)
+{
+ rqstp->rq_client = NULL;
+ return rqstp->rq_authop->set_client(rqstp);
+}
+EXPORT_SYMBOL_GPL(svc_set_client);
+
+/**
+ * svc_authorise - Finalize credentials/verifier and release resources
+ * @rqstp: RPC execution context
+ *
+ * Returns zero on success, or a negative errno.
+ */
+int svc_authorise(struct svc_rqst *rqstp)
+{
+ struct auth_ops *aops = rqstp->rq_authop;
+ int rv = 0;
+
+ rqstp->rq_authop = NULL;
+
+ if (aops) {
+ rv = aops->release(rqstp);
+ svc_put_auth_ops(aops);
+ }
+ return rv;
+}
+
+int
+svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops)
+{
+ struct auth_ops *old;
+ int rv = -EINVAL;
+
+ if (flavor < RPC_AUTH_MAXFLAVOR) {
+ old = cmpxchg((struct auth_ops ** __force)&authtab[flavor], NULL, aops);
+ if (old == NULL || old == aops)
+ rv = 0;
+ }
+ return rv;
+}
+EXPORT_SYMBOL_GPL(svc_auth_register);
+
+void
+svc_auth_unregister(rpc_authflavor_t flavor)
+{
+ if (flavor < RPC_AUTH_MAXFLAVOR)
+ rcu_assign_pointer(authtab[flavor], NULL);
+}
+EXPORT_SYMBOL_GPL(svc_auth_unregister);
+
+/**************************************************
+ * 'auth_domains' are stored in a hash table indexed by name.
+ * When the last reference to an 'auth_domain' is dropped,
+ * the object is unhashed and freed.
+ * If auth_domain_lookup fails to find an entry, it will return
+ * it's second argument 'new'. If this is non-null, it will
+ * have been atomically linked into the table.
+ */
+
+#define DN_HASHBITS 6
+#define DN_HASHMAX (1<<DN_HASHBITS)
+
+static struct hlist_head auth_domain_table[DN_HASHMAX];
+static DEFINE_SPINLOCK(auth_domain_lock);
+
+static void auth_domain_release(struct kref *kref)
+ __releases(&auth_domain_lock)
+{
+ struct auth_domain *dom = container_of(kref, struct auth_domain, ref);
+
+ hlist_del_rcu(&dom->hash);
+ dom->flavour->domain_release(dom);
+ spin_unlock(&auth_domain_lock);
+}
+
+void auth_domain_put(struct auth_domain *dom)
+{
+ kref_put_lock(&dom->ref, auth_domain_release, &auth_domain_lock);
+}
+EXPORT_SYMBOL_GPL(auth_domain_put);
+
+struct auth_domain *
+auth_domain_lookup(char *name, struct auth_domain *new)
+{
+ struct auth_domain *hp;
+ struct hlist_head *head;
+
+ head = &auth_domain_table[hash_str(name, DN_HASHBITS)];
+
+ spin_lock(&auth_domain_lock);
+
+ hlist_for_each_entry(hp, head, hash) {
+ if (strcmp(hp->name, name)==0) {
+ kref_get(&hp->ref);
+ spin_unlock(&auth_domain_lock);
+ return hp;
+ }
+ }
+ if (new)
+ hlist_add_head_rcu(&new->hash, head);
+ spin_unlock(&auth_domain_lock);
+ return new;
+}
+EXPORT_SYMBOL_GPL(auth_domain_lookup);
+
+struct auth_domain *auth_domain_find(char *name)
+{
+ struct auth_domain *hp;
+ struct hlist_head *head;
+
+ head = &auth_domain_table[hash_str(name, DN_HASHBITS)];
+
+ rcu_read_lock();
+ hlist_for_each_entry_rcu(hp, head, hash) {
+ if (strcmp(hp->name, name)==0) {
+ if (!kref_get_unless_zero(&hp->ref))
+ hp = NULL;
+ rcu_read_unlock();
+ return hp;
+ }
+ }
+ rcu_read_unlock();
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(auth_domain_find);
+
+/**
+ * auth_domain_cleanup - check that the auth_domain table is empty
+ *
+ * On module unload the auth_domain_table must be empty. To make it
+ * easier to catch bugs which don't clean up domains properly, we
+ * warn if anything remains in the table at cleanup time.
+ *
+ * Note that we cannot proactively remove the domains at this stage.
+ * The ->release() function might be in a module that has already been
+ * unloaded.
+ */
+
+void auth_domain_cleanup(void)
+{
+ int h;
+ struct auth_domain *hp;
+
+ for (h = 0; h < DN_HASHMAX; h++)
+ hlist_for_each_entry(hp, &auth_domain_table[h], hash)
+ pr_warn("svc: domain %s still present at module unload.\n",
+ hp->name);
+}
diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c
new file mode 100644
index 0000000000..04b45588ae
--- /dev/null
+++ b/net/sunrpc/svcauth_unix.c
@@ -0,0 +1,1061 @@
+// SPDX-License-Identifier: GPL-2.0-only
+#include <linux/types.h>
+#include <linux/sched.h>
+#include <linux/module.h>
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/svcauth.h>
+#include <linux/sunrpc/gss_api.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/err.h>
+#include <linux/seq_file.h>
+#include <linux/hash.h>
+#include <linux/string.h>
+#include <linux/slab.h>
+#include <net/sock.h>
+#include <net/ipv6.h>
+#include <linux/kernel.h>
+#include <linux/user_namespace.h>
+#include <trace/events/sunrpc.h>
+
+#define RPCDBG_FACILITY RPCDBG_AUTH
+
+#include "netns.h"
+
+/*
+ * AUTHUNIX and AUTHNULL credentials are both handled here.
+ * AUTHNULL is treated just like AUTHUNIX except that the uid/gid
+ * are always nobody (-2). i.e. we do the same IP address checks for
+ * AUTHNULL as for AUTHUNIX, and that is done here.
+ */
+
+
+struct unix_domain {
+ struct auth_domain h;
+ /* other stuff later */
+};
+
+extern struct auth_ops svcauth_null;
+extern struct auth_ops svcauth_unix;
+extern struct auth_ops svcauth_tls;
+
+static void svcauth_unix_domain_release_rcu(struct rcu_head *head)
+{
+ struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head);
+ struct unix_domain *ud = container_of(dom, struct unix_domain, h);
+
+ kfree(dom->name);
+ kfree(ud);
+}
+
+static void svcauth_unix_domain_release(struct auth_domain *dom)
+{
+ call_rcu(&dom->rcu_head, svcauth_unix_domain_release_rcu);
+}
+
+struct auth_domain *unix_domain_find(char *name)
+{
+ struct auth_domain *rv;
+ struct unix_domain *new = NULL;
+
+ rv = auth_domain_find(name);
+ while(1) {
+ if (rv) {
+ if (new && rv != &new->h)
+ svcauth_unix_domain_release(&new->h);
+
+ if (rv->flavour != &svcauth_unix) {
+ auth_domain_put(rv);
+ return NULL;
+ }
+ return rv;
+ }
+
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (new == NULL)
+ return NULL;
+ kref_init(&new->h.ref);
+ new->h.name = kstrdup(name, GFP_KERNEL);
+ if (new->h.name == NULL) {
+ kfree(new);
+ return NULL;
+ }
+ new->h.flavour = &svcauth_unix;
+ rv = auth_domain_lookup(name, &new->h);
+ }
+}
+EXPORT_SYMBOL_GPL(unix_domain_find);
+
+
+/**************************************************
+ * cache for IP address to unix_domain
+ * as needed by AUTH_UNIX
+ */
+#define IP_HASHBITS 8
+#define IP_HASHMAX (1<<IP_HASHBITS)
+
+struct ip_map {
+ struct cache_head h;
+ char m_class[8]; /* e.g. "nfsd" */
+ struct in6_addr m_addr;
+ struct unix_domain *m_client;
+ struct rcu_head m_rcu;
+};
+
+static void ip_map_put(struct kref *kref)
+{
+ struct cache_head *item = container_of(kref, struct cache_head, ref);
+ struct ip_map *im = container_of(item, struct ip_map,h);
+
+ if (test_bit(CACHE_VALID, &item->flags) &&
+ !test_bit(CACHE_NEGATIVE, &item->flags))
+ auth_domain_put(&im->m_client->h);
+ kfree_rcu(im, m_rcu);
+}
+
+static inline int hash_ip6(const struct in6_addr *ip)
+{
+ return hash_32(ipv6_addr_hash(ip), IP_HASHBITS);
+}
+static int ip_map_match(struct cache_head *corig, struct cache_head *cnew)
+{
+ struct ip_map *orig = container_of(corig, struct ip_map, h);
+ struct ip_map *new = container_of(cnew, struct ip_map, h);
+ return strcmp(orig->m_class, new->m_class) == 0 &&
+ ipv6_addr_equal(&orig->m_addr, &new->m_addr);
+}
+static void ip_map_init(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct ip_map *new = container_of(cnew, struct ip_map, h);
+ struct ip_map *item = container_of(citem, struct ip_map, h);
+
+ strcpy(new->m_class, item->m_class);
+ new->m_addr = item->m_addr;
+}
+static void update(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct ip_map *new = container_of(cnew, struct ip_map, h);
+ struct ip_map *item = container_of(citem, struct ip_map, h);
+
+ kref_get(&item->m_client->h.ref);
+ new->m_client = item->m_client;
+}
+static struct cache_head *ip_map_alloc(void)
+{
+ struct ip_map *i = kmalloc(sizeof(*i), GFP_KERNEL);
+ if (i)
+ return &i->h;
+ else
+ return NULL;
+}
+
+static int ip_map_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+ return sunrpc_cache_pipe_upcall(cd, h);
+}
+
+static void ip_map_request(struct cache_detail *cd,
+ struct cache_head *h,
+ char **bpp, int *blen)
+{
+ char text_addr[40];
+ struct ip_map *im = container_of(h, struct ip_map, h);
+
+ if (ipv6_addr_v4mapped(&(im->m_addr))) {
+ snprintf(text_addr, 20, "%pI4", &im->m_addr.s6_addr32[3]);
+ } else {
+ snprintf(text_addr, 40, "%pI6", &im->m_addr);
+ }
+ qword_add(bpp, blen, im->m_class);
+ qword_add(bpp, blen, text_addr);
+ (*bpp)[-1] = '\n';
+}
+
+static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr);
+static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time64_t expiry);
+
+static int ip_map_parse(struct cache_detail *cd,
+ char *mesg, int mlen)
+{
+ /* class ipaddress [domainname] */
+ /* should be safe just to use the start of the input buffer
+ * for scratch: */
+ char *buf = mesg;
+ int len;
+ char class[8];
+ union {
+ struct sockaddr sa;
+ struct sockaddr_in s4;
+ struct sockaddr_in6 s6;
+ } address;
+ struct sockaddr_in6 sin6;
+ int err;
+
+ struct ip_map *ipmp;
+ struct auth_domain *dom;
+ time64_t expiry;
+
+ if (mesg[mlen-1] != '\n')
+ return -EINVAL;
+ mesg[mlen-1] = 0;
+
+ /* class */
+ len = qword_get(&mesg, class, sizeof(class));
+ if (len <= 0) return -EINVAL;
+
+ /* ip address */
+ len = qword_get(&mesg, buf, mlen);
+ if (len <= 0) return -EINVAL;
+
+ if (rpc_pton(cd->net, buf, len, &address.sa, sizeof(address)) == 0)
+ return -EINVAL;
+ switch (address.sa.sa_family) {
+ case AF_INET:
+ /* Form a mapped IPv4 address in sin6 */
+ sin6.sin6_family = AF_INET6;
+ ipv6_addr_set_v4mapped(address.s4.sin_addr.s_addr,
+ &sin6.sin6_addr);
+ break;
+#if IS_ENABLED(CONFIG_IPV6)
+ case AF_INET6:
+ memcpy(&sin6, &address.s6, sizeof(sin6));
+ break;
+#endif
+ default:
+ return -EINVAL;
+ }
+
+ err = get_expiry(&mesg, &expiry);
+ if (err)
+ return err;
+
+ /* domainname, or empty for NEGATIVE */
+ len = qword_get(&mesg, buf, mlen);
+ if (len < 0) return -EINVAL;
+
+ if (len) {
+ dom = unix_domain_find(buf);
+ if (dom == NULL)
+ return -ENOENT;
+ } else
+ dom = NULL;
+
+ /* IPv6 scope IDs are ignored for now */
+ ipmp = __ip_map_lookup(cd, class, &sin6.sin6_addr);
+ if (ipmp) {
+ err = __ip_map_update(cd, ipmp,
+ container_of(dom, struct unix_domain, h),
+ expiry);
+ } else
+ err = -ENOMEM;
+
+ if (dom)
+ auth_domain_put(dom);
+
+ cache_flush();
+ return err;
+}
+
+static int ip_map_show(struct seq_file *m,
+ struct cache_detail *cd,
+ struct cache_head *h)
+{
+ struct ip_map *im;
+ struct in6_addr addr;
+ char *dom = "-no-domain-";
+
+ if (h == NULL) {
+ seq_puts(m, "#class IP domain\n");
+ return 0;
+ }
+ im = container_of(h, struct ip_map, h);
+ /* class addr domain */
+ addr = im->m_addr;
+
+ if (test_bit(CACHE_VALID, &h->flags) &&
+ !test_bit(CACHE_NEGATIVE, &h->flags))
+ dom = im->m_client->h.name;
+
+ if (ipv6_addr_v4mapped(&addr)) {
+ seq_printf(m, "%s %pI4 %s\n",
+ im->m_class, &addr.s6_addr32[3], dom);
+ } else {
+ seq_printf(m, "%s %pI6 %s\n", im->m_class, &addr, dom);
+ }
+ return 0;
+}
+
+
+static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class,
+ struct in6_addr *addr)
+{
+ struct ip_map ip;
+ struct cache_head *ch;
+
+ strcpy(ip.m_class, class);
+ ip.m_addr = *addr;
+ ch = sunrpc_cache_lookup_rcu(cd, &ip.h,
+ hash_str(class, IP_HASHBITS) ^
+ hash_ip6(addr));
+
+ if (ch)
+ return container_of(ch, struct ip_map, h);
+ else
+ return NULL;
+}
+
+static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm,
+ struct unix_domain *udom, time64_t expiry)
+{
+ struct ip_map ip;
+ struct cache_head *ch;
+
+ ip.m_client = udom;
+ ip.h.flags = 0;
+ if (!udom)
+ set_bit(CACHE_NEGATIVE, &ip.h.flags);
+ ip.h.expiry_time = expiry;
+ ch = sunrpc_cache_update(cd, &ip.h, &ipm->h,
+ hash_str(ipm->m_class, IP_HASHBITS) ^
+ hash_ip6(&ipm->m_addr));
+ if (!ch)
+ return -ENOMEM;
+ cache_put(ch, cd);
+ return 0;
+}
+
+void svcauth_unix_purge(struct net *net)
+{
+ struct sunrpc_net *sn;
+
+ sn = net_generic(net, sunrpc_net_id);
+ cache_purge(sn->ip_map_cache);
+}
+EXPORT_SYMBOL_GPL(svcauth_unix_purge);
+
+static inline struct ip_map *
+ip_map_cached_get(struct svc_xprt *xprt)
+{
+ struct ip_map *ipm = NULL;
+ struct sunrpc_net *sn;
+
+ if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) {
+ spin_lock(&xprt->xpt_lock);
+ ipm = xprt->xpt_auth_cache;
+ if (ipm != NULL) {
+ sn = net_generic(xprt->xpt_net, sunrpc_net_id);
+ if (cache_is_expired(sn->ip_map_cache, &ipm->h)) {
+ /*
+ * The entry has been invalidated since it was
+ * remembered, e.g. by a second mount from the
+ * same IP address.
+ */
+ xprt->xpt_auth_cache = NULL;
+ spin_unlock(&xprt->xpt_lock);
+ cache_put(&ipm->h, sn->ip_map_cache);
+ return NULL;
+ }
+ cache_get(&ipm->h);
+ }
+ spin_unlock(&xprt->xpt_lock);
+ }
+ return ipm;
+}
+
+static inline void
+ip_map_cached_put(struct svc_xprt *xprt, struct ip_map *ipm)
+{
+ if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) {
+ spin_lock(&xprt->xpt_lock);
+ if (xprt->xpt_auth_cache == NULL) {
+ /* newly cached, keep the reference */
+ xprt->xpt_auth_cache = ipm;
+ ipm = NULL;
+ }
+ spin_unlock(&xprt->xpt_lock);
+ }
+ if (ipm) {
+ struct sunrpc_net *sn;
+
+ sn = net_generic(xprt->xpt_net, sunrpc_net_id);
+ cache_put(&ipm->h, sn->ip_map_cache);
+ }
+}
+
+void
+svcauth_unix_info_release(struct svc_xprt *xpt)
+{
+ struct ip_map *ipm;
+
+ ipm = xpt->xpt_auth_cache;
+ if (ipm != NULL) {
+ struct sunrpc_net *sn;
+
+ sn = net_generic(xpt->xpt_net, sunrpc_net_id);
+ cache_put(&ipm->h, sn->ip_map_cache);
+ }
+}
+
+/****************************************************************************
+ * auth.unix.gid cache
+ * simple cache to map a UID to a list of GIDs
+ * because AUTH_UNIX aka AUTH_SYS has a max of UNX_NGROUPS
+ */
+#define GID_HASHBITS 8
+#define GID_HASHMAX (1<<GID_HASHBITS)
+
+struct unix_gid {
+ struct cache_head h;
+ kuid_t uid;
+ struct group_info *gi;
+ struct rcu_head rcu;
+};
+
+static int unix_gid_hash(kuid_t uid)
+{
+ return hash_long(from_kuid(&init_user_ns, uid), GID_HASHBITS);
+}
+
+static void unix_gid_free(struct rcu_head *rcu)
+{
+ struct unix_gid *ug = container_of(rcu, struct unix_gid, rcu);
+ struct cache_head *item = &ug->h;
+
+ if (test_bit(CACHE_VALID, &item->flags) &&
+ !test_bit(CACHE_NEGATIVE, &item->flags))
+ put_group_info(ug->gi);
+ kfree(ug);
+}
+
+static void unix_gid_put(struct kref *kref)
+{
+ struct cache_head *item = container_of(kref, struct cache_head, ref);
+ struct unix_gid *ug = container_of(item, struct unix_gid, h);
+
+ call_rcu(&ug->rcu, unix_gid_free);
+}
+
+static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew)
+{
+ struct unix_gid *orig = container_of(corig, struct unix_gid, h);
+ struct unix_gid *new = container_of(cnew, struct unix_gid, h);
+ return uid_eq(orig->uid, new->uid);
+}
+static void unix_gid_init(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct unix_gid *new = container_of(cnew, struct unix_gid, h);
+ struct unix_gid *item = container_of(citem, struct unix_gid, h);
+ new->uid = item->uid;
+}
+static void unix_gid_update(struct cache_head *cnew, struct cache_head *citem)
+{
+ struct unix_gid *new = container_of(cnew, struct unix_gid, h);
+ struct unix_gid *item = container_of(citem, struct unix_gid, h);
+
+ get_group_info(item->gi);
+ new->gi = item->gi;
+}
+static struct cache_head *unix_gid_alloc(void)
+{
+ struct unix_gid *g = kmalloc(sizeof(*g), GFP_KERNEL);
+ if (g)
+ return &g->h;
+ else
+ return NULL;
+}
+
+static int unix_gid_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+ return sunrpc_cache_pipe_upcall_timeout(cd, h);
+}
+
+static void unix_gid_request(struct cache_detail *cd,
+ struct cache_head *h,
+ char **bpp, int *blen)
+{
+ char tuid[20];
+ struct unix_gid *ug = container_of(h, struct unix_gid, h);
+
+ snprintf(tuid, 20, "%u", from_kuid(&init_user_ns, ug->uid));
+ qword_add(bpp, blen, tuid);
+ (*bpp)[-1] = '\n';
+}
+
+static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid);
+
+static int unix_gid_parse(struct cache_detail *cd,
+ char *mesg, int mlen)
+{
+ /* uid expiry Ngid gid0 gid1 ... gidN-1 */
+ int id;
+ kuid_t uid;
+ int gids;
+ int rv;
+ int i;
+ int err;
+ time64_t expiry;
+ struct unix_gid ug, *ugp;
+
+ if (mesg[mlen - 1] != '\n')
+ return -EINVAL;
+ mesg[mlen-1] = 0;
+
+ rv = get_int(&mesg, &id);
+ if (rv)
+ return -EINVAL;
+ uid = make_kuid(current_user_ns(), id);
+ ug.uid = uid;
+
+ err = get_expiry(&mesg, &expiry);
+ if (err)
+ return err;
+
+ rv = get_int(&mesg, &gids);
+ if (rv || gids < 0 || gids > 8192)
+ return -EINVAL;
+
+ ug.gi = groups_alloc(gids);
+ if (!ug.gi)
+ return -ENOMEM;
+
+ for (i = 0 ; i < gids ; i++) {
+ int gid;
+ kgid_t kgid;
+ rv = get_int(&mesg, &gid);
+ err = -EINVAL;
+ if (rv)
+ goto out;
+ kgid = make_kgid(current_user_ns(), gid);
+ if (!gid_valid(kgid))
+ goto out;
+ ug.gi->gid[i] = kgid;
+ }
+
+ groups_sort(ug.gi);
+ ugp = unix_gid_lookup(cd, uid);
+ if (ugp) {
+ struct cache_head *ch;
+ ug.h.flags = 0;
+ ug.h.expiry_time = expiry;
+ ch = sunrpc_cache_update(cd,
+ &ug.h, &ugp->h,
+ unix_gid_hash(uid));
+ if (!ch)
+ err = -ENOMEM;
+ else {
+ err = 0;
+ cache_put(ch, cd);
+ }
+ } else
+ err = -ENOMEM;
+ out:
+ if (ug.gi)
+ put_group_info(ug.gi);
+ return err;
+}
+
+static int unix_gid_show(struct seq_file *m,
+ struct cache_detail *cd,
+ struct cache_head *h)
+{
+ struct user_namespace *user_ns = m->file->f_cred->user_ns;
+ struct unix_gid *ug;
+ int i;
+ int glen;
+
+ if (h == NULL) {
+ seq_puts(m, "#uid cnt: gids...\n");
+ return 0;
+ }
+ ug = container_of(h, struct unix_gid, h);
+ if (test_bit(CACHE_VALID, &h->flags) &&
+ !test_bit(CACHE_NEGATIVE, &h->flags))
+ glen = ug->gi->ngroups;
+ else
+ glen = 0;
+
+ seq_printf(m, "%u %d:", from_kuid_munged(user_ns, ug->uid), glen);
+ for (i = 0; i < glen; i++)
+ seq_printf(m, " %d", from_kgid_munged(user_ns, ug->gi->gid[i]));
+ seq_printf(m, "\n");
+ return 0;
+}
+
+static const struct cache_detail unix_gid_cache_template = {
+ .owner = THIS_MODULE,
+ .hash_size = GID_HASHMAX,
+ .name = "auth.unix.gid",
+ .cache_put = unix_gid_put,
+ .cache_upcall = unix_gid_upcall,
+ .cache_request = unix_gid_request,
+ .cache_parse = unix_gid_parse,
+ .cache_show = unix_gid_show,
+ .match = unix_gid_match,
+ .init = unix_gid_init,
+ .update = unix_gid_update,
+ .alloc = unix_gid_alloc,
+};
+
+int unix_gid_cache_create(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd;
+ int err;
+
+ cd = cache_create_net(&unix_gid_cache_template, net);
+ if (IS_ERR(cd))
+ return PTR_ERR(cd);
+ err = cache_register_net(cd, net);
+ if (err) {
+ cache_destroy_net(cd, net);
+ return err;
+ }
+ sn->unix_gid_cache = cd;
+ return 0;
+}
+
+void unix_gid_cache_destroy(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd = sn->unix_gid_cache;
+
+ sn->unix_gid_cache = NULL;
+ cache_purge(cd);
+ cache_unregister_net(cd, net);
+ cache_destroy_net(cd, net);
+}
+
+static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid)
+{
+ struct unix_gid ug;
+ struct cache_head *ch;
+
+ ug.uid = uid;
+ ch = sunrpc_cache_lookup_rcu(cd, &ug.h, unix_gid_hash(uid));
+ if (ch)
+ return container_of(ch, struct unix_gid, h);
+ else
+ return NULL;
+}
+
+static struct group_info *unix_gid_find(kuid_t uid, struct svc_rqst *rqstp)
+{
+ struct unix_gid *ug;
+ struct group_info *gi;
+ int ret;
+ struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net,
+ sunrpc_net_id);
+
+ ug = unix_gid_lookup(sn->unix_gid_cache, uid);
+ if (!ug)
+ return ERR_PTR(-EAGAIN);
+ ret = cache_check(sn->unix_gid_cache, &ug->h, &rqstp->rq_chandle);
+ switch (ret) {
+ case -ENOENT:
+ return ERR_PTR(-ENOENT);
+ case -ETIMEDOUT:
+ return ERR_PTR(-ESHUTDOWN);
+ case 0:
+ gi = get_group_info(ug->gi);
+ cache_put(&ug->h, sn->unix_gid_cache);
+ return gi;
+ default:
+ return ERR_PTR(-EAGAIN);
+ }
+}
+
+enum svc_auth_status
+svcauth_unix_set_client(struct svc_rqst *rqstp)
+{
+ struct sockaddr_in *sin;
+ struct sockaddr_in6 *sin6, sin6_storage;
+ struct ip_map *ipm;
+ struct group_info *gi;
+ struct svc_cred *cred = &rqstp->rq_cred;
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ struct net *net = xprt->xpt_net;
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+ switch (rqstp->rq_addr.ss_family) {
+ case AF_INET:
+ sin = svc_addr_in(rqstp);
+ sin6 = &sin6_storage;
+ ipv6_addr_set_v4mapped(sin->sin_addr.s_addr, &sin6->sin6_addr);
+ break;
+ case AF_INET6:
+ sin6 = svc_addr_in6(rqstp);
+ break;
+ default:
+ BUG();
+ }
+
+ rqstp->rq_client = NULL;
+ if (rqstp->rq_proc == 0)
+ goto out;
+
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ ipm = ip_map_cached_get(xprt);
+ if (ipm == NULL)
+ ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class,
+ &sin6->sin6_addr);
+
+ if (ipm == NULL)
+ return SVC_DENIED;
+
+ switch (cache_check(sn->ip_map_cache, &ipm->h, &rqstp->rq_chandle)) {
+ default:
+ BUG();
+ case -ETIMEDOUT:
+ return SVC_CLOSE;
+ case -EAGAIN:
+ return SVC_DROP;
+ case -ENOENT:
+ return SVC_DENIED;
+ case 0:
+ rqstp->rq_client = &ipm->m_client->h;
+ kref_get(&rqstp->rq_client->ref);
+ ip_map_cached_put(xprt, ipm);
+ break;
+ }
+
+ gi = unix_gid_find(cred->cr_uid, rqstp);
+ switch (PTR_ERR(gi)) {
+ case -EAGAIN:
+ return SVC_DROP;
+ case -ESHUTDOWN:
+ return SVC_CLOSE;
+ case -ENOENT:
+ break;
+ default:
+ put_group_info(cred->cr_group_info);
+ cred->cr_group_info = gi;
+ }
+
+out:
+ rqstp->rq_auth_stat = rpc_auth_ok;
+ return SVC_OK;
+}
+EXPORT_SYMBOL_GPL(svcauth_unix_set_client);
+
+/**
+ * svcauth_null_accept - Decode and validate incoming RPC_AUTH_NULL credential
+ * @rqstp: RPC transaction
+ *
+ * Return values:
+ * %SVC_OK: Both credential and verifier are valid
+ * %SVC_DENIED: Credential or verifier is not valid
+ * %SVC_GARBAGE: Failed to decode credential or verifier
+ * %SVC_CLOSE: Temporary failure
+ *
+ * rqstp->rq_auth_stat is set as mandated by RFC 5531.
+ */
+static enum svc_auth_status
+svcauth_null_accept(struct svc_rqst *rqstp)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ struct svc_cred *cred = &rqstp->rq_cred;
+ u32 flavor, len;
+ void *body;
+
+ /* Length of Call's credential body field: */
+ if (xdr_stream_decode_u32(xdr, &len) < 0)
+ return SVC_GARBAGE;
+ if (len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+ }
+
+ /* Call's verf field: */
+ if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0)
+ return SVC_GARBAGE;
+ if (flavor != RPC_AUTH_NULL || len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+
+ /* Signal that mapping to nobody uid/gid is required */
+ cred->cr_uid = INVALID_UID;
+ cred->cr_gid = INVALID_GID;
+ cred->cr_group_info = groups_alloc(0);
+ if (cred->cr_group_info == NULL)
+ return SVC_CLOSE; /* kmalloc failure - client must retry */
+
+ if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream,
+ RPC_AUTH_NULL, NULL, 0) < 0)
+ return SVC_CLOSE;
+ if (!svcxdr_set_accept_stat(rqstp))
+ return SVC_CLOSE;
+
+ rqstp->rq_cred.cr_flavor = RPC_AUTH_NULL;
+ return SVC_OK;
+}
+
+static int
+svcauth_null_release(struct svc_rqst *rqstp)
+{
+ if (rqstp->rq_client)
+ auth_domain_put(rqstp->rq_client);
+ rqstp->rq_client = NULL;
+ if (rqstp->rq_cred.cr_group_info)
+ put_group_info(rqstp->rq_cred.cr_group_info);
+ rqstp->rq_cred.cr_group_info = NULL;
+
+ return 0; /* don't drop */
+}
+
+
+struct auth_ops svcauth_null = {
+ .name = "null",
+ .owner = THIS_MODULE,
+ .flavour = RPC_AUTH_NULL,
+ .accept = svcauth_null_accept,
+ .release = svcauth_null_release,
+ .set_client = svcauth_unix_set_client,
+};
+
+
+/**
+ * svcauth_tls_accept - Decode and validate incoming RPC_AUTH_TLS credential
+ * @rqstp: RPC transaction
+ *
+ * Return values:
+ * %SVC_OK: Both credential and verifier are valid
+ * %SVC_DENIED: Credential or verifier is not valid
+ * %SVC_GARBAGE: Failed to decode credential or verifier
+ * %SVC_CLOSE: Temporary failure
+ *
+ * rqstp->rq_auth_stat is set as mandated by RFC 5531.
+ */
+static enum svc_auth_status
+svcauth_tls_accept(struct svc_rqst *rqstp)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ struct svc_cred *cred = &rqstp->rq_cred;
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ u32 flavor, len;
+ void *body;
+ __be32 *p;
+
+ /* Length of Call's credential body field: */
+ if (xdr_stream_decode_u32(xdr, &len) < 0)
+ return SVC_GARBAGE;
+ if (len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+ }
+
+ /* Call's verf field: */
+ if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0)
+ return SVC_GARBAGE;
+ if (flavor != RPC_AUTH_NULL || len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+
+ /* AUTH_TLS is not valid on non-NULL procedures */
+ if (rqstp->rq_proc != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+ }
+
+ /* Signal that mapping to nobody uid/gid is required */
+ cred->cr_uid = INVALID_UID;
+ cred->cr_gid = INVALID_GID;
+ cred->cr_group_info = groups_alloc(0);
+ if (cred->cr_group_info == NULL)
+ return SVC_CLOSE;
+
+ if (xprt->xpt_ops->xpo_handshake) {
+ p = xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2 + 8);
+ if (!p)
+ return SVC_CLOSE;
+ trace_svc_tls_start(xprt);
+ *p++ = rpc_auth_null;
+ *p++ = cpu_to_be32(8);
+ memcpy(p, "STARTTLS", 8);
+
+ set_bit(XPT_HANDSHAKE, &xprt->xpt_flags);
+ svc_xprt_enqueue(xprt);
+ } else {
+ trace_svc_tls_unavailable(xprt);
+ if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream,
+ RPC_AUTH_NULL, NULL, 0) < 0)
+ return SVC_CLOSE;
+ }
+ if (!svcxdr_set_accept_stat(rqstp))
+ return SVC_CLOSE;
+
+ rqstp->rq_cred.cr_flavor = RPC_AUTH_TLS;
+ return SVC_OK;
+}
+
+struct auth_ops svcauth_tls = {
+ .name = "tls",
+ .owner = THIS_MODULE,
+ .flavour = RPC_AUTH_TLS,
+ .accept = svcauth_tls_accept,
+ .release = svcauth_null_release,
+ .set_client = svcauth_unix_set_client,
+};
+
+
+/**
+ * svcauth_unix_accept - Decode and validate incoming RPC_AUTH_SYS credential
+ * @rqstp: RPC transaction
+ *
+ * Return values:
+ * %SVC_OK: Both credential and verifier are valid
+ * %SVC_DENIED: Credential or verifier is not valid
+ * %SVC_GARBAGE: Failed to decode credential or verifier
+ * %SVC_CLOSE: Temporary failure
+ *
+ * rqstp->rq_auth_stat is set as mandated by RFC 5531.
+ */
+static enum svc_auth_status
+svcauth_unix_accept(struct svc_rqst *rqstp)
+{
+ struct xdr_stream *xdr = &rqstp->rq_arg_stream;
+ struct svc_cred *cred = &rqstp->rq_cred;
+ struct user_namespace *userns;
+ u32 flavor, len, i;
+ void *body;
+ __be32 *p;
+
+ /*
+ * This implementation ignores the length of the Call's
+ * credential body field and the timestamp and machinename
+ * fields.
+ */
+ p = xdr_inline_decode(xdr, XDR_UNIT * 3);
+ if (!p)
+ return SVC_GARBAGE;
+ len = be32_to_cpup(p + 2);
+ if (len > RPC_MAX_MACHINENAME)
+ return SVC_GARBAGE;
+ if (!xdr_inline_decode(xdr, len))
+ return SVC_GARBAGE;
+
+ /*
+ * Note: we skip uid_valid()/gid_valid() checks here for
+ * backwards compatibility with clients that use -1 id's.
+ * Instead, -1 uid or gid is later mapped to the
+ * (export-specific) anonymous id by nfsd_setuser.
+ * Supplementary gid's will be left alone.
+ */
+ userns = (rqstp->rq_xprt && rqstp->rq_xprt->xpt_cred) ?
+ rqstp->rq_xprt->xpt_cred->user_ns : &init_user_ns;
+ if (xdr_stream_decode_u32(xdr, &i) < 0)
+ return SVC_GARBAGE;
+ cred->cr_uid = make_kuid(userns, i);
+ if (xdr_stream_decode_u32(xdr, &i) < 0)
+ return SVC_GARBAGE;
+ cred->cr_gid = make_kgid(userns, i);
+
+ if (xdr_stream_decode_u32(xdr, &len) < 0)
+ return SVC_GARBAGE;
+ if (len > UNX_NGROUPS)
+ goto badcred;
+ p = xdr_inline_decode(xdr, XDR_UNIT * len);
+ if (!p)
+ return SVC_GARBAGE;
+ cred->cr_group_info = groups_alloc(len);
+ if (cred->cr_group_info == NULL)
+ return SVC_CLOSE;
+ for (i = 0; i < len; i++) {
+ kgid_t kgid = make_kgid(userns, be32_to_cpup(p++));
+ cred->cr_group_info->gid[i] = kgid;
+ }
+ groups_sort(cred->cr_group_info);
+
+ /* Call's verf field: */
+ if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0)
+ return SVC_GARBAGE;
+ if (flavor != RPC_AUTH_NULL || len != 0) {
+ rqstp->rq_auth_stat = rpc_autherr_badverf;
+ return SVC_DENIED;
+ }
+
+ if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream,
+ RPC_AUTH_NULL, NULL, 0) < 0)
+ return SVC_CLOSE;
+ if (!svcxdr_set_accept_stat(rqstp))
+ return SVC_CLOSE;
+
+ rqstp->rq_cred.cr_flavor = RPC_AUTH_UNIX;
+ return SVC_OK;
+
+badcred:
+ rqstp->rq_auth_stat = rpc_autherr_badcred;
+ return SVC_DENIED;
+}
+
+static int
+svcauth_unix_release(struct svc_rqst *rqstp)
+{
+ /* Verifier (such as it is) is already in place.
+ */
+ if (rqstp->rq_client)
+ auth_domain_put(rqstp->rq_client);
+ rqstp->rq_client = NULL;
+ if (rqstp->rq_cred.cr_group_info)
+ put_group_info(rqstp->rq_cred.cr_group_info);
+ rqstp->rq_cred.cr_group_info = NULL;
+
+ return 0;
+}
+
+
+struct auth_ops svcauth_unix = {
+ .name = "unix",
+ .owner = THIS_MODULE,
+ .flavour = RPC_AUTH_UNIX,
+ .accept = svcauth_unix_accept,
+ .release = svcauth_unix_release,
+ .domain_release = svcauth_unix_domain_release,
+ .set_client = svcauth_unix_set_client,
+};
+
+static const struct cache_detail ip_map_cache_template = {
+ .owner = THIS_MODULE,
+ .hash_size = IP_HASHMAX,
+ .name = "auth.unix.ip",
+ .cache_put = ip_map_put,
+ .cache_upcall = ip_map_upcall,
+ .cache_request = ip_map_request,
+ .cache_parse = ip_map_parse,
+ .cache_show = ip_map_show,
+ .match = ip_map_match,
+ .init = ip_map_init,
+ .update = update,
+ .alloc = ip_map_alloc,
+};
+
+int ip_map_cache_create(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd;
+ int err;
+
+ cd = cache_create_net(&ip_map_cache_template, net);
+ if (IS_ERR(cd))
+ return PTR_ERR(cd);
+ err = cache_register_net(cd, net);
+ if (err) {
+ cache_destroy_net(cd, net);
+ return err;
+ }
+ sn->ip_map_cache = cd;
+ return 0;
+}
+
+void ip_map_cache_destroy(struct net *net)
+{
+ struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ struct cache_detail *cd = sn->ip_map_cache;
+
+ sn->ip_map_cache = NULL;
+ cache_purge(cd);
+ cache_unregister_net(cd, net);
+ cache_destroy_net(cd, net);
+}
diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
new file mode 100644
index 0000000000..e0ce427627
--- /dev/null
+++ b/net/sunrpc/svcsock.c
@@ -0,0 +1,1644 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/svcsock.c
+ *
+ * These are the RPC server socket internals.
+ *
+ * The server scheduling algorithm does not always distribute the load
+ * evenly when servicing a single client. May need to modify the
+ * svc_xprt_enqueue procedure...
+ *
+ * TCP support is largely untested and may be a little slow. The problem
+ * is that we currently do two separate recvfrom's, one for the 4-byte
+ * record length, and the second for the actual record. This could possibly
+ * be improved by always reading a minimum size of around 100 bytes and
+ * tucking any superfluous bytes away in a temporary store. Still, that
+ * leaves write requests out in the rain. An alternative may be to peek at
+ * the first skb in the queue, and if it matches the next TCP sequence
+ * number, to extract the record marker. Yuck.
+ *
+ * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/kernel.h>
+#include <linux/sched.h>
+#include <linux/module.h>
+#include <linux/errno.h>
+#include <linux/fcntl.h>
+#include <linux/net.h>
+#include <linux/in.h>
+#include <linux/inet.h>
+#include <linux/udp.h>
+#include <linux/tcp.h>
+#include <linux/unistd.h>
+#include <linux/slab.h>
+#include <linux/netdevice.h>
+#include <linux/skbuff.h>
+#include <linux/file.h>
+#include <linux/freezer.h>
+#include <linux/bvec.h>
+
+#include <net/sock.h>
+#include <net/checksum.h>
+#include <net/ip.h>
+#include <net/ipv6.h>
+#include <net/udp.h>
+#include <net/tcp.h>
+#include <net/tcp_states.h>
+#include <net/tls_prot.h>
+#include <net/handshake.h>
+#include <linux/uaccess.h>
+#include <linux/highmem.h>
+#include <asm/ioctls.h>
+#include <linux/key.h>
+
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/msg_prot.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/stats.h>
+#include <linux/sunrpc/xprt.h>
+
+#include <trace/events/sock.h>
+#include <trace/events/sunrpc.h>
+
+#include "socklib.h"
+#include "sunrpc.h"
+
+#define RPCDBG_FACILITY RPCDBG_SVCXPRT
+
+/* To-do: to avoid tying up an nfsd thread while waiting for a
+ * handshake request, the request could instead be deferred.
+ */
+enum {
+ SVC_HANDSHAKE_TO = 5U * HZ
+};
+
+static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *,
+ int flags);
+static int svc_udp_recvfrom(struct svc_rqst *);
+static int svc_udp_sendto(struct svc_rqst *);
+static void svc_sock_detach(struct svc_xprt *);
+static void svc_tcp_sock_detach(struct svc_xprt *);
+static void svc_sock_free(struct svc_xprt *);
+
+static struct svc_xprt *svc_create_socket(struct svc_serv *, int,
+ struct net *, struct sockaddr *,
+ int, int);
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+static struct lock_class_key svc_key[2];
+static struct lock_class_key svc_slock_key[2];
+
+static void svc_reclassify_socket(struct socket *sock)
+{
+ struct sock *sk = sock->sk;
+
+ if (WARN_ON_ONCE(!sock_allow_reclassification(sk)))
+ return;
+
+ switch (sk->sk_family) {
+ case AF_INET:
+ sock_lock_init_class_and_name(sk, "slock-AF_INET-NFSD",
+ &svc_slock_key[0],
+ "sk_xprt.xpt_lock-AF_INET-NFSD",
+ &svc_key[0]);
+ break;
+
+ case AF_INET6:
+ sock_lock_init_class_and_name(sk, "slock-AF_INET6-NFSD",
+ &svc_slock_key[1],
+ "sk_xprt.xpt_lock-AF_INET6-NFSD",
+ &svc_key[1]);
+ break;
+
+ default:
+ BUG();
+ }
+}
+#else
+static void svc_reclassify_socket(struct socket *sock)
+{
+}
+#endif
+
+/**
+ * svc_tcp_release_ctxt - Release transport-related resources
+ * @xprt: the transport which owned the context
+ * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt
+ *
+ */
+static void svc_tcp_release_ctxt(struct svc_xprt *xprt, void *ctxt)
+{
+}
+
+/**
+ * svc_udp_release_ctxt - Release transport-related resources
+ * @xprt: the transport which owned the context
+ * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt
+ *
+ */
+static void svc_udp_release_ctxt(struct svc_xprt *xprt, void *ctxt)
+{
+ struct sk_buff *skb = ctxt;
+
+ if (skb)
+ consume_skb(skb);
+}
+
+union svc_pktinfo_u {
+ struct in_pktinfo pkti;
+ struct in6_pktinfo pkti6;
+};
+#define SVC_PKTINFO_SPACE \
+ CMSG_SPACE(sizeof(union svc_pktinfo_u))
+
+static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh)
+{
+ struct svc_sock *svsk =
+ container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt);
+ switch (svsk->sk_sk->sk_family) {
+ case AF_INET: {
+ struct in_pktinfo *pki = CMSG_DATA(cmh);
+
+ cmh->cmsg_level = SOL_IP;
+ cmh->cmsg_type = IP_PKTINFO;
+ pki->ipi_ifindex = 0;
+ pki->ipi_spec_dst.s_addr =
+ svc_daddr_in(rqstp)->sin_addr.s_addr;
+ cmh->cmsg_len = CMSG_LEN(sizeof(*pki));
+ }
+ break;
+
+ case AF_INET6: {
+ struct in6_pktinfo *pki = CMSG_DATA(cmh);
+ struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp);
+
+ cmh->cmsg_level = SOL_IPV6;
+ cmh->cmsg_type = IPV6_PKTINFO;
+ pki->ipi6_ifindex = daddr->sin6_scope_id;
+ pki->ipi6_addr = daddr->sin6_addr;
+ cmh->cmsg_len = CMSG_LEN(sizeof(*pki));
+ }
+ break;
+ }
+}
+
+static int svc_sock_result_payload(struct svc_rqst *rqstp, unsigned int offset,
+ unsigned int length)
+{
+ return 0;
+}
+
+/*
+ * Report socket names for nfsdfs
+ */
+static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining)
+{
+ const struct sock *sk = svsk->sk_sk;
+ const char *proto_name = sk->sk_protocol == IPPROTO_UDP ?
+ "udp" : "tcp";
+ int len;
+
+ switch (sk->sk_family) {
+ case PF_INET:
+ len = snprintf(buf, remaining, "ipv4 %s %pI4 %d\n",
+ proto_name,
+ &inet_sk(sk)->inet_rcv_saddr,
+ inet_sk(sk)->inet_num);
+ break;
+#if IS_ENABLED(CONFIG_IPV6)
+ case PF_INET6:
+ len = snprintf(buf, remaining, "ipv6 %s %pI6 %d\n",
+ proto_name,
+ &sk->sk_v6_rcv_saddr,
+ inet_sk(sk)->inet_num);
+ break;
+#endif
+ default:
+ len = snprintf(buf, remaining, "*unknown-%d*\n",
+ sk->sk_family);
+ }
+
+ if (len >= remaining) {
+ *buf = '\0';
+ return -ENAMETOOLONG;
+ }
+ return len;
+}
+
+static int
+svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
+ struct cmsghdr *cmsg, int ret)
+{
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ msg->msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -ENOTCONN : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
+ }
+ return ret;
+}
+
+static int
+svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg)
+{
+ union {
+ struct cmsghdr cmsg;
+ u8 buf[CMSG_SPACE(sizeof(u8))];
+ } u;
+ struct socket *sock = svsk->sk_sock;
+ int ret;
+
+ msg->msg_control = &u;
+ msg->msg_controllen = sizeof(u);
+ ret = sock_recvmsg(sock, msg, MSG_DONTWAIT);
+ if (unlikely(msg->msg_controllen != sizeof(u)))
+ ret = svc_tcp_sock_process_cmsg(sock, msg, &u.cmsg, ret);
+ return ret;
+}
+
+#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
+static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek)
+{
+ struct bvec_iter bi = {
+ .bi_size = size + seek,
+ };
+ struct bio_vec bv;
+
+ bvec_iter_advance(bvec, &bi, seek & PAGE_MASK);
+ for_each_bvec(bv, bvec, bi, bi)
+ flush_dcache_page(bv.bv_page);
+}
+#else
+static inline void svc_flush_bvec(const struct bio_vec *bvec, size_t size,
+ size_t seek)
+{
+}
+#endif
+
+/*
+ * Read from @rqstp's transport socket. The incoming message fills whole
+ * pages in @rqstp's rq_pages array until the last page of the message
+ * has been received into a partial page.
+ */
+static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen,
+ size_t seek)
+{
+ struct svc_sock *svsk =
+ container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt);
+ struct bio_vec *bvec = rqstp->rq_bvec;
+ struct msghdr msg = { NULL };
+ unsigned int i;
+ ssize_t len;
+ size_t t;
+
+ clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+
+ for (i = 0, t = 0; t < buflen; i++, t += PAGE_SIZE)
+ bvec_set_page(&bvec[i], rqstp->rq_pages[i], PAGE_SIZE, 0);
+ rqstp->rq_respages = &rqstp->rq_pages[i];
+ rqstp->rq_next_page = rqstp->rq_respages + 1;
+
+ iov_iter_bvec(&msg.msg_iter, ITER_DEST, bvec, i, buflen);
+ if (seek) {
+ iov_iter_advance(&msg.msg_iter, seek);
+ buflen -= seek;
+ }
+ len = svc_tcp_sock_recv_cmsg(svsk, &msg);
+ if (len > 0)
+ svc_flush_bvec(bvec, len, seek);
+
+ /* If we read a full record, then assume there may be more
+ * data to read (stream based sockets only!)
+ */
+ if (len == buflen)
+ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+
+ return len;
+}
+
+/*
+ * Set socket snd and rcv buffer lengths
+ */
+static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs)
+{
+ unsigned int max_mesg = svsk->sk_xprt.xpt_server->sv_max_mesg;
+ struct socket *sock = svsk->sk_sock;
+
+ nreqs = min(nreqs, INT_MAX / 2 / max_mesg);
+
+ lock_sock(sock->sk);
+ sock->sk->sk_sndbuf = nreqs * max_mesg * 2;
+ sock->sk->sk_rcvbuf = nreqs * max_mesg * 2;
+ sock->sk->sk_write_space(sock->sk);
+ release_sock(sock->sk);
+}
+
+static void svc_sock_secure_port(struct svc_rqst *rqstp)
+{
+ if (svc_port_is_privileged(svc_addr(rqstp)))
+ set_bit(RQ_SECURE, &rqstp->rq_flags);
+ else
+ clear_bit(RQ_SECURE, &rqstp->rq_flags);
+}
+
+/*
+ * INET callback when data has been received on the socket.
+ */
+static void svc_data_ready(struct sock *sk)
+{
+ struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data;
+
+ trace_sk_data_ready(sk);
+
+ if (svsk) {
+ /* Refer to svc_setup_socket() for details. */
+ rmb();
+ svsk->sk_odata(sk);
+ trace_svcsock_data_ready(&svsk->sk_xprt, 0);
+ if (test_bit(XPT_HANDSHAKE, &svsk->sk_xprt.xpt_flags))
+ return;
+ if (!test_and_set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags))
+ svc_xprt_enqueue(&svsk->sk_xprt);
+ }
+}
+
+/*
+ * INET callback when space is newly available on the socket.
+ */
+static void svc_write_space(struct sock *sk)
+{
+ struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data);
+
+ if (svsk) {
+ /* Refer to svc_setup_socket() for details. */
+ rmb();
+ trace_svcsock_write_space(&svsk->sk_xprt, 0);
+ svsk->sk_owspace(sk);
+ svc_xprt_enqueue(&svsk->sk_xprt);
+ }
+}
+
+static int svc_tcp_has_wspace(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+
+ if (test_bit(XPT_LISTENER, &xprt->xpt_flags))
+ return 1;
+ return !test_bit(SOCK_NOSPACE, &svsk->sk_sock->flags);
+}
+
+static void svc_tcp_kill_temp_xprt(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+
+ sock_no_linger(svsk->sk_sock->sk);
+}
+
+/**
+ * svc_tcp_handshake_done - Handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote peer's identity
+ *
+ * If a security policy is specified as an export option, we don't
+ * have a specific export here to check. So we set a "TLS session
+ * is present" flag on the xprt and let an upper layer enforce local
+ * security policy.
+ */
+static void svc_tcp_handshake_done(void *data, int status, key_serial_t peerid)
+{
+ struct svc_xprt *xprt = data;
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+
+ if (!status) {
+ if (peerid != TLS_NO_PEERID)
+ set_bit(XPT_PEER_AUTH, &xprt->xpt_flags);
+ set_bit(XPT_TLS_SESSION, &xprt->xpt_flags);
+ }
+ clear_bit(XPT_HANDSHAKE, &xprt->xpt_flags);
+ complete_all(&svsk->sk_handshake_done);
+}
+
+/**
+ * svc_tcp_handshake - Perform a transport-layer security handshake
+ * @xprt: connected transport endpoint
+ *
+ */
+static void svc_tcp_handshake(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct sock *sk = svsk->sk_sock->sk;
+ struct tls_handshake_args args = {
+ .ta_sock = svsk->sk_sock,
+ .ta_done = svc_tcp_handshake_done,
+ .ta_data = xprt,
+ };
+ int ret;
+
+ trace_svc_tls_upcall(xprt);
+
+ clear_bit(XPT_TLS_SESSION, &xprt->xpt_flags);
+ init_completion(&svsk->sk_handshake_done);
+
+ ret = tls_server_hello_x509(&args, GFP_KERNEL);
+ if (ret) {
+ trace_svc_tls_not_started(xprt);
+ goto out_failed;
+ }
+
+ ret = wait_for_completion_interruptible_timeout(&svsk->sk_handshake_done,
+ SVC_HANDSHAKE_TO);
+ if (ret <= 0) {
+ if (tls_handshake_cancel(sk)) {
+ trace_svc_tls_timed_out(xprt);
+ goto out_close;
+ }
+ }
+
+ if (!test_bit(XPT_TLS_SESSION, &xprt->xpt_flags)) {
+ trace_svc_tls_unavailable(xprt);
+ goto out_close;
+ }
+
+ /* Mark the transport ready in case the remote sent RPC
+ * traffic before the kernel received the handshake
+ * completion downcall.
+ */
+ set_bit(XPT_DATA, &xprt->xpt_flags);
+ svc_xprt_enqueue(xprt);
+ return;
+
+out_close:
+ set_bit(XPT_CLOSE, &xprt->xpt_flags);
+out_failed:
+ clear_bit(XPT_HANDSHAKE, &xprt->xpt_flags);
+ set_bit(XPT_DATA, &xprt->xpt_flags);
+ svc_xprt_enqueue(xprt);
+}
+
+/*
+ * See net/ipv6/ip_sockglue.c : ip_cmsg_recv_pktinfo
+ */
+static int svc_udp_get_dest_address4(struct svc_rqst *rqstp,
+ struct cmsghdr *cmh)
+{
+ struct in_pktinfo *pki = CMSG_DATA(cmh);
+ struct sockaddr_in *daddr = svc_daddr_in(rqstp);
+
+ if (cmh->cmsg_type != IP_PKTINFO)
+ return 0;
+
+ daddr->sin_family = AF_INET;
+ daddr->sin_addr.s_addr = pki->ipi_spec_dst.s_addr;
+ return 1;
+}
+
+/*
+ * See net/ipv6/datagram.c : ip6_datagram_recv_ctl
+ */
+static int svc_udp_get_dest_address6(struct svc_rqst *rqstp,
+ struct cmsghdr *cmh)
+{
+ struct in6_pktinfo *pki = CMSG_DATA(cmh);
+ struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp);
+
+ if (cmh->cmsg_type != IPV6_PKTINFO)
+ return 0;
+
+ daddr->sin6_family = AF_INET6;
+ daddr->sin6_addr = pki->ipi6_addr;
+ daddr->sin6_scope_id = pki->ipi6_ifindex;
+ return 1;
+}
+
+/*
+ * Copy the UDP datagram's destination address to the rqstp structure.
+ * The 'destination' address in this case is the address to which the
+ * peer sent the datagram, i.e. our local address. For multihomed
+ * hosts, this can change from msg to msg. Note that only the IP
+ * address changes, the port number should remain the same.
+ */
+static int svc_udp_get_dest_address(struct svc_rqst *rqstp,
+ struct cmsghdr *cmh)
+{
+ switch (cmh->cmsg_level) {
+ case SOL_IP:
+ return svc_udp_get_dest_address4(rqstp, cmh);
+ case SOL_IPV6:
+ return svc_udp_get_dest_address6(rqstp, cmh);
+ }
+
+ return 0;
+}
+
+/**
+ * svc_udp_recvfrom - Receive a datagram from a UDP socket.
+ * @rqstp: request structure into which to receive an RPC Call
+ *
+ * Called in a loop when XPT_DATA has been set.
+ *
+ * Returns:
+ * On success, the number of bytes in a received RPC Call, or
+ * %0 if a complete RPC Call message was not ready to return
+ */
+static int svc_udp_recvfrom(struct svc_rqst *rqstp)
+{
+ struct svc_sock *svsk =
+ container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt);
+ struct svc_serv *serv = svsk->sk_xprt.xpt_server;
+ struct sk_buff *skb;
+ union {
+ struct cmsghdr hdr;
+ long all[SVC_PKTINFO_SPACE / sizeof(long)];
+ } buffer;
+ struct cmsghdr *cmh = &buffer.hdr;
+ struct msghdr msg = {
+ .msg_name = svc_addr(rqstp),
+ .msg_control = cmh,
+ .msg_controllen = sizeof(buffer),
+ .msg_flags = MSG_DONTWAIT,
+ };
+ size_t len;
+ int err;
+
+ if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags))
+ /* udp sockets need large rcvbuf as all pending
+ * requests are still in that buffer. sndbuf must
+ * also be large enough that there is enough space
+ * for one reply per thread. We count all threads
+ * rather than threads in a particular pool, which
+ * provides an upper bound on the number of threads
+ * which will access the socket.
+ */
+ svc_sock_setbufsize(svsk, serv->sv_nrthreads + 3);
+
+ clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+ err = kernel_recvmsg(svsk->sk_sock, &msg, NULL,
+ 0, 0, MSG_PEEK | MSG_DONTWAIT);
+ if (err < 0)
+ goto out_recv_err;
+ skb = skb_recv_udp(svsk->sk_sk, MSG_DONTWAIT, &err);
+ if (!skb)
+ goto out_recv_err;
+
+ len = svc_addr_len(svc_addr(rqstp));
+ rqstp->rq_addrlen = len;
+ if (skb->tstamp == 0) {
+ skb->tstamp = ktime_get_real();
+ /* Don't enable netstamp, sunrpc doesn't
+ need that much accuracy */
+ }
+ sock_write_timestamp(svsk->sk_sk, skb->tstamp);
+ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */
+
+ len = skb->len;
+ rqstp->rq_arg.len = len;
+ trace_svcsock_udp_recv(&svsk->sk_xprt, len);
+
+ rqstp->rq_prot = IPPROTO_UDP;
+
+ if (!svc_udp_get_dest_address(rqstp, cmh))
+ goto out_cmsg_err;
+ rqstp->rq_daddrlen = svc_addr_len(svc_daddr(rqstp));
+
+ if (skb_is_nonlinear(skb)) {
+ /* we have to copy */
+ local_bh_disable();
+ if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb))
+ goto out_bh_enable;
+ local_bh_enable();
+ consume_skb(skb);
+ } else {
+ /* we can use it in-place */
+ rqstp->rq_arg.head[0].iov_base = skb->data;
+ rqstp->rq_arg.head[0].iov_len = len;
+ if (skb_checksum_complete(skb))
+ goto out_free;
+ rqstp->rq_xprt_ctxt = skb;
+ }
+
+ rqstp->rq_arg.page_base = 0;
+ if (len <= rqstp->rq_arg.head[0].iov_len) {
+ rqstp->rq_arg.head[0].iov_len = len;
+ rqstp->rq_arg.page_len = 0;
+ rqstp->rq_respages = rqstp->rq_pages+1;
+ } else {
+ rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len;
+ rqstp->rq_respages = rqstp->rq_pages + 1 +
+ DIV_ROUND_UP(rqstp->rq_arg.page_len, PAGE_SIZE);
+ }
+ rqstp->rq_next_page = rqstp->rq_respages+1;
+
+ if (serv->sv_stats)
+ serv->sv_stats->netudpcnt++;
+
+ svc_sock_secure_port(rqstp);
+ svc_xprt_received(rqstp->rq_xprt);
+ return len;
+
+out_recv_err:
+ if (err != -EAGAIN) {
+ /* possibly an icmp error */
+ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+ }
+ trace_svcsock_udp_recv_err(&svsk->sk_xprt, err);
+ goto out_clear_busy;
+out_cmsg_err:
+ net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n",
+ cmh->cmsg_level, cmh->cmsg_type);
+ goto out_free;
+out_bh_enable:
+ local_bh_enable();
+out_free:
+ kfree_skb(skb);
+out_clear_busy:
+ svc_xprt_received(rqstp->rq_xprt);
+ return 0;
+}
+
+/**
+ * svc_udp_sendto - Send out a reply on a UDP socket
+ * @rqstp: completed svc_rqst
+ *
+ * xpt_mutex ensures @rqstp's whole message is written to the socket
+ * without interruption.
+ *
+ * Returns the number of bytes sent, or a negative errno.
+ */
+static int svc_udp_sendto(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct xdr_buf *xdr = &rqstp->rq_res;
+ union {
+ struct cmsghdr hdr;
+ long all[SVC_PKTINFO_SPACE / sizeof(long)];
+ } buffer;
+ struct cmsghdr *cmh = &buffer.hdr;
+ struct msghdr msg = {
+ .msg_name = &rqstp->rq_addr,
+ .msg_namelen = rqstp->rq_addrlen,
+ .msg_control = cmh,
+ .msg_flags = MSG_SPLICE_PAGES,
+ .msg_controllen = sizeof(buffer),
+ };
+ unsigned int count;
+ int err;
+
+ svc_udp_release_ctxt(xprt, rqstp->rq_xprt_ctxt);
+ rqstp->rq_xprt_ctxt = NULL;
+
+ svc_set_cmsg_data(rqstp, cmh);
+
+ mutex_lock(&xprt->xpt_mutex);
+
+ if (svc_xprt_is_dead(xprt))
+ goto out_notconn;
+
+ count = xdr_buf_to_bvec(rqstp->rq_bvec,
+ ARRAY_SIZE(rqstp->rq_bvec), xdr);
+
+ iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, rqstp->rq_bvec,
+ count, rqstp->rq_res.len);
+ err = sock_sendmsg(svsk->sk_sock, &msg);
+ if (err == -ECONNREFUSED) {
+ /* ICMP error on earlier request. */
+ iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, rqstp->rq_bvec,
+ count, rqstp->rq_res.len);
+ err = sock_sendmsg(svsk->sk_sock, &msg);
+ }
+
+ trace_svcsock_udp_send(xprt, err);
+
+ mutex_unlock(&xprt->xpt_mutex);
+ return err;
+
+out_notconn:
+ mutex_unlock(&xprt->xpt_mutex);
+ return -ENOTCONN;
+}
+
+static int svc_udp_has_wspace(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct svc_serv *serv = xprt->xpt_server;
+ unsigned long required;
+
+ /*
+ * Set the SOCK_NOSPACE flag before checking the available
+ * sock space.
+ */
+ set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags);
+ required = atomic_read(&svsk->sk_xprt.xpt_reserved) + serv->sv_max_mesg;
+ if (required*2 > sock_wspace(svsk->sk_sk))
+ return 0;
+ clear_bit(SOCK_NOSPACE, &svsk->sk_sock->flags);
+ return 1;
+}
+
+static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt)
+{
+ BUG();
+ return NULL;
+}
+
+static void svc_udp_kill_temp_xprt(struct svc_xprt *xprt)
+{
+}
+
+static struct svc_xprt *svc_udp_create(struct svc_serv *serv,
+ struct net *net,
+ struct sockaddr *sa, int salen,
+ int flags)
+{
+ return svc_create_socket(serv, IPPROTO_UDP, net, sa, salen, flags);
+}
+
+static const struct svc_xprt_ops svc_udp_ops = {
+ .xpo_create = svc_udp_create,
+ .xpo_recvfrom = svc_udp_recvfrom,
+ .xpo_sendto = svc_udp_sendto,
+ .xpo_result_payload = svc_sock_result_payload,
+ .xpo_release_ctxt = svc_udp_release_ctxt,
+ .xpo_detach = svc_sock_detach,
+ .xpo_free = svc_sock_free,
+ .xpo_has_wspace = svc_udp_has_wspace,
+ .xpo_accept = svc_udp_accept,
+ .xpo_kill_temp_xprt = svc_udp_kill_temp_xprt,
+};
+
+static struct svc_xprt_class svc_udp_class = {
+ .xcl_name = "udp",
+ .xcl_owner = THIS_MODULE,
+ .xcl_ops = &svc_udp_ops,
+ .xcl_max_payload = RPCSVC_MAXPAYLOAD_UDP,
+ .xcl_ident = XPRT_TRANSPORT_UDP,
+};
+
+static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv)
+{
+ svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_udp_class,
+ &svsk->sk_xprt, serv);
+ clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags);
+ svsk->sk_sk->sk_data_ready = svc_data_ready;
+ svsk->sk_sk->sk_write_space = svc_write_space;
+
+ /* initialise setting must have enough space to
+ * receive and respond to one request.
+ * svc_udp_recvfrom will re-adjust if necessary
+ */
+ svc_sock_setbufsize(svsk, 3);
+
+ /* data might have come in before data_ready set up */
+ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+ set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags);
+
+ /* make sure we get destination address info */
+ switch (svsk->sk_sk->sk_family) {
+ case AF_INET:
+ ip_sock_set_pktinfo(svsk->sk_sock->sk);
+ break;
+ case AF_INET6:
+ ip6_sock_set_recvpktinfo(svsk->sk_sock->sk);
+ break;
+ default:
+ BUG();
+ }
+}
+
+/*
+ * A data_ready event on a listening socket means there's a connection
+ * pending. Do not use state_change as a substitute for it.
+ */
+static void svc_tcp_listen_data_ready(struct sock *sk)
+{
+ struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data;
+
+ trace_sk_data_ready(sk);
+
+ /*
+ * This callback may called twice when a new connection
+ * is established as a child socket inherits everything
+ * from a parent LISTEN socket.
+ * 1) data_ready method of the parent socket will be called
+ * when one of child sockets become ESTABLISHED.
+ * 2) data_ready method of the child socket may be called
+ * when it receives data before the socket is accepted.
+ * In case of 2, we should ignore it silently and DO NOT
+ * dereference svsk.
+ */
+ if (sk->sk_state != TCP_LISTEN)
+ return;
+
+ if (svsk) {
+ /* Refer to svc_setup_socket() for details. */
+ rmb();
+ svsk->sk_odata(sk);
+ set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags);
+ svc_xprt_enqueue(&svsk->sk_xprt);
+ }
+}
+
+/*
+ * A state change on a connected socket means it's dying or dead.
+ */
+static void svc_tcp_state_change(struct sock *sk)
+{
+ struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data;
+
+ if (svsk) {
+ /* Refer to svc_setup_socket() for details. */
+ rmb();
+ svsk->sk_ostate(sk);
+ trace_svcsock_tcp_state(&svsk->sk_xprt, svsk->sk_sock);
+ if (sk->sk_state != TCP_ESTABLISHED)
+ svc_xprt_deferred_close(&svsk->sk_xprt);
+ }
+}
+
+/*
+ * Accept a TCP connection
+ */
+static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct sockaddr_storage addr;
+ struct sockaddr *sin = (struct sockaddr *) &addr;
+ struct svc_serv *serv = svsk->sk_xprt.xpt_server;
+ struct socket *sock = svsk->sk_sock;
+ struct socket *newsock;
+ struct svc_sock *newsvsk;
+ int err, slen;
+
+ if (!sock)
+ return NULL;
+
+ clear_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags);
+ err = kernel_accept(sock, &newsock, O_NONBLOCK);
+ if (err < 0) {
+ if (err != -EAGAIN)
+ trace_svcsock_accept_err(xprt, serv->sv_name, err);
+ return NULL;
+ }
+ if (IS_ERR(sock_alloc_file(newsock, O_NONBLOCK, NULL)))
+ return NULL;
+
+ set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags);
+
+ err = kernel_getpeername(newsock, sin);
+ if (err < 0) {
+ trace_svcsock_getpeername_err(xprt, serv->sv_name, err);
+ goto failed; /* aborted connection or whatever */
+ }
+ slen = err;
+
+ /* Reset the inherited callbacks before calling svc_setup_socket */
+ newsock->sk->sk_state_change = svsk->sk_ostate;
+ newsock->sk->sk_data_ready = svsk->sk_odata;
+ newsock->sk->sk_write_space = svsk->sk_owspace;
+
+ /* make sure that a write doesn't block forever when
+ * low on memory
+ */
+ newsock->sk->sk_sndtimeo = HZ*30;
+
+ newsvsk = svc_setup_socket(serv, newsock,
+ (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY));
+ if (IS_ERR(newsvsk))
+ goto failed;
+ svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen);
+ err = kernel_getsockname(newsock, sin);
+ slen = err;
+ if (unlikely(err < 0))
+ slen = offsetof(struct sockaddr, sa_data);
+ svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen);
+
+ if (sock_is_loopback(newsock->sk))
+ set_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags);
+ else
+ clear_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags);
+ if (serv->sv_stats)
+ serv->sv_stats->nettcpconn++;
+
+ return &newsvsk->sk_xprt;
+
+failed:
+ sockfd_put(newsock);
+ return NULL;
+}
+
+static size_t svc_tcp_restore_pages(struct svc_sock *svsk,
+ struct svc_rqst *rqstp)
+{
+ size_t len = svsk->sk_datalen;
+ unsigned int i, npages;
+
+ if (!len)
+ return 0;
+ npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+ for (i = 0; i < npages; i++) {
+ if (rqstp->rq_pages[i] != NULL)
+ put_page(rqstp->rq_pages[i]);
+ BUG_ON(svsk->sk_pages[i] == NULL);
+ rqstp->rq_pages[i] = svsk->sk_pages[i];
+ svsk->sk_pages[i] = NULL;
+ }
+ rqstp->rq_arg.head[0].iov_base = page_address(rqstp->rq_pages[0]);
+ return len;
+}
+
+static void svc_tcp_save_pages(struct svc_sock *svsk, struct svc_rqst *rqstp)
+{
+ unsigned int i, len, npages;
+
+ if (svsk->sk_datalen == 0)
+ return;
+ len = svsk->sk_datalen;
+ npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+ for (i = 0; i < npages; i++) {
+ svsk->sk_pages[i] = rqstp->rq_pages[i];
+ rqstp->rq_pages[i] = NULL;
+ }
+}
+
+static void svc_tcp_clear_pages(struct svc_sock *svsk)
+{
+ unsigned int i, len, npages;
+
+ if (svsk->sk_datalen == 0)
+ goto out;
+ len = svsk->sk_datalen;
+ npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+ for (i = 0; i < npages; i++) {
+ if (svsk->sk_pages[i] == NULL) {
+ WARN_ON_ONCE(1);
+ continue;
+ }
+ put_page(svsk->sk_pages[i]);
+ svsk->sk_pages[i] = NULL;
+ }
+out:
+ svsk->sk_tcplen = 0;
+ svsk->sk_datalen = 0;
+}
+
+/*
+ * Receive fragment record header into sk_marker.
+ */
+static ssize_t svc_tcp_read_marker(struct svc_sock *svsk,
+ struct svc_rqst *rqstp)
+{
+ ssize_t want, len;
+
+ /* If we haven't gotten the record length yet,
+ * get the next four bytes.
+ */
+ if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) {
+ struct msghdr msg = { NULL };
+ struct kvec iov;
+
+ want = sizeof(rpc_fraghdr) - svsk->sk_tcplen;
+ iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen;
+ iov.iov_len = want;
+ iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want);
+ len = svc_tcp_sock_recv_cmsg(svsk, &msg);
+ if (len < 0)
+ return len;
+ svsk->sk_tcplen += len;
+ if (len < want) {
+ /* call again to read the remaining bytes */
+ goto err_short;
+ }
+ trace_svcsock_marker(&svsk->sk_xprt, svsk->sk_marker);
+ if (svc_sock_reclen(svsk) + svsk->sk_datalen >
+ svsk->sk_xprt.xpt_server->sv_max_mesg)
+ goto err_too_large;
+ }
+ return svc_sock_reclen(svsk);
+
+err_too_large:
+ net_notice_ratelimited("svc: %s %s RPC fragment too large: %d\n",
+ __func__, svsk->sk_xprt.xpt_server->sv_name,
+ svc_sock_reclen(svsk));
+ svc_xprt_deferred_close(&svsk->sk_xprt);
+err_short:
+ return -EAGAIN;
+}
+
+static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp)
+{
+ struct rpc_xprt *bc_xprt = svsk->sk_xprt.xpt_bc_xprt;
+ struct rpc_rqst *req = NULL;
+ struct kvec *src, *dst;
+ __be32 *p = (__be32 *)rqstp->rq_arg.head[0].iov_base;
+ __be32 xid;
+ __be32 calldir;
+
+ xid = *p++;
+ calldir = *p;
+
+ if (!bc_xprt)
+ return -EAGAIN;
+ spin_lock(&bc_xprt->queue_lock);
+ req = xprt_lookup_rqst(bc_xprt, xid);
+ if (!req)
+ goto unlock_notfound;
+
+ memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf));
+ /*
+ * XXX!: cheating for now! Only copying HEAD.
+ * But we know this is good enough for now (in fact, for any
+ * callback reply in the forseeable future).
+ */
+ dst = &req->rq_private_buf.head[0];
+ src = &rqstp->rq_arg.head[0];
+ if (dst->iov_len < src->iov_len)
+ goto unlock_eagain; /* whatever; just giving up. */
+ memcpy(dst->iov_base, src->iov_base, src->iov_len);
+ xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len);
+ rqstp->rq_arg.len = 0;
+ spin_unlock(&bc_xprt->queue_lock);
+ return 0;
+unlock_notfound:
+ printk(KERN_NOTICE
+ "%s: Got unrecognized reply: "
+ "calldir 0x%x xpt_bc_xprt %p xid %08x\n",
+ __func__, ntohl(calldir),
+ bc_xprt, ntohl(xid));
+unlock_eagain:
+ spin_unlock(&bc_xprt->queue_lock);
+ return -EAGAIN;
+}
+
+static void svc_tcp_fragment_received(struct svc_sock *svsk)
+{
+ /* If we have more data, signal svc_xprt_enqueue() to try again */
+ svsk->sk_tcplen = 0;
+ svsk->sk_marker = xdr_zero;
+
+ smp_wmb();
+ tcp_set_rcvlowat(svsk->sk_sk, 1);
+}
+
+/**
+ * svc_tcp_recvfrom - Receive data from a TCP socket
+ * @rqstp: request structure into which to receive an RPC Call
+ *
+ * Called in a loop when XPT_DATA has been set.
+ *
+ * Read the 4-byte stream record marker, then use the record length
+ * in that marker to set up exactly the resources needed to receive
+ * the next RPC message into @rqstp.
+ *
+ * Returns:
+ * On success, the number of bytes in a received RPC Call, or
+ * %0 if a complete RPC Call message was not ready to return
+ *
+ * The zero return case handles partial receives and callback Replies.
+ * The state of a partial receive is preserved in the svc_sock for
+ * the next call to svc_tcp_recvfrom.
+ */
+static int svc_tcp_recvfrom(struct svc_rqst *rqstp)
+{
+ struct svc_sock *svsk =
+ container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt);
+ struct svc_serv *serv = svsk->sk_xprt.xpt_server;
+ size_t want, base;
+ ssize_t len;
+ __be32 *p;
+ __be32 calldir;
+
+ clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+ len = svc_tcp_read_marker(svsk, rqstp);
+ if (len < 0)
+ goto error;
+
+ base = svc_tcp_restore_pages(svsk, rqstp);
+ want = len - (svsk->sk_tcplen - sizeof(rpc_fraghdr));
+ len = svc_tcp_read_msg(rqstp, base + want, base);
+ if (len >= 0) {
+ trace_svcsock_tcp_recv(&svsk->sk_xprt, len);
+ svsk->sk_tcplen += len;
+ svsk->sk_datalen += len;
+ }
+ if (len != want || !svc_sock_final_rec(svsk))
+ goto err_incomplete;
+ if (svsk->sk_datalen < 8)
+ goto err_nuts;
+
+ rqstp->rq_arg.len = svsk->sk_datalen;
+ rqstp->rq_arg.page_base = 0;
+ if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) {
+ rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len;
+ rqstp->rq_arg.page_len = 0;
+ } else
+ rqstp->rq_arg.page_len = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len;
+
+ rqstp->rq_xprt_ctxt = NULL;
+ rqstp->rq_prot = IPPROTO_TCP;
+ if (test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags))
+ set_bit(RQ_LOCAL, &rqstp->rq_flags);
+ else
+ clear_bit(RQ_LOCAL, &rqstp->rq_flags);
+
+ p = (__be32 *)rqstp->rq_arg.head[0].iov_base;
+ calldir = p[1];
+ if (calldir)
+ len = receive_cb_reply(svsk, rqstp);
+
+ /* Reset TCP read info */
+ svsk->sk_datalen = 0;
+ svc_tcp_fragment_received(svsk);
+
+ if (len < 0)
+ goto error;
+
+ svc_xprt_copy_addrs(rqstp, &svsk->sk_xprt);
+ if (serv->sv_stats)
+ serv->sv_stats->nettcpcnt++;
+
+ svc_sock_secure_port(rqstp);
+ svc_xprt_received(rqstp->rq_xprt);
+ return rqstp->rq_arg.len;
+
+err_incomplete:
+ svc_tcp_save_pages(svsk, rqstp);
+ if (len < 0 && len != -EAGAIN)
+ goto err_delete;
+ if (len == want)
+ svc_tcp_fragment_received(svsk);
+ else {
+ /* Avoid more ->sk_data_ready() calls until the rest
+ * of the message has arrived. This reduces service
+ * thread wake-ups on large incoming messages. */
+ tcp_set_rcvlowat(svsk->sk_sk,
+ svc_sock_reclen(svsk) - svsk->sk_tcplen);
+
+ trace_svcsock_tcp_recv_short(&svsk->sk_xprt,
+ svc_sock_reclen(svsk),
+ svsk->sk_tcplen - sizeof(rpc_fraghdr));
+ }
+ goto err_noclose;
+error:
+ if (len != -EAGAIN)
+ goto err_delete;
+ trace_svcsock_tcp_recv_eagain(&svsk->sk_xprt, 0);
+ goto err_noclose;
+err_nuts:
+ svsk->sk_datalen = 0;
+err_delete:
+ trace_svcsock_tcp_recv_err(&svsk->sk_xprt, len);
+ svc_xprt_deferred_close(&svsk->sk_xprt);
+err_noclose:
+ svc_xprt_received(rqstp->rq_xprt);
+ return 0; /* record not complete */
+}
+
+/*
+ * MSG_SPLICE_PAGES is used exclusively to reduce the number of
+ * copy operations in this path. Therefore the caller must ensure
+ * that the pages backing @xdr are unchanging.
+ *
+ * Note that the send is non-blocking. The caller has incremented
+ * the reference count on each page backing the RPC message, and
+ * the network layer will "put" these pages when transmission is
+ * complete.
+ *
+ * This is safe for our RPC services because the memory backing
+ * the head and tail components is never kmalloc'd. These always
+ * come from pages in the svc_rqst::rq_pages array.
+ */
+static int svc_tcp_sendmsg(struct svc_sock *svsk, struct svc_rqst *rqstp,
+ rpc_fraghdr marker, unsigned int *sentp)
+{
+ struct msghdr msg = {
+ .msg_flags = MSG_SPLICE_PAGES,
+ };
+ unsigned int count;
+ void *buf;
+ int ret;
+
+ *sentp = 0;
+
+ /* The stream record marker is copied into a temporary page
+ * fragment buffer so that it can be included in rq_bvec.
+ */
+ buf = page_frag_alloc(&svsk->sk_frag_cache, sizeof(marker),
+ GFP_KERNEL);
+ if (!buf)
+ return -ENOMEM;
+ memcpy(buf, &marker, sizeof(marker));
+ bvec_set_virt(rqstp->rq_bvec, buf, sizeof(marker));
+
+ count = xdr_buf_to_bvec(rqstp->rq_bvec + 1,
+ ARRAY_SIZE(rqstp->rq_bvec) - 1, &rqstp->rq_res);
+
+ iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, rqstp->rq_bvec,
+ 1 + count, sizeof(marker) + rqstp->rq_res.len);
+ ret = sock_sendmsg(svsk->sk_sock, &msg);
+ if (ret < 0)
+ return ret;
+ *sentp += ret;
+ return 0;
+}
+
+/**
+ * svc_tcp_sendto - Send out a reply on a TCP socket
+ * @rqstp: completed svc_rqst
+ *
+ * xpt_mutex ensures @rqstp's whole message is written to the socket
+ * without interruption.
+ *
+ * Returns the number of bytes sent, or a negative errno.
+ */
+static int svc_tcp_sendto(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct xdr_buf *xdr = &rqstp->rq_res;
+ rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT |
+ (u32)xdr->len);
+ unsigned int sent;
+ int err;
+
+ svc_tcp_release_ctxt(xprt, rqstp->rq_xprt_ctxt);
+ rqstp->rq_xprt_ctxt = NULL;
+
+ mutex_lock(&xprt->xpt_mutex);
+ if (svc_xprt_is_dead(xprt))
+ goto out_notconn;
+ err = svc_tcp_sendmsg(svsk, rqstp, marker, &sent);
+ trace_svcsock_tcp_send(xprt, err < 0 ? (long)err : sent);
+ if (err < 0 || sent != (xdr->len + sizeof(marker)))
+ goto out_close;
+ mutex_unlock(&xprt->xpt_mutex);
+ return sent;
+
+out_notconn:
+ mutex_unlock(&xprt->xpt_mutex);
+ return -ENOTCONN;
+out_close:
+ pr_notice("rpc-srv/tcp: %s: %s %d when sending %d bytes - shutting down socket\n",
+ xprt->xpt_server->sv_name,
+ (err < 0) ? "got error" : "sent",
+ (err < 0) ? err : sent, xdr->len);
+ svc_xprt_deferred_close(xprt);
+ mutex_unlock(&xprt->xpt_mutex);
+ return -EAGAIN;
+}
+
+static struct svc_xprt *svc_tcp_create(struct svc_serv *serv,
+ struct net *net,
+ struct sockaddr *sa, int salen,
+ int flags)
+{
+ return svc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags);
+}
+
+static const struct svc_xprt_ops svc_tcp_ops = {
+ .xpo_create = svc_tcp_create,
+ .xpo_recvfrom = svc_tcp_recvfrom,
+ .xpo_sendto = svc_tcp_sendto,
+ .xpo_result_payload = svc_sock_result_payload,
+ .xpo_release_ctxt = svc_tcp_release_ctxt,
+ .xpo_detach = svc_tcp_sock_detach,
+ .xpo_free = svc_sock_free,
+ .xpo_has_wspace = svc_tcp_has_wspace,
+ .xpo_accept = svc_tcp_accept,
+ .xpo_kill_temp_xprt = svc_tcp_kill_temp_xprt,
+ .xpo_handshake = svc_tcp_handshake,
+};
+
+static struct svc_xprt_class svc_tcp_class = {
+ .xcl_name = "tcp",
+ .xcl_owner = THIS_MODULE,
+ .xcl_ops = &svc_tcp_ops,
+ .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP,
+ .xcl_ident = XPRT_TRANSPORT_TCP,
+};
+
+void svc_init_xprt_sock(void)
+{
+ svc_reg_xprt_class(&svc_tcp_class);
+ svc_reg_xprt_class(&svc_udp_class);
+}
+
+void svc_cleanup_xprt_sock(void)
+{
+ svc_unreg_xprt_class(&svc_tcp_class);
+ svc_unreg_xprt_class(&svc_udp_class);
+}
+
+static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv)
+{
+ struct sock *sk = svsk->sk_sk;
+
+ svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_tcp_class,
+ &svsk->sk_xprt, serv);
+ set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags);
+ set_bit(XPT_CONG_CTRL, &svsk->sk_xprt.xpt_flags);
+ if (sk->sk_state == TCP_LISTEN) {
+ strcpy(svsk->sk_xprt.xpt_remotebuf, "listener");
+ set_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags);
+ sk->sk_data_ready = svc_tcp_listen_data_ready;
+ set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags);
+ } else {
+ sk->sk_state_change = svc_tcp_state_change;
+ sk->sk_data_ready = svc_data_ready;
+ sk->sk_write_space = svc_write_space;
+
+ svsk->sk_marker = xdr_zero;
+ svsk->sk_tcplen = 0;
+ svsk->sk_datalen = 0;
+ memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages));
+
+ tcp_sock_set_nodelay(sk);
+
+ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags);
+ switch (sk->sk_state) {
+ case TCP_SYN_RECV:
+ case TCP_ESTABLISHED:
+ break;
+ default:
+ svc_xprt_deferred_close(&svsk->sk_xprt);
+ }
+ }
+}
+
+void svc_sock_update_bufs(struct svc_serv *serv)
+{
+ /*
+ * The number of server threads has changed. Update
+ * rcvbuf and sndbuf accordingly on all sockets
+ */
+ struct svc_sock *svsk;
+
+ spin_lock_bh(&serv->sv_lock);
+ list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list)
+ set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags);
+ spin_unlock_bh(&serv->sv_lock);
+}
+EXPORT_SYMBOL_GPL(svc_sock_update_bufs);
+
+/*
+ * Initialize socket for RPC use and create svc_sock struct
+ */
+static struct svc_sock *svc_setup_socket(struct svc_serv *serv,
+ struct socket *sock,
+ int flags)
+{
+ struct svc_sock *svsk;
+ struct sock *inet;
+ int pmap_register = !(flags & SVC_SOCK_ANONYMOUS);
+
+ svsk = kzalloc(sizeof(*svsk), GFP_KERNEL);
+ if (!svsk)
+ return ERR_PTR(-ENOMEM);
+
+ inet = sock->sk;
+
+ if (pmap_register) {
+ int err;
+
+ err = svc_register(serv, sock_net(sock->sk), inet->sk_family,
+ inet->sk_protocol,
+ ntohs(inet_sk(inet)->inet_sport));
+ if (err < 0) {
+ kfree(svsk);
+ return ERR_PTR(err);
+ }
+ }
+
+ svsk->sk_sock = sock;
+ svsk->sk_sk = inet;
+ svsk->sk_ostate = inet->sk_state_change;
+ svsk->sk_odata = inet->sk_data_ready;
+ svsk->sk_owspace = inet->sk_write_space;
+ /*
+ * This barrier is necessary in order to prevent race condition
+ * with svc_data_ready(), svc_tcp_listen_data_ready(), and others
+ * when calling callbacks above.
+ */
+ wmb();
+ inet->sk_user_data = svsk;
+
+ /* Initialize the socket */
+ if (sock->type == SOCK_DGRAM)
+ svc_udp_init(svsk, serv);
+ else
+ svc_tcp_init(svsk, serv);
+
+ trace_svcsock_new(svsk, sock);
+ return svsk;
+}
+
+/**
+ * svc_addsock - add a listener socket to an RPC service
+ * @serv: pointer to RPC service to which to add a new listener
+ * @net: caller's network namespace
+ * @fd: file descriptor of the new listener
+ * @name_return: pointer to buffer to fill in with name of listener
+ * @len: size of the buffer
+ * @cred: credential
+ *
+ * Fills in socket name and returns positive length of name if successful.
+ * Name is terminated with '\n'. On error, returns a negative errno
+ * value.
+ */
+int svc_addsock(struct svc_serv *serv, struct net *net, const int fd,
+ char *name_return, const size_t len, const struct cred *cred)
+{
+ int err = 0;
+ struct socket *so = sockfd_lookup(fd, &err);
+ struct svc_sock *svsk = NULL;
+ struct sockaddr_storage addr;
+ struct sockaddr *sin = (struct sockaddr *)&addr;
+ int salen;
+
+ if (!so)
+ return err;
+ err = -EINVAL;
+ if (sock_net(so->sk) != net)
+ goto out;
+ err = -EAFNOSUPPORT;
+ if ((so->sk->sk_family != PF_INET) && (so->sk->sk_family != PF_INET6))
+ goto out;
+ err = -EPROTONOSUPPORT;
+ if (so->sk->sk_protocol != IPPROTO_TCP &&
+ so->sk->sk_protocol != IPPROTO_UDP)
+ goto out;
+ err = -EISCONN;
+ if (so->state > SS_UNCONNECTED)
+ goto out;
+ err = -ENOENT;
+ if (!try_module_get(THIS_MODULE))
+ goto out;
+ svsk = svc_setup_socket(serv, so, SVC_SOCK_DEFAULTS);
+ if (IS_ERR(svsk)) {
+ module_put(THIS_MODULE);
+ err = PTR_ERR(svsk);
+ goto out;
+ }
+ salen = kernel_getsockname(svsk->sk_sock, sin);
+ if (salen >= 0)
+ svc_xprt_set_local(&svsk->sk_xprt, sin, salen);
+ svsk->sk_xprt.xpt_cred = get_cred(cred);
+ svc_add_new_perm_xprt(serv, &svsk->sk_xprt);
+ return svc_one_sock_name(svsk, name_return, len);
+out:
+ sockfd_put(so);
+ return err;
+}
+EXPORT_SYMBOL_GPL(svc_addsock);
+
+/*
+ * Create socket for RPC service.
+ */
+static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
+ int protocol,
+ struct net *net,
+ struct sockaddr *sin, int len,
+ int flags)
+{
+ struct svc_sock *svsk;
+ struct socket *sock;
+ int error;
+ int type;
+ struct sockaddr_storage addr;
+ struct sockaddr *newsin = (struct sockaddr *)&addr;
+ int newlen;
+ int family;
+
+ if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) {
+ printk(KERN_WARNING "svc: only UDP and TCP "
+ "sockets supported\n");
+ return ERR_PTR(-EINVAL);
+ }
+
+ type = (protocol == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM;
+ switch (sin->sa_family) {
+ case AF_INET6:
+ family = PF_INET6;
+ break;
+ case AF_INET:
+ family = PF_INET;
+ break;
+ default:
+ return ERR_PTR(-EINVAL);
+ }
+
+ error = __sock_create(net, family, type, protocol, &sock, 1);
+ if (error < 0)
+ return ERR_PTR(error);
+
+ svc_reclassify_socket(sock);
+
+ /*
+ * If this is an PF_INET6 listener, we want to avoid
+ * getting requests from IPv4 remotes. Those should
+ * be shunted to a PF_INET listener via rpcbind.
+ */
+ if (family == PF_INET6)
+ ip6_sock_set_v6only(sock->sk);
+ if (type == SOCK_STREAM)
+ sock->sk->sk_reuse = SK_CAN_REUSE; /* allow address reuse */
+ error = kernel_bind(sock, sin, len);
+ if (error < 0)
+ goto bummer;
+
+ error = kernel_getsockname(sock, newsin);
+ if (error < 0)
+ goto bummer;
+ newlen = error;
+
+ if (protocol == IPPROTO_TCP) {
+ if ((error = kernel_listen(sock, 64)) < 0)
+ goto bummer;
+ }
+
+ svsk = svc_setup_socket(serv, sock, flags);
+ if (IS_ERR(svsk)) {
+ error = PTR_ERR(svsk);
+ goto bummer;
+ }
+ svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen);
+ return (struct svc_xprt *)svsk;
+bummer:
+ sock_release(sock);
+ return ERR_PTR(error);
+}
+
+/*
+ * Detach the svc_sock from the socket so that no
+ * more callbacks occur.
+ */
+static void svc_sock_detach(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct sock *sk = svsk->sk_sk;
+
+ /* put back the old socket callbacks */
+ lock_sock(sk);
+ sk->sk_state_change = svsk->sk_ostate;
+ sk->sk_data_ready = svsk->sk_odata;
+ sk->sk_write_space = svsk->sk_owspace;
+ sk->sk_user_data = NULL;
+ release_sock(sk);
+}
+
+/*
+ * Disconnect the socket, and reset the callbacks
+ */
+static void svc_tcp_sock_detach(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+
+ tls_handshake_close(svsk->sk_sock);
+
+ svc_sock_detach(xprt);
+
+ if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) {
+ svc_tcp_clear_pages(svsk);
+ kernel_sock_shutdown(svsk->sk_sock, SHUT_RDWR);
+ }
+}
+
+/*
+ * Free the svc_sock's socket resources and the svc_sock itself.
+ */
+static void svc_sock_free(struct svc_xprt *xprt)
+{
+ struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt);
+ struct page_frag_cache *pfc = &svsk->sk_frag_cache;
+ struct socket *sock = svsk->sk_sock;
+
+ trace_svcsock_free(svsk, sock);
+
+ tls_handshake_cancel(sock->sk);
+ if (sock->file)
+ sockfd_put(sock);
+ else
+ sock_release(sock);
+ if (pfc->va)
+ __page_frag_cache_drain(virt_to_head_page(pfc->va),
+ pfc->pagecnt_bias);
+ kfree(svsk);
+}
diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c
new file mode 100644
index 0000000000..93941ab125
--- /dev/null
+++ b/net/sunrpc/sysctl.c
@@ -0,0 +1,181 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/sysctl.c
+ *
+ * Sysctl interface to sunrpc module.
+ *
+ * I would prefer to register the sunrpc table below sys/net, but that's
+ * impossible at the moment.
+ */
+
+#include <linux/types.h>
+#include <linux/linkage.h>
+#include <linux/ctype.h>
+#include <linux/fs.h>
+#include <linux/sysctl.h>
+#include <linux/module.h>
+
+#include <linux/uaccess.h>
+#include <linux/sunrpc/types.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/stats.h>
+#include <linux/sunrpc/svc_xprt.h>
+
+#include "netns.h"
+
+/*
+ * Declare the debug flags here
+ */
+unsigned int rpc_debug;
+EXPORT_SYMBOL_GPL(rpc_debug);
+
+unsigned int nfs_debug;
+EXPORT_SYMBOL_GPL(nfs_debug);
+
+unsigned int nfsd_debug;
+EXPORT_SYMBOL_GPL(nfsd_debug);
+
+unsigned int nlm_debug;
+EXPORT_SYMBOL_GPL(nlm_debug);
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+
+static int proc_do_xprt(struct ctl_table *table, int write,
+ void *buffer, size_t *lenp, loff_t *ppos)
+{
+ char tmpbuf[256];
+ ssize_t len;
+
+ if (write || *ppos) {
+ *lenp = 0;
+ return 0;
+ }
+ len = svc_print_xprts(tmpbuf, sizeof(tmpbuf));
+ len = memory_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len);
+
+ if (len < 0) {
+ *lenp = 0;
+ return -EINVAL;
+ }
+ *lenp = len;
+ return 0;
+}
+
+static int
+proc_dodebug(struct ctl_table *table, int write, void *buffer, size_t *lenp,
+ loff_t *ppos)
+{
+ char tmpbuf[20], *s = NULL;
+ char *p;
+ unsigned int value;
+ size_t left, len;
+
+ if ((*ppos && !write) || !*lenp) {
+ *lenp = 0;
+ return 0;
+ }
+
+ left = *lenp;
+
+ if (write) {
+ p = buffer;
+ while (left && isspace(*p)) {
+ left--;
+ p++;
+ }
+ if (!left)
+ goto done;
+
+ if (left > sizeof(tmpbuf) - 1)
+ return -EINVAL;
+ memcpy(tmpbuf, p, left);
+ tmpbuf[left] = '\0';
+
+ value = simple_strtol(tmpbuf, &s, 0);
+ if (s) {
+ left -= (s - tmpbuf);
+ if (left && !isspace(*s))
+ return -EINVAL;
+ while (left && isspace(*s)) {
+ left--;
+ s++;
+ }
+ } else
+ left = 0;
+ *(unsigned int *) table->data = value;
+ /* Display the RPC tasks on writing to rpc_debug */
+ if (strcmp(table->procname, "rpc_debug") == 0)
+ rpc_show_tasks(&init_net);
+ } else {
+ len = sprintf(tmpbuf, "0x%04x", *(unsigned int *) table->data);
+ if (len > left)
+ len = left;
+ memcpy(buffer, tmpbuf, len);
+ if ((left -= len) > 0) {
+ *((char *)buffer + len) = '\n';
+ left--;
+ }
+ }
+
+done:
+ *lenp -= left;
+ *ppos += *lenp;
+ return 0;
+}
+
+static struct ctl_table_header *sunrpc_table_header;
+
+static struct ctl_table debug_table[] = {
+ {
+ .procname = "rpc_debug",
+ .data = &rpc_debug,
+ .maxlen = sizeof(int),
+ .mode = 0644,
+ .proc_handler = proc_dodebug
+ },
+ {
+ .procname = "nfs_debug",
+ .data = &nfs_debug,
+ .maxlen = sizeof(int),
+ .mode = 0644,
+ .proc_handler = proc_dodebug
+ },
+ {
+ .procname = "nfsd_debug",
+ .data = &nfsd_debug,
+ .maxlen = sizeof(int),
+ .mode = 0644,
+ .proc_handler = proc_dodebug
+ },
+ {
+ .procname = "nlm_debug",
+ .data = &nlm_debug,
+ .maxlen = sizeof(int),
+ .mode = 0644,
+ .proc_handler = proc_dodebug
+ },
+ {
+ .procname = "transports",
+ .maxlen = 256,
+ .mode = 0444,
+ .proc_handler = proc_do_xprt,
+ },
+ { }
+};
+
+void
+rpc_register_sysctl(void)
+{
+ if (!sunrpc_table_header)
+ sunrpc_table_header = register_sysctl("sunrpc", debug_table);
+}
+
+void
+rpc_unregister_sysctl(void)
+{
+ if (sunrpc_table_header) {
+ unregister_sysctl_table(sunrpc_table_header);
+ sunrpc_table_header = NULL;
+ }
+}
+#endif
diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c
new file mode 100644
index 0000000000..5c8ecdaaa9
--- /dev/null
+++ b/net/sunrpc/sysfs.c
@@ -0,0 +1,627 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2020 Anna Schumaker <Anna.Schumaker@Netapp.com>
+ */
+#include <linux/sunrpc/clnt.h>
+#include <linux/kobject.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/xprtsock.h>
+
+#include "sysfs.h"
+
+struct xprt_addr {
+ const char *addr;
+ struct rcu_head rcu;
+};
+
+static void free_xprt_addr(struct rcu_head *head)
+{
+ struct xprt_addr *addr = container_of(head, struct xprt_addr, rcu);
+
+ kfree(addr->addr);
+ kfree(addr);
+}
+
+static struct kset *rpc_sunrpc_kset;
+static struct kobject *rpc_sunrpc_client_kobj, *rpc_sunrpc_xprt_switch_kobj;
+
+static void rpc_sysfs_object_release(struct kobject *kobj)
+{
+ kfree(kobj);
+}
+
+static const struct kobj_ns_type_operations *
+rpc_sysfs_object_child_ns_type(const struct kobject *kobj)
+{
+ return &net_ns_type_operations;
+}
+
+static const struct kobj_type rpc_sysfs_object_type = {
+ .release = rpc_sysfs_object_release,
+ .sysfs_ops = &kobj_sysfs_ops,
+ .child_ns_type = rpc_sysfs_object_child_ns_type,
+};
+
+static struct kobject *rpc_sysfs_object_alloc(const char *name,
+ struct kset *kset,
+ struct kobject *parent)
+{
+ struct kobject *kobj;
+
+ kobj = kzalloc(sizeof(*kobj), GFP_KERNEL);
+ if (kobj) {
+ kobj->kset = kset;
+ if (kobject_init_and_add(kobj, &rpc_sysfs_object_type,
+ parent, "%s", name) == 0)
+ return kobj;
+ kobject_put(kobj);
+ }
+ return NULL;
+}
+
+static inline struct rpc_xprt *
+rpc_sysfs_xprt_kobj_get_xprt(struct kobject *kobj)
+{
+ struct rpc_sysfs_xprt *x = container_of(kobj,
+ struct rpc_sysfs_xprt, kobject);
+
+ return xprt_get(x->xprt);
+}
+
+static inline struct rpc_xprt_switch *
+rpc_sysfs_xprt_kobj_get_xprt_switch(struct kobject *kobj)
+{
+ struct rpc_sysfs_xprt *x = container_of(kobj,
+ struct rpc_sysfs_xprt, kobject);
+
+ return xprt_switch_get(x->xprt_switch);
+}
+
+static inline struct rpc_xprt_switch *
+rpc_sysfs_xprt_switch_kobj_get_xprt(struct kobject *kobj)
+{
+ struct rpc_sysfs_xprt_switch *x = container_of(kobj,
+ struct rpc_sysfs_xprt_switch, kobject);
+
+ return xprt_switch_get(x->xprt_switch);
+}
+
+static ssize_t rpc_sysfs_xprt_dstaddr_show(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ char *buf)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ ssize_t ret;
+
+ if (!xprt) {
+ ret = sprintf(buf, "<closed>\n");
+ goto out;
+ }
+ ret = sprintf(buf, "%s\n", xprt->address_strings[RPC_DISPLAY_ADDR]);
+ xprt_put(xprt);
+out:
+ return ret;
+}
+
+static ssize_t rpc_sysfs_xprt_srcaddr_show(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ char *buf)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ size_t buflen = PAGE_SIZE;
+ ssize_t ret;
+
+ if (!xprt || !xprt_connected(xprt)) {
+ ret = sprintf(buf, "<closed>\n");
+ } else if (xprt->ops->get_srcaddr) {
+ ret = xprt->ops->get_srcaddr(xprt, buf, buflen);
+ if (ret > 0) {
+ if (ret < buflen - 1) {
+ buf[ret] = '\n';
+ ret++;
+ buf[ret] = '\0';
+ }
+ } else
+ ret = sprintf(buf, "<closed>\n");
+ } else
+ ret = sprintf(buf, "<not a socket>\n");
+ xprt_put(xprt);
+ return ret;
+}
+
+static ssize_t rpc_sysfs_xprt_info_show(struct kobject *kobj,
+ struct kobj_attribute *attr, char *buf)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ unsigned short srcport = 0;
+ size_t buflen = PAGE_SIZE;
+ ssize_t ret;
+
+ if (!xprt || !xprt_connected(xprt)) {
+ ret = sprintf(buf, "<closed>\n");
+ goto out;
+ }
+
+ if (xprt->ops->get_srcport)
+ srcport = xprt->ops->get_srcport(xprt);
+
+ ret = snprintf(buf, buflen,
+ "last_used=%lu\ncur_cong=%lu\ncong_win=%lu\n"
+ "max_num_slots=%u\nmin_num_slots=%u\nnum_reqs=%u\n"
+ "binding_q_len=%u\nsending_q_len=%u\npending_q_len=%u\n"
+ "backlog_q_len=%u\nmain_xprt=%d\nsrc_port=%u\n"
+ "tasks_queuelen=%ld\ndst_port=%s\n",
+ xprt->last_used, xprt->cong, xprt->cwnd, xprt->max_reqs,
+ xprt->min_reqs, xprt->num_reqs, xprt->binding.qlen,
+ xprt->sending.qlen, xprt->pending.qlen,
+ xprt->backlog.qlen, xprt->main, srcport,
+ atomic_long_read(&xprt->queuelen),
+ xprt->address_strings[RPC_DISPLAY_PORT]);
+out:
+ xprt_put(xprt);
+ return ret;
+}
+
+static ssize_t rpc_sysfs_xprt_state_show(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ char *buf)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ ssize_t ret;
+ int locked, connected, connecting, close_wait, bound, binding,
+ closing, congested, cwnd_wait, write_space, offline, remove;
+
+ if (!(xprt && xprt->state)) {
+ ret = sprintf(buf, "state=CLOSED\n");
+ } else {
+ locked = test_bit(XPRT_LOCKED, &xprt->state);
+ connected = test_bit(XPRT_CONNECTED, &xprt->state);
+ connecting = test_bit(XPRT_CONNECTING, &xprt->state);
+ close_wait = test_bit(XPRT_CLOSE_WAIT, &xprt->state);
+ bound = test_bit(XPRT_BOUND, &xprt->state);
+ binding = test_bit(XPRT_BINDING, &xprt->state);
+ closing = test_bit(XPRT_CLOSING, &xprt->state);
+ congested = test_bit(XPRT_CONGESTED, &xprt->state);
+ cwnd_wait = test_bit(XPRT_CWND_WAIT, &xprt->state);
+ write_space = test_bit(XPRT_WRITE_SPACE, &xprt->state);
+ offline = test_bit(XPRT_OFFLINE, &xprt->state);
+ remove = test_bit(XPRT_REMOVE, &xprt->state);
+
+ ret = sprintf(buf, "state=%s %s %s %s %s %s %s %s %s %s %s %s\n",
+ locked ? "LOCKED" : "",
+ connected ? "CONNECTED" : "",
+ connecting ? "CONNECTING" : "",
+ close_wait ? "CLOSE_WAIT" : "",
+ bound ? "BOUND" : "",
+ binding ? "BOUNDING" : "",
+ closing ? "CLOSING" : "",
+ congested ? "CONGESTED" : "",
+ cwnd_wait ? "CWND_WAIT" : "",
+ write_space ? "WRITE_SPACE" : "",
+ offline ? "OFFLINE" : "",
+ remove ? "REMOVE" : "");
+ }
+
+ xprt_put(xprt);
+ return ret;
+}
+
+static ssize_t rpc_sysfs_xprt_switch_info_show(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ char *buf)
+{
+ struct rpc_xprt_switch *xprt_switch =
+ rpc_sysfs_xprt_switch_kobj_get_xprt(kobj);
+ ssize_t ret;
+
+ if (!xprt_switch)
+ return 0;
+ ret = sprintf(buf, "num_xprts=%u\nnum_active=%u\n"
+ "num_unique_destaddr=%u\nqueue_len=%ld\n",
+ xprt_switch->xps_nxprts, xprt_switch->xps_nactive,
+ xprt_switch->xps_nunique_destaddr_xprts,
+ atomic_long_read(&xprt_switch->xps_queuelen));
+ xprt_switch_put(xprt_switch);
+ return ret;
+}
+
+static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ const char *buf, size_t count)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ struct sockaddr *saddr;
+ char *dst_addr;
+ int port;
+ struct xprt_addr *saved_addr;
+ size_t buf_len;
+
+ if (!xprt)
+ return 0;
+ if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP ||
+ xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS ||
+ xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) {
+ xprt_put(xprt);
+ return -EOPNOTSUPP;
+ }
+
+ if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) {
+ count = -EINTR;
+ goto out_put;
+ }
+ saddr = (struct sockaddr *)&xprt->addr;
+ port = rpc_get_port(saddr);
+
+ /* buf_len is the len until the first occurence of either
+ * '\n' or '\0'
+ */
+ buf_len = strcspn(buf, "\n");
+
+ dst_addr = kstrndup(buf, buf_len, GFP_KERNEL);
+ if (!dst_addr)
+ goto out_err;
+ saved_addr = kzalloc(sizeof(*saved_addr), GFP_KERNEL);
+ if (!saved_addr)
+ goto out_err_free;
+ saved_addr->addr =
+ rcu_dereference_raw(xprt->address_strings[RPC_DISPLAY_ADDR]);
+ rcu_assign_pointer(xprt->address_strings[RPC_DISPLAY_ADDR], dst_addr);
+ call_rcu(&saved_addr->rcu, free_xprt_addr);
+ xprt->addrlen = rpc_pton(xprt->xprt_net, buf, buf_len, saddr,
+ sizeof(*saddr));
+ rpc_set_port(saddr, port);
+
+ xprt_force_disconnect(xprt);
+out:
+ xprt_release_write(xprt, NULL);
+out_put:
+ xprt_put(xprt);
+ return count;
+out_err_free:
+ kfree(dst_addr);
+out_err:
+ count = -ENOMEM;
+ goto out;
+}
+
+static ssize_t rpc_sysfs_xprt_state_change(struct kobject *kobj,
+ struct kobj_attribute *attr,
+ const char *buf, size_t count)
+{
+ struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj);
+ int offline = 0, online = 0, remove = 0;
+ struct rpc_xprt_switch *xps = rpc_sysfs_xprt_kobj_get_xprt_switch(kobj);
+
+ if (!xprt || !xps) {
+ count = 0;
+ goto out_put;
+ }
+
+ if (!strncmp(buf, "offline", 7))
+ offline = 1;
+ else if (!strncmp(buf, "online", 6))
+ online = 1;
+ else if (!strncmp(buf, "remove", 6))
+ remove = 1;
+ else {
+ count = -EINVAL;
+ goto out_put;
+ }
+
+ if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) {
+ count = -EINTR;
+ goto out_put;
+ }
+ if (xprt->main) {
+ count = -EINVAL;
+ goto release_tasks;
+ }
+ if (offline) {
+ xprt_set_offline_locked(xprt, xps);
+ } else if (online) {
+ xprt_set_online_locked(xprt, xps);
+ } else if (remove) {
+ if (test_bit(XPRT_OFFLINE, &xprt->state))
+ xprt_delete_locked(xprt, xps);
+ else
+ count = -EINVAL;
+ }
+
+release_tasks:
+ xprt_release_write(xprt, NULL);
+out_put:
+ xprt_put(xprt);
+ xprt_switch_put(xps);
+ return count;
+}
+
+int rpc_sysfs_init(void)
+{
+ rpc_sunrpc_kset = kset_create_and_add("sunrpc", NULL, kernel_kobj);
+ if (!rpc_sunrpc_kset)
+ return -ENOMEM;
+ rpc_sunrpc_client_kobj =
+ rpc_sysfs_object_alloc("rpc-clients", rpc_sunrpc_kset, NULL);
+ if (!rpc_sunrpc_client_kobj)
+ goto err_client;
+ rpc_sunrpc_xprt_switch_kobj =
+ rpc_sysfs_object_alloc("xprt-switches", rpc_sunrpc_kset, NULL);
+ if (!rpc_sunrpc_xprt_switch_kobj)
+ goto err_switch;
+ return 0;
+err_switch:
+ kobject_put(rpc_sunrpc_client_kobj);
+ rpc_sunrpc_client_kobj = NULL;
+err_client:
+ kset_unregister(rpc_sunrpc_kset);
+ rpc_sunrpc_kset = NULL;
+ return -ENOMEM;
+}
+
+static void rpc_sysfs_client_release(struct kobject *kobj)
+{
+ struct rpc_sysfs_client *c;
+
+ c = container_of(kobj, struct rpc_sysfs_client, kobject);
+ kfree(c);
+}
+
+static void rpc_sysfs_xprt_switch_release(struct kobject *kobj)
+{
+ struct rpc_sysfs_xprt_switch *xprt_switch;
+
+ xprt_switch = container_of(kobj, struct rpc_sysfs_xprt_switch, kobject);
+ kfree(xprt_switch);
+}
+
+static void rpc_sysfs_xprt_release(struct kobject *kobj)
+{
+ struct rpc_sysfs_xprt *xprt;
+
+ xprt = container_of(kobj, struct rpc_sysfs_xprt, kobject);
+ kfree(xprt);
+}
+
+static const void *rpc_sysfs_client_namespace(const struct kobject *kobj)
+{
+ return container_of(kobj, struct rpc_sysfs_client, kobject)->net;
+}
+
+static const void *rpc_sysfs_xprt_switch_namespace(const struct kobject *kobj)
+{
+ return container_of(kobj, struct rpc_sysfs_xprt_switch, kobject)->net;
+}
+
+static const void *rpc_sysfs_xprt_namespace(const struct kobject *kobj)
+{
+ return container_of(kobj, struct rpc_sysfs_xprt,
+ kobject)->xprt->xprt_net;
+}
+
+static struct kobj_attribute rpc_sysfs_xprt_dstaddr = __ATTR(dstaddr,
+ 0644, rpc_sysfs_xprt_dstaddr_show, rpc_sysfs_xprt_dstaddr_store);
+
+static struct kobj_attribute rpc_sysfs_xprt_srcaddr = __ATTR(srcaddr,
+ 0644, rpc_sysfs_xprt_srcaddr_show, NULL);
+
+static struct kobj_attribute rpc_sysfs_xprt_info = __ATTR(xprt_info,
+ 0444, rpc_sysfs_xprt_info_show, NULL);
+
+static struct kobj_attribute rpc_sysfs_xprt_change_state = __ATTR(xprt_state,
+ 0644, rpc_sysfs_xprt_state_show, rpc_sysfs_xprt_state_change);
+
+static struct attribute *rpc_sysfs_xprt_attrs[] = {
+ &rpc_sysfs_xprt_dstaddr.attr,
+ &rpc_sysfs_xprt_srcaddr.attr,
+ &rpc_sysfs_xprt_info.attr,
+ &rpc_sysfs_xprt_change_state.attr,
+ NULL,
+};
+ATTRIBUTE_GROUPS(rpc_sysfs_xprt);
+
+static struct kobj_attribute rpc_sysfs_xprt_switch_info =
+ __ATTR(xprt_switch_info, 0444, rpc_sysfs_xprt_switch_info_show, NULL);
+
+static struct attribute *rpc_sysfs_xprt_switch_attrs[] = {
+ &rpc_sysfs_xprt_switch_info.attr,
+ NULL,
+};
+ATTRIBUTE_GROUPS(rpc_sysfs_xprt_switch);
+
+static const struct kobj_type rpc_sysfs_client_type = {
+ .release = rpc_sysfs_client_release,
+ .sysfs_ops = &kobj_sysfs_ops,
+ .namespace = rpc_sysfs_client_namespace,
+};
+
+static const struct kobj_type rpc_sysfs_xprt_switch_type = {
+ .release = rpc_sysfs_xprt_switch_release,
+ .default_groups = rpc_sysfs_xprt_switch_groups,
+ .sysfs_ops = &kobj_sysfs_ops,
+ .namespace = rpc_sysfs_xprt_switch_namespace,
+};
+
+static const struct kobj_type rpc_sysfs_xprt_type = {
+ .release = rpc_sysfs_xprt_release,
+ .default_groups = rpc_sysfs_xprt_groups,
+ .sysfs_ops = &kobj_sysfs_ops,
+ .namespace = rpc_sysfs_xprt_namespace,
+};
+
+void rpc_sysfs_exit(void)
+{
+ kobject_put(rpc_sunrpc_client_kobj);
+ kobject_put(rpc_sunrpc_xprt_switch_kobj);
+ kset_unregister(rpc_sunrpc_kset);
+}
+
+static struct rpc_sysfs_client *rpc_sysfs_client_alloc(struct kobject *parent,
+ struct net *net,
+ int clid)
+{
+ struct rpc_sysfs_client *p;
+
+ p = kzalloc(sizeof(*p), GFP_KERNEL);
+ if (p) {
+ p->net = net;
+ p->kobject.kset = rpc_sunrpc_kset;
+ if (kobject_init_and_add(&p->kobject, &rpc_sysfs_client_type,
+ parent, "clnt-%d", clid) == 0)
+ return p;
+ kobject_put(&p->kobject);
+ }
+ return NULL;
+}
+
+static struct rpc_sysfs_xprt_switch *
+rpc_sysfs_xprt_switch_alloc(struct kobject *parent,
+ struct rpc_xprt_switch *xprt_switch,
+ struct net *net,
+ gfp_t gfp_flags)
+{
+ struct rpc_sysfs_xprt_switch *p;
+
+ p = kzalloc(sizeof(*p), gfp_flags);
+ if (p) {
+ p->net = net;
+ p->kobject.kset = rpc_sunrpc_kset;
+ if (kobject_init_and_add(&p->kobject,
+ &rpc_sysfs_xprt_switch_type,
+ parent, "switch-%d",
+ xprt_switch->xps_id) == 0)
+ return p;
+ kobject_put(&p->kobject);
+ }
+ return NULL;
+}
+
+static struct rpc_sysfs_xprt *rpc_sysfs_xprt_alloc(struct kobject *parent,
+ struct rpc_xprt *xprt,
+ gfp_t gfp_flags)
+{
+ struct rpc_sysfs_xprt *p;
+
+ p = kzalloc(sizeof(*p), gfp_flags);
+ if (!p)
+ goto out;
+ p->kobject.kset = rpc_sunrpc_kset;
+ if (kobject_init_and_add(&p->kobject, &rpc_sysfs_xprt_type,
+ parent, "xprt-%d-%s", xprt->id,
+ xprt->address_strings[RPC_DISPLAY_PROTO]) == 0)
+ return p;
+ kobject_put(&p->kobject);
+out:
+ return NULL;
+}
+
+void rpc_sysfs_client_setup(struct rpc_clnt *clnt,
+ struct rpc_xprt_switch *xprt_switch,
+ struct net *net)
+{
+ struct rpc_sysfs_client *rpc_client;
+ struct rpc_sysfs_xprt_switch *xswitch =
+ (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs;
+
+ if (!xswitch)
+ return;
+
+ rpc_client = rpc_sysfs_client_alloc(rpc_sunrpc_client_kobj,
+ net, clnt->cl_clid);
+ if (rpc_client) {
+ char name[] = "switch";
+ int ret;
+
+ clnt->cl_sysfs = rpc_client;
+ rpc_client->clnt = clnt;
+ rpc_client->xprt_switch = xprt_switch;
+ kobject_uevent(&rpc_client->kobject, KOBJ_ADD);
+ ret = sysfs_create_link_nowarn(&rpc_client->kobject,
+ &xswitch->kobject, name);
+ if (ret)
+ pr_warn("can't create link to %s in sysfs (%d)\n",
+ name, ret);
+ }
+}
+
+void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch,
+ struct rpc_xprt *xprt,
+ gfp_t gfp_flags)
+{
+ struct rpc_sysfs_xprt_switch *rpc_xprt_switch;
+ struct net *net;
+
+ if (xprt_switch->xps_net)
+ net = xprt_switch->xps_net;
+ else
+ net = xprt->xprt_net;
+ rpc_xprt_switch =
+ rpc_sysfs_xprt_switch_alloc(rpc_sunrpc_xprt_switch_kobj,
+ xprt_switch, net, gfp_flags);
+ if (rpc_xprt_switch) {
+ xprt_switch->xps_sysfs = rpc_xprt_switch;
+ rpc_xprt_switch->xprt_switch = xprt_switch;
+ rpc_xprt_switch->xprt = xprt;
+ kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_ADD);
+ } else {
+ xprt_switch->xps_sysfs = NULL;
+ }
+}
+
+void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch,
+ struct rpc_xprt *xprt,
+ gfp_t gfp_flags)
+{
+ struct rpc_sysfs_xprt *rpc_xprt;
+ struct rpc_sysfs_xprt_switch *switch_obj =
+ (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs;
+
+ if (!switch_obj)
+ return;
+
+ rpc_xprt = rpc_sysfs_xprt_alloc(&switch_obj->kobject, xprt, gfp_flags);
+ if (rpc_xprt) {
+ xprt->xprt_sysfs = rpc_xprt;
+ rpc_xprt->xprt = xprt;
+ rpc_xprt->xprt_switch = xprt_switch;
+ kobject_uevent(&rpc_xprt->kobject, KOBJ_ADD);
+ }
+}
+
+void rpc_sysfs_client_destroy(struct rpc_clnt *clnt)
+{
+ struct rpc_sysfs_client *rpc_client = clnt->cl_sysfs;
+
+ if (rpc_client) {
+ char name[] = "switch";
+
+ sysfs_remove_link(&rpc_client->kobject, name);
+ kobject_uevent(&rpc_client->kobject, KOBJ_REMOVE);
+ kobject_del(&rpc_client->kobject);
+ kobject_put(&rpc_client->kobject);
+ clnt->cl_sysfs = NULL;
+ }
+}
+
+void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt_switch)
+{
+ struct rpc_sysfs_xprt_switch *rpc_xprt_switch = xprt_switch->xps_sysfs;
+
+ if (rpc_xprt_switch) {
+ kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_REMOVE);
+ kobject_del(&rpc_xprt_switch->kobject);
+ kobject_put(&rpc_xprt_switch->kobject);
+ xprt_switch->xps_sysfs = NULL;
+ }
+}
+
+void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt)
+{
+ struct rpc_sysfs_xprt *rpc_xprt = xprt->xprt_sysfs;
+
+ if (rpc_xprt) {
+ kobject_uevent(&rpc_xprt->kobject, KOBJ_REMOVE);
+ kobject_del(&rpc_xprt->kobject);
+ kobject_put(&rpc_xprt->kobject);
+ xprt->xprt_sysfs = NULL;
+ }
+}
diff --git a/net/sunrpc/sysfs.h b/net/sunrpc/sysfs.h
new file mode 100644
index 0000000000..d2dd77a0a0
--- /dev/null
+++ b/net/sunrpc/sysfs.h
@@ -0,0 +1,35 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2020 Anna Schumaker <Anna.Schumaker@Netapp.com>
+ */
+#ifndef __SUNRPC_SYSFS_H
+#define __SUNRPC_SYSFS_H
+
+struct rpc_sysfs_xprt_switch {
+ struct kobject kobject;
+ struct net *net;
+ struct rpc_xprt_switch *xprt_switch;
+ struct rpc_xprt *xprt;
+};
+
+struct rpc_sysfs_xprt {
+ struct kobject kobject;
+ struct rpc_xprt *xprt;
+ struct rpc_xprt_switch *xprt_switch;
+};
+
+int rpc_sysfs_init(void);
+void rpc_sysfs_exit(void);
+
+void rpc_sysfs_client_setup(struct rpc_clnt *clnt,
+ struct rpc_xprt_switch *xprt_switch,
+ struct net *net);
+void rpc_sysfs_client_destroy(struct rpc_clnt *clnt);
+void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch,
+ struct rpc_xprt *xprt, gfp_t gfp_flags);
+void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt);
+void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch,
+ struct rpc_xprt *xprt, gfp_t gfp_flags);
+void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt);
+
+#endif
diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c
new file mode 100644
index 0000000000..81ae35b376
--- /dev/null
+++ b/net/sunrpc/timer.c
@@ -0,0 +1,123 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/timer.c
+ *
+ * Estimate RPC request round trip time.
+ *
+ * Based on packet round-trip and variance estimator algorithms described
+ * in appendix A of "Congestion Avoidance and Control" by Van Jacobson
+ * and Michael J. Karels (ACM Computer Communication Review; Proceedings
+ * of the Sigcomm '88 Symposium in Stanford, CA, August, 1988).
+ *
+ * This RTT estimator is used only for RPC over datagram protocols.
+ *
+ * Copyright (C) 2002 Trond Myklebust <trond.myklebust@fys.uio.no>
+ */
+
+#include <asm/param.h>
+
+#include <linux/types.h>
+#include <linux/unistd.h>
+#include <linux/module.h>
+
+#include <linux/sunrpc/clnt.h>
+
+#define RPC_RTO_MAX (60*HZ)
+#define RPC_RTO_INIT (HZ/5)
+#define RPC_RTO_MIN (HZ/10)
+
+/**
+ * rpc_init_rtt - Initialize an RPC RTT estimator context
+ * @rt: context to initialize
+ * @timeo: initial timeout value, in jiffies
+ *
+ */
+void rpc_init_rtt(struct rpc_rtt *rt, unsigned long timeo)
+{
+ unsigned long init = 0;
+ unsigned int i;
+
+ rt->timeo = timeo;
+
+ if (timeo > RPC_RTO_INIT)
+ init = (timeo - RPC_RTO_INIT) << 3;
+ for (i = 0; i < 5; i++) {
+ rt->srtt[i] = init;
+ rt->sdrtt[i] = RPC_RTO_INIT;
+ rt->ntimeouts[i] = 0;
+ }
+}
+EXPORT_SYMBOL_GPL(rpc_init_rtt);
+
+/**
+ * rpc_update_rtt - Update an RPC RTT estimator context
+ * @rt: context to update
+ * @timer: timer array index (request type)
+ * @m: recent actual RTT, in jiffies
+ *
+ * NB: When computing the smoothed RTT and standard deviation,
+ * be careful not to produce negative intermediate results.
+ */
+void rpc_update_rtt(struct rpc_rtt *rt, unsigned int timer, long m)
+{
+ long *srtt, *sdrtt;
+
+ if (timer-- == 0)
+ return;
+
+ /* jiffies wrapped; ignore this one */
+ if (m < 0)
+ return;
+
+ if (m == 0)
+ m = 1L;
+
+ srtt = (long *)&rt->srtt[timer];
+ m -= *srtt >> 3;
+ *srtt += m;
+
+ if (m < 0)
+ m = -m;
+
+ sdrtt = (long *)&rt->sdrtt[timer];
+ m -= *sdrtt >> 2;
+ *sdrtt += m;
+
+ /* Set lower bound on the variance */
+ if (*sdrtt < RPC_RTO_MIN)
+ *sdrtt = RPC_RTO_MIN;
+}
+EXPORT_SYMBOL_GPL(rpc_update_rtt);
+
+/**
+ * rpc_calc_rto - Provide an estimated timeout value
+ * @rt: context to use for calculation
+ * @timer: timer array index (request type)
+ *
+ * Estimate RTO for an NFS RPC sent via an unreliable datagram. Use
+ * the mean and mean deviation of RTT for the appropriate type of RPC
+ * for frequently issued RPCs, and a fixed default for the others.
+ *
+ * The justification for doing "other" this way is that these RPCs
+ * happen so infrequently that timer estimation would probably be
+ * stale. Also, since many of these RPCs are non-idempotent, a
+ * conservative timeout is desired.
+ *
+ * getattr, lookup,
+ * read, write, commit - A+4D
+ * other - timeo
+ */
+unsigned long rpc_calc_rto(struct rpc_rtt *rt, unsigned int timer)
+{
+ unsigned long res;
+
+ if (timer-- == 0)
+ return rt->timeo;
+
+ res = ((rt->srtt[timer] + 7) >> 3) + rt->sdrtt[timer];
+ if (res > RPC_RTO_MAX)
+ res = RPC_RTO_MAX;
+
+ return res;
+}
+EXPORT_SYMBOL_GPL(rpc_calc_rto);
diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c
new file mode 100644
index 0000000000..62e07c330a
--- /dev/null
+++ b/net/sunrpc/xdr.c
@@ -0,0 +1,2413 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/xdr.c
+ *
+ * Generic XDR support.
+ *
+ * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
+ */
+
+#include <linux/module.h>
+#include <linux/slab.h>
+#include <linux/types.h>
+#include <linux/string.h>
+#include <linux/kernel.h>
+#include <linux/pagemap.h>
+#include <linux/errno.h>
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/msg_prot.h>
+#include <linux/bvec.h>
+#include <trace/events/sunrpc.h>
+
+static void _copy_to_pages(struct page **, size_t, const char *, size_t);
+
+
+/*
+ * XDR functions for basic NFS types
+ */
+__be32 *
+xdr_encode_netobj(__be32 *p, const struct xdr_netobj *obj)
+{
+ unsigned int quadlen = XDR_QUADLEN(obj->len);
+
+ p[quadlen] = 0; /* zero trailing bytes */
+ *p++ = cpu_to_be32(obj->len);
+ memcpy(p, obj->data, obj->len);
+ return p + XDR_QUADLEN(obj->len);
+}
+EXPORT_SYMBOL_GPL(xdr_encode_netobj);
+
+__be32 *
+xdr_decode_netobj(__be32 *p, struct xdr_netobj *obj)
+{
+ unsigned int len;
+
+ if ((len = be32_to_cpu(*p++)) > XDR_MAX_NETOBJ)
+ return NULL;
+ obj->len = len;
+ obj->data = (u8 *) p;
+ return p + XDR_QUADLEN(len);
+}
+EXPORT_SYMBOL_GPL(xdr_decode_netobj);
+
+/**
+ * xdr_encode_opaque_fixed - Encode fixed length opaque data
+ * @p: pointer to current position in XDR buffer.
+ * @ptr: pointer to data to encode (or NULL)
+ * @nbytes: size of data.
+ *
+ * Copy the array of data of length nbytes at ptr to the XDR buffer
+ * at position p, then align to the next 32-bit boundary by padding
+ * with zero bytes (see RFC1832).
+ * Note: if ptr is NULL, only the padding is performed.
+ *
+ * Returns the updated current XDR buffer position
+ *
+ */
+__be32 *xdr_encode_opaque_fixed(__be32 *p, const void *ptr, unsigned int nbytes)
+{
+ if (likely(nbytes != 0)) {
+ unsigned int quadlen = XDR_QUADLEN(nbytes);
+ unsigned int padding = (quadlen << 2) - nbytes;
+
+ if (ptr != NULL)
+ memcpy(p, ptr, nbytes);
+ if (padding != 0)
+ memset((char *)p + nbytes, 0, padding);
+ p += quadlen;
+ }
+ return p;
+}
+EXPORT_SYMBOL_GPL(xdr_encode_opaque_fixed);
+
+/**
+ * xdr_encode_opaque - Encode variable length opaque data
+ * @p: pointer to current position in XDR buffer.
+ * @ptr: pointer to data to encode (or NULL)
+ * @nbytes: size of data.
+ *
+ * Returns the updated current XDR buffer position
+ */
+__be32 *xdr_encode_opaque(__be32 *p, const void *ptr, unsigned int nbytes)
+{
+ *p++ = cpu_to_be32(nbytes);
+ return xdr_encode_opaque_fixed(p, ptr, nbytes);
+}
+EXPORT_SYMBOL_GPL(xdr_encode_opaque);
+
+__be32 *
+xdr_encode_string(__be32 *p, const char *string)
+{
+ return xdr_encode_array(p, string, strlen(string));
+}
+EXPORT_SYMBOL_GPL(xdr_encode_string);
+
+__be32 *
+xdr_decode_string_inplace(__be32 *p, char **sp,
+ unsigned int *lenp, unsigned int maxlen)
+{
+ u32 len;
+
+ len = be32_to_cpu(*p++);
+ if (len > maxlen)
+ return NULL;
+ *lenp = len;
+ *sp = (char *) p;
+ return p + XDR_QUADLEN(len);
+}
+EXPORT_SYMBOL_GPL(xdr_decode_string_inplace);
+
+/**
+ * xdr_terminate_string - '\0'-terminate a string residing in an xdr_buf
+ * @buf: XDR buffer where string resides
+ * @len: length of string, in bytes
+ *
+ */
+void xdr_terminate_string(const struct xdr_buf *buf, const u32 len)
+{
+ char *kaddr;
+
+ kaddr = kmap_atomic(buf->pages[0]);
+ kaddr[buf->page_base + len] = '\0';
+ kunmap_atomic(kaddr);
+}
+EXPORT_SYMBOL_GPL(xdr_terminate_string);
+
+size_t xdr_buf_pagecount(const struct xdr_buf *buf)
+{
+ if (!buf->page_len)
+ return 0;
+ return (buf->page_base + buf->page_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+}
+
+int
+xdr_alloc_bvec(struct xdr_buf *buf, gfp_t gfp)
+{
+ size_t i, n = xdr_buf_pagecount(buf);
+
+ if (n != 0 && buf->bvec == NULL) {
+ buf->bvec = kmalloc_array(n, sizeof(buf->bvec[0]), gfp);
+ if (!buf->bvec)
+ return -ENOMEM;
+ for (i = 0; i < n; i++) {
+ bvec_set_page(&buf->bvec[i], buf->pages[i], PAGE_SIZE,
+ 0);
+ }
+ }
+ return 0;
+}
+
+void
+xdr_free_bvec(struct xdr_buf *buf)
+{
+ kfree(buf->bvec);
+ buf->bvec = NULL;
+}
+
+/**
+ * xdr_buf_to_bvec - Copy components of an xdr_buf into a bio_vec array
+ * @bvec: bio_vec array to populate
+ * @bvec_size: element count of @bio_vec
+ * @xdr: xdr_buf to be copied
+ *
+ * Returns the number of entries consumed in @bvec.
+ */
+unsigned int xdr_buf_to_bvec(struct bio_vec *bvec, unsigned int bvec_size,
+ const struct xdr_buf *xdr)
+{
+ const struct kvec *head = xdr->head;
+ const struct kvec *tail = xdr->tail;
+ unsigned int count = 0;
+
+ if (head->iov_len) {
+ bvec_set_virt(bvec++, head->iov_base, head->iov_len);
+ ++count;
+ }
+
+ if (xdr->page_len) {
+ unsigned int offset, len, remaining;
+ struct page **pages = xdr->pages;
+
+ offset = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining > 0) {
+ len = min_t(unsigned int, remaining,
+ PAGE_SIZE - offset);
+ bvec_set_page(bvec++, *pages++, len, offset);
+ remaining -= len;
+ offset = 0;
+ if (unlikely(++count > bvec_size))
+ goto bvec_overflow;
+ }
+ }
+
+ if (tail->iov_len) {
+ bvec_set_virt(bvec, tail->iov_base, tail->iov_len);
+ if (unlikely(++count > bvec_size))
+ goto bvec_overflow;
+ }
+
+ return count;
+
+bvec_overflow:
+ pr_warn_once("%s: bio_vec array overflow\n", __func__);
+ return count - 1;
+}
+
+/**
+ * xdr_inline_pages - Prepare receive buffer for a large reply
+ * @xdr: xdr_buf into which reply will be placed
+ * @offset: expected offset where data payload will start, in bytes
+ * @pages: vector of struct page pointers
+ * @base: offset in first page where receive should start, in bytes
+ * @len: expected size of the upper layer data payload, in bytes
+ *
+ */
+void
+xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset,
+ struct page **pages, unsigned int base, unsigned int len)
+{
+ struct kvec *head = xdr->head;
+ struct kvec *tail = xdr->tail;
+ char *buf = (char *)head->iov_base;
+ unsigned int buflen = head->iov_len;
+
+ head->iov_len = offset;
+
+ xdr->pages = pages;
+ xdr->page_base = base;
+ xdr->page_len = len;
+
+ tail->iov_base = buf + offset;
+ tail->iov_len = buflen - offset;
+ xdr->buflen += len;
+}
+EXPORT_SYMBOL_GPL(xdr_inline_pages);
+
+/*
+ * Helper routines for doing 'memmove' like operations on a struct xdr_buf
+ */
+
+/**
+ * _shift_data_left_pages
+ * @pages: vector of pages containing both the source and dest memory area.
+ * @pgto_base: page vector address of destination
+ * @pgfrom_base: page vector address of source
+ * @len: number of bytes to copy
+ *
+ * Note: the addresses pgto_base and pgfrom_base are both calculated in
+ * the same way:
+ * if a memory area starts at byte 'base' in page 'pages[i]',
+ * then its address is given as (i << PAGE_CACHE_SHIFT) + base
+ * Alse note: pgto_base must be < pgfrom_base, but the memory areas
+ * they point to may overlap.
+ */
+static void
+_shift_data_left_pages(struct page **pages, size_t pgto_base,
+ size_t pgfrom_base, size_t len)
+{
+ struct page **pgfrom, **pgto;
+ char *vfrom, *vto;
+ size_t copy;
+
+ BUG_ON(pgfrom_base <= pgto_base);
+
+ if (!len)
+ return;
+
+ pgto = pages + (pgto_base >> PAGE_SHIFT);
+ pgfrom = pages + (pgfrom_base >> PAGE_SHIFT);
+
+ pgto_base &= ~PAGE_MASK;
+ pgfrom_base &= ~PAGE_MASK;
+
+ do {
+ if (pgto_base >= PAGE_SIZE) {
+ pgto_base = 0;
+ pgto++;
+ }
+ if (pgfrom_base >= PAGE_SIZE){
+ pgfrom_base = 0;
+ pgfrom++;
+ }
+
+ copy = len;
+ if (copy > (PAGE_SIZE - pgto_base))
+ copy = PAGE_SIZE - pgto_base;
+ if (copy > (PAGE_SIZE - pgfrom_base))
+ copy = PAGE_SIZE - pgfrom_base;
+
+ vto = kmap_atomic(*pgto);
+ if (*pgto != *pgfrom) {
+ vfrom = kmap_atomic(*pgfrom);
+ memcpy(vto + pgto_base, vfrom + pgfrom_base, copy);
+ kunmap_atomic(vfrom);
+ } else
+ memmove(vto + pgto_base, vto + pgfrom_base, copy);
+ flush_dcache_page(*pgto);
+ kunmap_atomic(vto);
+
+ pgto_base += copy;
+ pgfrom_base += copy;
+
+ } while ((len -= copy) != 0);
+}
+
+/**
+ * _shift_data_right_pages
+ * @pages: vector of pages containing both the source and dest memory area.
+ * @pgto_base: page vector address of destination
+ * @pgfrom_base: page vector address of source
+ * @len: number of bytes to copy
+ *
+ * Note: the addresses pgto_base and pgfrom_base are both calculated in
+ * the same way:
+ * if a memory area starts at byte 'base' in page 'pages[i]',
+ * then its address is given as (i << PAGE_SHIFT) + base
+ * Also note: pgfrom_base must be < pgto_base, but the memory areas
+ * they point to may overlap.
+ */
+static void
+_shift_data_right_pages(struct page **pages, size_t pgto_base,
+ size_t pgfrom_base, size_t len)
+{
+ struct page **pgfrom, **pgto;
+ char *vfrom, *vto;
+ size_t copy;
+
+ BUG_ON(pgto_base <= pgfrom_base);
+
+ if (!len)
+ return;
+
+ pgto_base += len;
+ pgfrom_base += len;
+
+ pgto = pages + (pgto_base >> PAGE_SHIFT);
+ pgfrom = pages + (pgfrom_base >> PAGE_SHIFT);
+
+ pgto_base &= ~PAGE_MASK;
+ pgfrom_base &= ~PAGE_MASK;
+
+ do {
+ /* Are any pointers crossing a page boundary? */
+ if (pgto_base == 0) {
+ pgto_base = PAGE_SIZE;
+ pgto--;
+ }
+ if (pgfrom_base == 0) {
+ pgfrom_base = PAGE_SIZE;
+ pgfrom--;
+ }
+
+ copy = len;
+ if (copy > pgto_base)
+ copy = pgto_base;
+ if (copy > pgfrom_base)
+ copy = pgfrom_base;
+ pgto_base -= copy;
+ pgfrom_base -= copy;
+
+ vto = kmap_atomic(*pgto);
+ if (*pgto != *pgfrom) {
+ vfrom = kmap_atomic(*pgfrom);
+ memcpy(vto + pgto_base, vfrom + pgfrom_base, copy);
+ kunmap_atomic(vfrom);
+ } else
+ memmove(vto + pgto_base, vto + pgfrom_base, copy);
+ flush_dcache_page(*pgto);
+ kunmap_atomic(vto);
+
+ } while ((len -= copy) != 0);
+}
+
+/**
+ * _copy_to_pages
+ * @pages: array of pages
+ * @pgbase: page vector address of destination
+ * @p: pointer to source data
+ * @len: length
+ *
+ * Copies data from an arbitrary memory location into an array of pages
+ * The copy is assumed to be non-overlapping.
+ */
+static void
+_copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len)
+{
+ struct page **pgto;
+ char *vto;
+ size_t copy;
+
+ if (!len)
+ return;
+
+ pgto = pages + (pgbase >> PAGE_SHIFT);
+ pgbase &= ~PAGE_MASK;
+
+ for (;;) {
+ copy = PAGE_SIZE - pgbase;
+ if (copy > len)
+ copy = len;
+
+ vto = kmap_atomic(*pgto);
+ memcpy(vto + pgbase, p, copy);
+ kunmap_atomic(vto);
+
+ len -= copy;
+ if (len == 0)
+ break;
+
+ pgbase += copy;
+ if (pgbase == PAGE_SIZE) {
+ flush_dcache_page(*pgto);
+ pgbase = 0;
+ pgto++;
+ }
+ p += copy;
+ }
+ flush_dcache_page(*pgto);
+}
+
+/**
+ * _copy_from_pages
+ * @p: pointer to destination
+ * @pages: array of pages
+ * @pgbase: offset of source data
+ * @len: length
+ *
+ * Copies data into an arbitrary memory location from an array of pages
+ * The copy is assumed to be non-overlapping.
+ */
+void
+_copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len)
+{
+ struct page **pgfrom;
+ char *vfrom;
+ size_t copy;
+
+ if (!len)
+ return;
+
+ pgfrom = pages + (pgbase >> PAGE_SHIFT);
+ pgbase &= ~PAGE_MASK;
+
+ do {
+ copy = PAGE_SIZE - pgbase;
+ if (copy > len)
+ copy = len;
+
+ vfrom = kmap_atomic(*pgfrom);
+ memcpy(p, vfrom + pgbase, copy);
+ kunmap_atomic(vfrom);
+
+ pgbase += copy;
+ if (pgbase == PAGE_SIZE) {
+ pgbase = 0;
+ pgfrom++;
+ }
+ p += copy;
+
+ } while ((len -= copy) != 0);
+}
+EXPORT_SYMBOL_GPL(_copy_from_pages);
+
+static void xdr_buf_iov_zero(const struct kvec *iov, unsigned int base,
+ unsigned int len)
+{
+ if (base >= iov->iov_len)
+ return;
+ if (len > iov->iov_len - base)
+ len = iov->iov_len - base;
+ memset(iov->iov_base + base, 0, len);
+}
+
+/**
+ * xdr_buf_pages_zero
+ * @buf: xdr_buf
+ * @pgbase: beginning offset
+ * @len: length
+ */
+static void xdr_buf_pages_zero(const struct xdr_buf *buf, unsigned int pgbase,
+ unsigned int len)
+{
+ struct page **pages = buf->pages;
+ struct page **page;
+ char *vpage;
+ unsigned int zero;
+
+ if (!len)
+ return;
+ if (pgbase >= buf->page_len) {
+ xdr_buf_iov_zero(buf->tail, pgbase - buf->page_len, len);
+ return;
+ }
+ if (pgbase + len > buf->page_len) {
+ xdr_buf_iov_zero(buf->tail, 0, pgbase + len - buf->page_len);
+ len = buf->page_len - pgbase;
+ }
+
+ pgbase += buf->page_base;
+
+ page = pages + (pgbase >> PAGE_SHIFT);
+ pgbase &= ~PAGE_MASK;
+
+ do {
+ zero = PAGE_SIZE - pgbase;
+ if (zero > len)
+ zero = len;
+
+ vpage = kmap_atomic(*page);
+ memset(vpage + pgbase, 0, zero);
+ kunmap_atomic(vpage);
+
+ flush_dcache_page(*page);
+ pgbase = 0;
+ page++;
+
+ } while ((len -= zero) != 0);
+}
+
+static unsigned int xdr_buf_pages_fill_sparse(const struct xdr_buf *buf,
+ unsigned int buflen, gfp_t gfp)
+{
+ unsigned int i, npages, pagelen;
+
+ if (!(buf->flags & XDRBUF_SPARSE_PAGES))
+ return buflen;
+ if (buflen <= buf->head->iov_len)
+ return buflen;
+ pagelen = buflen - buf->head->iov_len;
+ if (pagelen > buf->page_len)
+ pagelen = buf->page_len;
+ npages = (pagelen + buf->page_base + PAGE_SIZE - 1) >> PAGE_SHIFT;
+ for (i = 0; i < npages; i++) {
+ if (!buf->pages[i])
+ continue;
+ buf->pages[i] = alloc_page(gfp);
+ if (likely(buf->pages[i]))
+ continue;
+ buflen -= pagelen;
+ pagelen = i << PAGE_SHIFT;
+ if (pagelen > buf->page_base)
+ buflen += pagelen - buf->page_base;
+ break;
+ }
+ return buflen;
+}
+
+static void xdr_buf_try_expand(struct xdr_buf *buf, unsigned int len)
+{
+ struct kvec *head = buf->head;
+ struct kvec *tail = buf->tail;
+ unsigned int sum = head->iov_len + buf->page_len + tail->iov_len;
+ unsigned int free_space, newlen;
+
+ if (sum > buf->len) {
+ free_space = min_t(unsigned int, sum - buf->len, len);
+ newlen = xdr_buf_pages_fill_sparse(buf, buf->len + free_space,
+ GFP_KERNEL);
+ free_space = newlen - buf->len;
+ buf->len = newlen;
+ len -= free_space;
+ if (!len)
+ return;
+ }
+
+ if (buf->buflen > sum) {
+ /* Expand the tail buffer */
+ free_space = min_t(unsigned int, buf->buflen - sum, len);
+ tail->iov_len += free_space;
+ buf->len += free_space;
+ }
+}
+
+static void xdr_buf_tail_copy_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *tail = buf->tail;
+ unsigned int to = base + shift;
+
+ if (to >= tail->iov_len)
+ return;
+ if (len + to > tail->iov_len)
+ len = tail->iov_len - to;
+ memmove(tail->iov_base + to, tail->iov_base + base, len);
+}
+
+static void xdr_buf_pages_copy_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *tail = buf->tail;
+ unsigned int to = base + shift;
+ unsigned int pglen = 0;
+ unsigned int talen = 0, tato = 0;
+
+ if (base >= buf->page_len)
+ return;
+ if (len > buf->page_len - base)
+ len = buf->page_len - base;
+ if (to >= buf->page_len) {
+ tato = to - buf->page_len;
+ if (tail->iov_len >= len + tato)
+ talen = len;
+ else if (tail->iov_len > tato)
+ talen = tail->iov_len - tato;
+ } else if (len + to >= buf->page_len) {
+ pglen = buf->page_len - to;
+ talen = len - pglen;
+ if (talen > tail->iov_len)
+ talen = tail->iov_len;
+ } else
+ pglen = len;
+
+ _copy_from_pages(tail->iov_base + tato, buf->pages,
+ buf->page_base + base + pglen, talen);
+ _shift_data_right_pages(buf->pages, buf->page_base + to,
+ buf->page_base + base, pglen);
+}
+
+static void xdr_buf_head_copy_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *head = buf->head;
+ const struct kvec *tail = buf->tail;
+ unsigned int to = base + shift;
+ unsigned int pglen = 0, pgto = 0;
+ unsigned int talen = 0, tato = 0;
+
+ if (base >= head->iov_len)
+ return;
+ if (len > head->iov_len - base)
+ len = head->iov_len - base;
+ if (to >= buf->page_len + head->iov_len) {
+ tato = to - buf->page_len - head->iov_len;
+ talen = len;
+ } else if (to >= head->iov_len) {
+ pgto = to - head->iov_len;
+ pglen = len;
+ if (pgto + pglen > buf->page_len) {
+ talen = pgto + pglen - buf->page_len;
+ pglen -= talen;
+ }
+ } else {
+ pglen = len - to;
+ if (pglen > buf->page_len) {
+ talen = pglen - buf->page_len;
+ pglen = buf->page_len;
+ }
+ }
+
+ len -= talen;
+ base += len;
+ if (talen + tato > tail->iov_len)
+ talen = tail->iov_len > tato ? tail->iov_len - tato : 0;
+ memcpy(tail->iov_base + tato, head->iov_base + base, talen);
+
+ len -= pglen;
+ base -= pglen;
+ _copy_to_pages(buf->pages, buf->page_base + pgto, head->iov_base + base,
+ pglen);
+
+ base -= len;
+ memmove(head->iov_base + to, head->iov_base + base, len);
+}
+
+static void xdr_buf_tail_shift_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *tail = buf->tail;
+
+ if (base >= tail->iov_len || !shift || !len)
+ return;
+ xdr_buf_tail_copy_right(buf, base, len, shift);
+}
+
+static void xdr_buf_pages_shift_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ if (!shift || !len)
+ return;
+ if (base >= buf->page_len) {
+ xdr_buf_tail_shift_right(buf, base - buf->page_len, len, shift);
+ return;
+ }
+ if (base + len > buf->page_len)
+ xdr_buf_tail_shift_right(buf, 0, base + len - buf->page_len,
+ shift);
+ xdr_buf_pages_copy_right(buf, base, len, shift);
+}
+
+static void xdr_buf_head_shift_right(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *head = buf->head;
+
+ if (!shift)
+ return;
+ if (base >= head->iov_len) {
+ xdr_buf_pages_shift_right(buf, head->iov_len - base, len,
+ shift);
+ return;
+ }
+ if (base + len > head->iov_len)
+ xdr_buf_pages_shift_right(buf, 0, base + len - head->iov_len,
+ shift);
+ xdr_buf_head_copy_right(buf, base, len, shift);
+}
+
+static void xdr_buf_tail_copy_left(const struct xdr_buf *buf, unsigned int base,
+ unsigned int len, unsigned int shift)
+{
+ const struct kvec *tail = buf->tail;
+
+ if (base >= tail->iov_len)
+ return;
+ if (len > tail->iov_len - base)
+ len = tail->iov_len - base;
+ /* Shift data into head */
+ if (shift > buf->page_len + base) {
+ const struct kvec *head = buf->head;
+ unsigned int hdto =
+ head->iov_len + buf->page_len + base - shift;
+ unsigned int hdlen = len;
+
+ if (WARN_ONCE(shift > head->iov_len + buf->page_len + base,
+ "SUNRPC: Misaligned data.\n"))
+ return;
+ if (hdto + hdlen > head->iov_len)
+ hdlen = head->iov_len - hdto;
+ memcpy(head->iov_base + hdto, tail->iov_base + base, hdlen);
+ base += hdlen;
+ len -= hdlen;
+ if (!len)
+ return;
+ }
+ /* Shift data into pages */
+ if (shift > base) {
+ unsigned int pgto = buf->page_len + base - shift;
+ unsigned int pglen = len;
+
+ if (pgto + pglen > buf->page_len)
+ pglen = buf->page_len - pgto;
+ _copy_to_pages(buf->pages, buf->page_base + pgto,
+ tail->iov_base + base, pglen);
+ base += pglen;
+ len -= pglen;
+ if (!len)
+ return;
+ }
+ memmove(tail->iov_base + base - shift, tail->iov_base + base, len);
+}
+
+static void xdr_buf_pages_copy_left(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ unsigned int pgto;
+
+ if (base >= buf->page_len)
+ return;
+ if (len > buf->page_len - base)
+ len = buf->page_len - base;
+ /* Shift data into head */
+ if (shift > base) {
+ const struct kvec *head = buf->head;
+ unsigned int hdto = head->iov_len + base - shift;
+ unsigned int hdlen = len;
+
+ if (WARN_ONCE(shift > head->iov_len + base,
+ "SUNRPC: Misaligned data.\n"))
+ return;
+ if (hdto + hdlen > head->iov_len)
+ hdlen = head->iov_len - hdto;
+ _copy_from_pages(head->iov_base + hdto, buf->pages,
+ buf->page_base + base, hdlen);
+ base += hdlen;
+ len -= hdlen;
+ if (!len)
+ return;
+ }
+ pgto = base - shift;
+ _shift_data_left_pages(buf->pages, buf->page_base + pgto,
+ buf->page_base + base, len);
+}
+
+static void xdr_buf_tail_shift_left(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ if (!shift || !len)
+ return;
+ xdr_buf_tail_copy_left(buf, base, len, shift);
+}
+
+static void xdr_buf_pages_shift_left(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ if (!shift || !len)
+ return;
+ if (base >= buf->page_len) {
+ xdr_buf_tail_shift_left(buf, base - buf->page_len, len, shift);
+ return;
+ }
+ xdr_buf_pages_copy_left(buf, base, len, shift);
+ len += base;
+ if (len <= buf->page_len)
+ return;
+ xdr_buf_tail_copy_left(buf, 0, len - buf->page_len, shift);
+}
+
+static void xdr_buf_head_shift_left(const struct xdr_buf *buf,
+ unsigned int base, unsigned int len,
+ unsigned int shift)
+{
+ const struct kvec *head = buf->head;
+ unsigned int bytes;
+
+ if (!shift || !len)
+ return;
+
+ if (shift > base) {
+ bytes = (shift - base);
+ if (bytes >= len)
+ return;
+ base += bytes;
+ len -= bytes;
+ }
+
+ if (base < head->iov_len) {
+ bytes = min_t(unsigned int, len, head->iov_len - base);
+ memmove(head->iov_base + (base - shift),
+ head->iov_base + base, bytes);
+ base += bytes;
+ len -= bytes;
+ }
+ xdr_buf_pages_shift_left(buf, base - head->iov_len, len, shift);
+}
+
+/**
+ * xdr_shrink_bufhead
+ * @buf: xdr_buf
+ * @len: new length of buf->head[0]
+ *
+ * Shrinks XDR buffer's header kvec buf->head[0], setting it to
+ * 'len' bytes. The extra data is not lost, but is instead
+ * moved into the inlined pages and/or the tail.
+ */
+static unsigned int xdr_shrink_bufhead(struct xdr_buf *buf, unsigned int len)
+{
+ struct kvec *head = buf->head;
+ unsigned int shift, buflen = max(buf->len, len);
+
+ WARN_ON_ONCE(len > head->iov_len);
+ if (head->iov_len > buflen) {
+ buf->buflen -= head->iov_len - buflen;
+ head->iov_len = buflen;
+ }
+ if (len >= head->iov_len)
+ return 0;
+ shift = head->iov_len - len;
+ xdr_buf_try_expand(buf, shift);
+ xdr_buf_head_shift_right(buf, len, buflen - len, shift);
+ head->iov_len = len;
+ buf->buflen -= shift;
+ buf->len -= shift;
+ return shift;
+}
+
+/**
+ * xdr_shrink_pagelen - shrinks buf->pages to @len bytes
+ * @buf: xdr_buf
+ * @len: new page buffer length
+ *
+ * The extra data is not lost, but is instead moved into buf->tail.
+ * Returns the actual number of bytes moved.
+ */
+static unsigned int xdr_shrink_pagelen(struct xdr_buf *buf, unsigned int len)
+{
+ unsigned int shift, buflen = buf->len - buf->head->iov_len;
+
+ WARN_ON_ONCE(len > buf->page_len);
+ if (buf->head->iov_len >= buf->len || len > buflen)
+ buflen = len;
+ if (buf->page_len > buflen) {
+ buf->buflen -= buf->page_len - buflen;
+ buf->page_len = buflen;
+ }
+ if (len >= buf->page_len)
+ return 0;
+ shift = buf->page_len - len;
+ xdr_buf_try_expand(buf, shift);
+ xdr_buf_pages_shift_right(buf, len, buflen - len, shift);
+ buf->page_len = len;
+ buf->len -= shift;
+ buf->buflen -= shift;
+ return shift;
+}
+
+/**
+ * xdr_stream_pos - Return the current offset from the start of the xdr_stream
+ * @xdr: pointer to struct xdr_stream
+ */
+unsigned int xdr_stream_pos(const struct xdr_stream *xdr)
+{
+ return (unsigned int)(XDR_QUADLEN(xdr->buf->len) - xdr->nwords) << 2;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_pos);
+
+static void xdr_stream_set_pos(struct xdr_stream *xdr, unsigned int pos)
+{
+ unsigned int blen = xdr->buf->len;
+
+ xdr->nwords = blen > pos ? XDR_QUADLEN(blen) - XDR_QUADLEN(pos) : 0;
+}
+
+static void xdr_stream_page_set_pos(struct xdr_stream *xdr, unsigned int pos)
+{
+ xdr_stream_set_pos(xdr, pos + xdr->buf->head[0].iov_len);
+}
+
+/**
+ * xdr_page_pos - Return the current offset from the start of the xdr pages
+ * @xdr: pointer to struct xdr_stream
+ */
+unsigned int xdr_page_pos(const struct xdr_stream *xdr)
+{
+ unsigned int pos = xdr_stream_pos(xdr);
+
+ WARN_ON(pos < xdr->buf->head[0].iov_len);
+ return pos - xdr->buf->head[0].iov_len;
+}
+EXPORT_SYMBOL_GPL(xdr_page_pos);
+
+/**
+ * xdr_init_encode - Initialize a struct xdr_stream for sending data.
+ * @xdr: pointer to xdr_stream struct
+ * @buf: pointer to XDR buffer in which to encode data
+ * @p: current pointer inside XDR buffer
+ * @rqst: pointer to controlling rpc_rqst, for debugging
+ *
+ * Note: at the moment the RPC client only passes the length of our
+ * scratch buffer in the xdr_buf's header kvec. Previously this
+ * meant we needed to call xdr_adjust_iovec() after encoding the
+ * data. With the new scheme, the xdr_stream manages the details
+ * of the buffer length, and takes care of adjusting the kvec
+ * length for us.
+ */
+void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p,
+ struct rpc_rqst *rqst)
+{
+ struct kvec *iov = buf->head;
+ int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len;
+
+ xdr_reset_scratch_buffer(xdr);
+ BUG_ON(scratch_len < 0);
+ xdr->buf = buf;
+ xdr->iov = iov;
+ xdr->p = (__be32 *)((char *)iov->iov_base + iov->iov_len);
+ xdr->end = (__be32 *)((char *)iov->iov_base + scratch_len);
+ BUG_ON(iov->iov_len > scratch_len);
+
+ if (p != xdr->p && p != NULL) {
+ size_t len;
+
+ BUG_ON(p < xdr->p || p > xdr->end);
+ len = (char *)p - (char *)xdr->p;
+ xdr->p = p;
+ buf->len += len;
+ iov->iov_len += len;
+ }
+ xdr->rqst = rqst;
+}
+EXPORT_SYMBOL_GPL(xdr_init_encode);
+
+/**
+ * xdr_init_encode_pages - Initialize an xdr_stream for encoding into pages
+ * @xdr: pointer to xdr_stream struct
+ * @buf: pointer to XDR buffer into which to encode data
+ * @pages: list of pages to decode into
+ * @rqst: pointer to controlling rpc_rqst, for debugging
+ *
+ */
+void xdr_init_encode_pages(struct xdr_stream *xdr, struct xdr_buf *buf,
+ struct page **pages, struct rpc_rqst *rqst)
+{
+ xdr_reset_scratch_buffer(xdr);
+
+ xdr->buf = buf;
+ xdr->page_ptr = pages;
+ xdr->iov = NULL;
+ xdr->p = page_address(*pages);
+ xdr->end = (void *)xdr->p + min_t(u32, buf->buflen, PAGE_SIZE);
+ xdr->rqst = rqst;
+}
+EXPORT_SYMBOL_GPL(xdr_init_encode_pages);
+
+/**
+ * __xdr_commit_encode - Ensure all data is written to buffer
+ * @xdr: pointer to xdr_stream
+ *
+ * We handle encoding across page boundaries by giving the caller a
+ * temporary location to write to, then later copying the data into
+ * place; xdr_commit_encode does that copying.
+ *
+ * Normally the caller doesn't need to call this directly, as the
+ * following xdr_reserve_space will do it. But an explicit call may be
+ * required at the end of encoding, or any other time when the xdr_buf
+ * data might be read.
+ */
+void __xdr_commit_encode(struct xdr_stream *xdr)
+{
+ size_t shift = xdr->scratch.iov_len;
+ void *page;
+
+ page = page_address(*xdr->page_ptr);
+ memcpy(xdr->scratch.iov_base, page, shift);
+ memmove(page, page + shift, (void *)xdr->p - page);
+ xdr_reset_scratch_buffer(xdr);
+}
+EXPORT_SYMBOL_GPL(__xdr_commit_encode);
+
+/*
+ * The buffer space to be reserved crosses the boundary between
+ * xdr->buf->head and xdr->buf->pages, or between two pages
+ * in xdr->buf->pages.
+ */
+static noinline __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr,
+ size_t nbytes)
+{
+ int space_left;
+ int frag1bytes, frag2bytes;
+ void *p;
+
+ if (nbytes > PAGE_SIZE)
+ goto out_overflow; /* Bigger buffers require special handling */
+ if (xdr->buf->len + nbytes > xdr->buf->buflen)
+ goto out_overflow; /* Sorry, we're totally out of space */
+ frag1bytes = (xdr->end - xdr->p) << 2;
+ frag2bytes = nbytes - frag1bytes;
+ if (xdr->iov)
+ xdr->iov->iov_len += frag1bytes;
+ else
+ xdr->buf->page_len += frag1bytes;
+ xdr->page_ptr++;
+ xdr->iov = NULL;
+
+ /*
+ * If the last encode didn't end exactly on a page boundary, the
+ * next one will straddle boundaries. Encode into the next
+ * page, then copy it back later in xdr_commit_encode. We use
+ * the "scratch" iov to track any temporarily unused fragment of
+ * space at the end of the previous buffer:
+ */
+ xdr_set_scratch_buffer(xdr, xdr->p, frag1bytes);
+
+ /*
+ * xdr->p is where the next encode will start after
+ * xdr_commit_encode() has shifted this one back:
+ */
+ p = page_address(*xdr->page_ptr);
+ xdr->p = p + frag2bytes;
+ space_left = xdr->buf->buflen - xdr->buf->len;
+ if (space_left - frag1bytes >= PAGE_SIZE)
+ xdr->end = p + PAGE_SIZE;
+ else
+ xdr->end = p + space_left - frag1bytes;
+
+ xdr->buf->page_len += frag2bytes;
+ xdr->buf->len += nbytes;
+ return p;
+out_overflow:
+ trace_rpc_xdr_overflow(xdr, nbytes);
+ return NULL;
+}
+
+/**
+ * xdr_reserve_space - Reserve buffer space for sending
+ * @xdr: pointer to xdr_stream
+ * @nbytes: number of bytes to reserve
+ *
+ * Checks that we have enough buffer space to encode 'nbytes' more
+ * bytes of data. If so, update the total xdr_buf length, and
+ * adjust the length of the current kvec.
+ */
+__be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes)
+{
+ __be32 *p = xdr->p;
+ __be32 *q;
+
+ xdr_commit_encode(xdr);
+ /* align nbytes on the next 32-bit boundary */
+ nbytes += 3;
+ nbytes &= ~3;
+ q = p + (nbytes >> 2);
+ if (unlikely(q > xdr->end || q < p))
+ return xdr_get_next_encode_buffer(xdr, nbytes);
+ xdr->p = q;
+ if (xdr->iov)
+ xdr->iov->iov_len += nbytes;
+ else
+ xdr->buf->page_len += nbytes;
+ xdr->buf->len += nbytes;
+ return p;
+}
+EXPORT_SYMBOL_GPL(xdr_reserve_space);
+
+/**
+ * xdr_reserve_space_vec - Reserves a large amount of buffer space for sending
+ * @xdr: pointer to xdr_stream
+ * @nbytes: number of bytes to reserve
+ *
+ * The size argument passed to xdr_reserve_space() is determined based
+ * on the number of bytes remaining in the current page to avoid
+ * invalidating iov_base pointers when xdr_commit_encode() is called.
+ *
+ * Return values:
+ * %0: success
+ * %-EMSGSIZE: not enough space is available in @xdr
+ */
+int xdr_reserve_space_vec(struct xdr_stream *xdr, size_t nbytes)
+{
+ size_t thislen;
+ __be32 *p;
+
+ /*
+ * svcrdma requires every READ payload to start somewhere
+ * in xdr->pages.
+ */
+ if (xdr->iov == xdr->buf->head) {
+ xdr->iov = NULL;
+ xdr->end = xdr->p;
+ }
+
+ /* XXX: Let's find a way to make this more efficient */
+ while (nbytes) {
+ thislen = xdr->buf->page_len % PAGE_SIZE;
+ thislen = min_t(size_t, nbytes, PAGE_SIZE - thislen);
+
+ p = xdr_reserve_space(xdr, thislen);
+ if (!p)
+ return -EMSGSIZE;
+
+ nbytes -= thislen;
+ }
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(xdr_reserve_space_vec);
+
+/**
+ * xdr_truncate_encode - truncate an encode buffer
+ * @xdr: pointer to xdr_stream
+ * @len: new length of buffer
+ *
+ * Truncates the xdr stream, so that xdr->buf->len == len,
+ * and xdr->p points at offset len from the start of the buffer, and
+ * head, tail, and page lengths are adjusted to correspond.
+ *
+ * If this means moving xdr->p to a different buffer, we assume that
+ * the end pointer should be set to the end of the current page,
+ * except in the case of the head buffer when we assume the head
+ * buffer's current length represents the end of the available buffer.
+ *
+ * This is *not* safe to use on a buffer that already has inlined page
+ * cache pages (as in a zero-copy server read reply), except for the
+ * simple case of truncating from one position in the tail to another.
+ *
+ */
+void xdr_truncate_encode(struct xdr_stream *xdr, size_t len)
+{
+ struct xdr_buf *buf = xdr->buf;
+ struct kvec *head = buf->head;
+ struct kvec *tail = buf->tail;
+ int fraglen;
+ int new;
+
+ if (len > buf->len) {
+ WARN_ON_ONCE(1);
+ return;
+ }
+ xdr_commit_encode(xdr);
+
+ fraglen = min_t(int, buf->len - len, tail->iov_len);
+ tail->iov_len -= fraglen;
+ buf->len -= fraglen;
+ if (tail->iov_len) {
+ xdr->p = tail->iov_base + tail->iov_len;
+ WARN_ON_ONCE(!xdr->end);
+ WARN_ON_ONCE(!xdr->iov);
+ return;
+ }
+ WARN_ON_ONCE(fraglen);
+ fraglen = min_t(int, buf->len - len, buf->page_len);
+ buf->page_len -= fraglen;
+ buf->len -= fraglen;
+
+ new = buf->page_base + buf->page_len;
+
+ xdr->page_ptr = buf->pages + (new >> PAGE_SHIFT);
+
+ if (buf->page_len) {
+ xdr->p = page_address(*xdr->page_ptr);
+ xdr->end = (void *)xdr->p + PAGE_SIZE;
+ xdr->p = (void *)xdr->p + (new % PAGE_SIZE);
+ WARN_ON_ONCE(xdr->iov);
+ return;
+ }
+ if (fraglen)
+ xdr->end = head->iov_base + head->iov_len;
+ /* (otherwise assume xdr->end is already set) */
+ xdr->page_ptr--;
+ head->iov_len = len;
+ buf->len = len;
+ xdr->p = head->iov_base + head->iov_len;
+ xdr->iov = buf->head;
+}
+EXPORT_SYMBOL(xdr_truncate_encode);
+
+/**
+ * xdr_truncate_decode - Truncate a decoding stream
+ * @xdr: pointer to struct xdr_stream
+ * @len: Number of bytes to remove
+ *
+ */
+void xdr_truncate_decode(struct xdr_stream *xdr, size_t len)
+{
+ unsigned int nbytes = xdr_align_size(len);
+
+ xdr->buf->len -= nbytes;
+ xdr->nwords -= XDR_QUADLEN(nbytes);
+}
+EXPORT_SYMBOL_GPL(xdr_truncate_decode);
+
+/**
+ * xdr_restrict_buflen - decrease available buffer space
+ * @xdr: pointer to xdr_stream
+ * @newbuflen: new maximum number of bytes available
+ *
+ * Adjust our idea of how much space is available in the buffer.
+ * If we've already used too much space in the buffer, returns -1.
+ * If the available space is already smaller than newbuflen, returns 0
+ * and does nothing. Otherwise, adjusts xdr->buf->buflen to newbuflen
+ * and ensures xdr->end is set at most offset newbuflen from the start
+ * of the buffer.
+ */
+int xdr_restrict_buflen(struct xdr_stream *xdr, int newbuflen)
+{
+ struct xdr_buf *buf = xdr->buf;
+ int left_in_this_buf = (void *)xdr->end - (void *)xdr->p;
+ int end_offset = buf->len + left_in_this_buf;
+
+ if (newbuflen < 0 || newbuflen < buf->len)
+ return -1;
+ if (newbuflen > buf->buflen)
+ return 0;
+ if (newbuflen < end_offset)
+ xdr->end = (void *)xdr->end + newbuflen - end_offset;
+ buf->buflen = newbuflen;
+ return 0;
+}
+EXPORT_SYMBOL(xdr_restrict_buflen);
+
+/**
+ * xdr_write_pages - Insert a list of pages into an XDR buffer for sending
+ * @xdr: pointer to xdr_stream
+ * @pages: array of pages to insert
+ * @base: starting offset of first data byte in @pages
+ * @len: number of data bytes in @pages to insert
+ *
+ * After the @pages are added, the tail iovec is instantiated pointing to
+ * end of the head buffer, and the stream is set up to encode subsequent
+ * items into the tail.
+ */
+void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base,
+ unsigned int len)
+{
+ struct xdr_buf *buf = xdr->buf;
+ struct kvec *tail = buf->tail;
+
+ buf->pages = pages;
+ buf->page_base = base;
+ buf->page_len = len;
+
+ tail->iov_base = xdr->p;
+ tail->iov_len = 0;
+ xdr->iov = tail;
+
+ if (len & 3) {
+ unsigned int pad = 4 - (len & 3);
+
+ BUG_ON(xdr->p >= xdr->end);
+ tail->iov_base = (char *)xdr->p + (len & 3);
+ tail->iov_len += pad;
+ len += pad;
+ *xdr->p++ = 0;
+ }
+ buf->buflen += len;
+ buf->len += len;
+}
+EXPORT_SYMBOL_GPL(xdr_write_pages);
+
+static unsigned int xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov,
+ unsigned int base, unsigned int len)
+{
+ if (len > iov->iov_len)
+ len = iov->iov_len;
+ if (unlikely(base > len))
+ base = len;
+ xdr->p = (__be32*)(iov->iov_base + base);
+ xdr->end = (__be32*)(iov->iov_base + len);
+ xdr->iov = iov;
+ xdr->page_ptr = NULL;
+ return len - base;
+}
+
+static unsigned int xdr_set_tail_base(struct xdr_stream *xdr,
+ unsigned int base, unsigned int len)
+{
+ struct xdr_buf *buf = xdr->buf;
+
+ xdr_stream_set_pos(xdr, base + buf->page_len + buf->head->iov_len);
+ return xdr_set_iov(xdr, buf->tail, base, len);
+}
+
+static void xdr_stream_unmap_current_page(struct xdr_stream *xdr)
+{
+ if (xdr->page_kaddr) {
+ kunmap_local(xdr->page_kaddr);
+ xdr->page_kaddr = NULL;
+ }
+}
+
+static unsigned int xdr_set_page_base(struct xdr_stream *xdr,
+ unsigned int base, unsigned int len)
+{
+ unsigned int pgnr;
+ unsigned int maxlen;
+ unsigned int pgoff;
+ unsigned int pgend;
+ void *kaddr;
+
+ maxlen = xdr->buf->page_len;
+ if (base >= maxlen)
+ return 0;
+ else
+ maxlen -= base;
+ if (len > maxlen)
+ len = maxlen;
+
+ xdr_stream_unmap_current_page(xdr);
+ xdr_stream_page_set_pos(xdr, base);
+ base += xdr->buf->page_base;
+
+ pgnr = base >> PAGE_SHIFT;
+ xdr->page_ptr = &xdr->buf->pages[pgnr];
+
+ if (PageHighMem(*xdr->page_ptr)) {
+ xdr->page_kaddr = kmap_local_page(*xdr->page_ptr);
+ kaddr = xdr->page_kaddr;
+ } else
+ kaddr = page_address(*xdr->page_ptr);
+
+ pgoff = base & ~PAGE_MASK;
+ xdr->p = (__be32*)(kaddr + pgoff);
+
+ pgend = pgoff + len;
+ if (pgend > PAGE_SIZE)
+ pgend = PAGE_SIZE;
+ xdr->end = (__be32*)(kaddr + pgend);
+ xdr->iov = NULL;
+ return len;
+}
+
+static void xdr_set_page(struct xdr_stream *xdr, unsigned int base,
+ unsigned int len)
+{
+ if (xdr_set_page_base(xdr, base, len) == 0) {
+ base -= xdr->buf->page_len;
+ xdr_set_tail_base(xdr, base, len);
+ }
+}
+
+static void xdr_set_next_page(struct xdr_stream *xdr)
+{
+ unsigned int newbase;
+
+ newbase = (1 + xdr->page_ptr - xdr->buf->pages) << PAGE_SHIFT;
+ newbase -= xdr->buf->page_base;
+ if (newbase < xdr->buf->page_len)
+ xdr_set_page_base(xdr, newbase, xdr_stream_remaining(xdr));
+ else
+ xdr_set_tail_base(xdr, 0, xdr_stream_remaining(xdr));
+}
+
+static bool xdr_set_next_buffer(struct xdr_stream *xdr)
+{
+ if (xdr->page_ptr != NULL)
+ xdr_set_next_page(xdr);
+ else if (xdr->iov == xdr->buf->head)
+ xdr_set_page(xdr, 0, xdr_stream_remaining(xdr));
+ return xdr->p != xdr->end;
+}
+
+/**
+ * xdr_init_decode - Initialize an xdr_stream for decoding data.
+ * @xdr: pointer to xdr_stream struct
+ * @buf: pointer to XDR buffer from which to decode data
+ * @p: current pointer inside XDR buffer
+ * @rqst: pointer to controlling rpc_rqst, for debugging
+ */
+void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p,
+ struct rpc_rqst *rqst)
+{
+ xdr->buf = buf;
+ xdr->page_kaddr = NULL;
+ xdr_reset_scratch_buffer(xdr);
+ xdr->nwords = XDR_QUADLEN(buf->len);
+ if (xdr_set_iov(xdr, buf->head, 0, buf->len) == 0 &&
+ xdr_set_page_base(xdr, 0, buf->len) == 0)
+ xdr_set_iov(xdr, buf->tail, 0, buf->len);
+ if (p != NULL && p > xdr->p && xdr->end >= p) {
+ xdr->nwords -= p - xdr->p;
+ xdr->p = p;
+ }
+ xdr->rqst = rqst;
+}
+EXPORT_SYMBOL_GPL(xdr_init_decode);
+
+/**
+ * xdr_init_decode_pages - Initialize an xdr_stream for decoding into pages
+ * @xdr: pointer to xdr_stream struct
+ * @buf: pointer to XDR buffer from which to decode data
+ * @pages: list of pages to decode into
+ * @len: length in bytes of buffer in pages
+ */
+void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf,
+ struct page **pages, unsigned int len)
+{
+ memset(buf, 0, sizeof(*buf));
+ buf->pages = pages;
+ buf->page_len = len;
+ buf->buflen = len;
+ buf->len = len;
+ xdr_init_decode(xdr, buf, NULL, NULL);
+}
+EXPORT_SYMBOL_GPL(xdr_init_decode_pages);
+
+/**
+ * xdr_finish_decode - Clean up the xdr_stream after decoding data.
+ * @xdr: pointer to xdr_stream struct
+ */
+void xdr_finish_decode(struct xdr_stream *xdr)
+{
+ xdr_stream_unmap_current_page(xdr);
+}
+EXPORT_SYMBOL(xdr_finish_decode);
+
+static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes)
+{
+ unsigned int nwords = XDR_QUADLEN(nbytes);
+ __be32 *p = xdr->p;
+ __be32 *q = p + nwords;
+
+ if (unlikely(nwords > xdr->nwords || q > xdr->end || q < p))
+ return NULL;
+ xdr->p = q;
+ xdr->nwords -= nwords;
+ return p;
+}
+
+static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes)
+{
+ __be32 *p;
+ char *cpdest = xdr->scratch.iov_base;
+ size_t cplen = (char *)xdr->end - (char *)xdr->p;
+
+ if (nbytes > xdr->scratch.iov_len)
+ goto out_overflow;
+ p = __xdr_inline_decode(xdr, cplen);
+ if (p == NULL)
+ return NULL;
+ memcpy(cpdest, p, cplen);
+ if (!xdr_set_next_buffer(xdr))
+ goto out_overflow;
+ cpdest += cplen;
+ nbytes -= cplen;
+ p = __xdr_inline_decode(xdr, nbytes);
+ if (p == NULL)
+ return NULL;
+ memcpy(cpdest, p, nbytes);
+ return xdr->scratch.iov_base;
+out_overflow:
+ trace_rpc_xdr_overflow(xdr, nbytes);
+ return NULL;
+}
+
+/**
+ * xdr_inline_decode - Retrieve XDR data to decode
+ * @xdr: pointer to xdr_stream struct
+ * @nbytes: number of bytes of data to decode
+ *
+ * Check if the input buffer is long enough to enable us to decode
+ * 'nbytes' more bytes of data starting at the current position.
+ * If so return the current pointer, then update the current
+ * pointer position.
+ */
+__be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes)
+{
+ __be32 *p;
+
+ if (unlikely(nbytes == 0))
+ return xdr->p;
+ if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr))
+ goto out_overflow;
+ p = __xdr_inline_decode(xdr, nbytes);
+ if (p != NULL)
+ return p;
+ return xdr_copy_to_scratch(xdr, nbytes);
+out_overflow:
+ trace_rpc_xdr_overflow(xdr, nbytes);
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(xdr_inline_decode);
+
+static void xdr_realign_pages(struct xdr_stream *xdr)
+{
+ struct xdr_buf *buf = xdr->buf;
+ struct kvec *iov = buf->head;
+ unsigned int cur = xdr_stream_pos(xdr);
+ unsigned int copied;
+
+ /* Realign pages to current pointer position */
+ if (iov->iov_len > cur) {
+ copied = xdr_shrink_bufhead(buf, cur);
+ trace_rpc_xdr_alignment(xdr, cur, copied);
+ xdr_set_page(xdr, 0, buf->page_len);
+ }
+}
+
+static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len)
+{
+ struct xdr_buf *buf = xdr->buf;
+ unsigned int nwords = XDR_QUADLEN(len);
+ unsigned int copied;
+
+ if (xdr->nwords == 0)
+ return 0;
+
+ xdr_realign_pages(xdr);
+ if (nwords > xdr->nwords) {
+ nwords = xdr->nwords;
+ len = nwords << 2;
+ }
+ if (buf->page_len <= len)
+ len = buf->page_len;
+ else if (nwords < xdr->nwords) {
+ /* Truncate page data and move it into the tail */
+ copied = xdr_shrink_pagelen(buf, len);
+ trace_rpc_xdr_alignment(xdr, len, copied);
+ }
+ return len;
+}
+
+/**
+ * xdr_read_pages - align page-based XDR data to current pointer position
+ * @xdr: pointer to xdr_stream struct
+ * @len: number of bytes of page data
+ *
+ * Moves data beyond the current pointer position from the XDR head[] buffer
+ * into the page list. Any data that lies beyond current position + @len
+ * bytes is moved into the XDR tail[]. The xdr_stream current position is
+ * then advanced past that data to align to the next XDR object in the tail.
+ *
+ * Returns the number of XDR encoded bytes now contained in the pages
+ */
+unsigned int xdr_read_pages(struct xdr_stream *xdr, unsigned int len)
+{
+ unsigned int nwords = XDR_QUADLEN(len);
+ unsigned int base, end, pglen;
+
+ pglen = xdr_align_pages(xdr, nwords << 2);
+ if (pglen == 0)
+ return 0;
+
+ base = (nwords << 2) - pglen;
+ end = xdr_stream_remaining(xdr) - pglen;
+
+ xdr_set_tail_base(xdr, base, end);
+ return len <= pglen ? len : pglen;
+}
+EXPORT_SYMBOL_GPL(xdr_read_pages);
+
+/**
+ * xdr_set_pagelen - Sets the length of the XDR pages
+ * @xdr: pointer to xdr_stream struct
+ * @len: new length of the XDR page data
+ *
+ * Either grows or shrinks the length of the xdr pages by setting pagelen to
+ * @len bytes. When shrinking, any extra data is moved into buf->tail, whereas
+ * when growing any data beyond the current pointer is moved into the tail.
+ *
+ * Returns True if the operation was successful, and False otherwise.
+ */
+void xdr_set_pagelen(struct xdr_stream *xdr, unsigned int len)
+{
+ struct xdr_buf *buf = xdr->buf;
+ size_t remaining = xdr_stream_remaining(xdr);
+ size_t base = 0;
+
+ if (len < buf->page_len) {
+ base = buf->page_len - len;
+ xdr_shrink_pagelen(buf, len);
+ } else {
+ xdr_buf_head_shift_right(buf, xdr_stream_pos(xdr),
+ buf->page_len, remaining);
+ if (len > buf->page_len)
+ xdr_buf_try_expand(buf, len - buf->page_len);
+ }
+ xdr_set_tail_base(xdr, base, remaining);
+}
+EXPORT_SYMBOL_GPL(xdr_set_pagelen);
+
+/**
+ * xdr_enter_page - decode data from the XDR page
+ * @xdr: pointer to xdr_stream struct
+ * @len: number of bytes of page data
+ *
+ * Moves data beyond the current pointer position from the XDR head[] buffer
+ * into the page list. Any data that lies beyond current position + "len"
+ * bytes is moved into the XDR tail[]. The current pointer is then
+ * repositioned at the beginning of the first XDR page.
+ */
+void xdr_enter_page(struct xdr_stream *xdr, unsigned int len)
+{
+ len = xdr_align_pages(xdr, len);
+ /*
+ * Position current pointer at beginning of tail, and
+ * set remaining message length.
+ */
+ if (len != 0)
+ xdr_set_page_base(xdr, 0, len);
+}
+EXPORT_SYMBOL_GPL(xdr_enter_page);
+
+static const struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0};
+
+void xdr_buf_from_iov(const struct kvec *iov, struct xdr_buf *buf)
+{
+ buf->head[0] = *iov;
+ buf->tail[0] = empty_iov;
+ buf->page_len = 0;
+ buf->buflen = buf->len = iov->iov_len;
+}
+EXPORT_SYMBOL_GPL(xdr_buf_from_iov);
+
+/**
+ * xdr_buf_subsegment - set subbuf to a portion of buf
+ * @buf: an xdr buffer
+ * @subbuf: the result buffer
+ * @base: beginning of range in bytes
+ * @len: length of range in bytes
+ *
+ * sets @subbuf to an xdr buffer representing the portion of @buf of
+ * length @len starting at offset @base.
+ *
+ * @buf and @subbuf may be pointers to the same struct xdr_buf.
+ *
+ * Returns -1 if base or length are out of bounds.
+ */
+int xdr_buf_subsegment(const struct xdr_buf *buf, struct xdr_buf *subbuf,
+ unsigned int base, unsigned int len)
+{
+ subbuf->buflen = subbuf->len = len;
+ if (base < buf->head[0].iov_len) {
+ subbuf->head[0].iov_base = buf->head[0].iov_base + base;
+ subbuf->head[0].iov_len = min_t(unsigned int, len,
+ buf->head[0].iov_len - base);
+ len -= subbuf->head[0].iov_len;
+ base = 0;
+ } else {
+ base -= buf->head[0].iov_len;
+ subbuf->head[0].iov_base = buf->head[0].iov_base;
+ subbuf->head[0].iov_len = 0;
+ }
+
+ if (base < buf->page_len) {
+ subbuf->page_len = min(buf->page_len - base, len);
+ base += buf->page_base;
+ subbuf->page_base = base & ~PAGE_MASK;
+ subbuf->pages = &buf->pages[base >> PAGE_SHIFT];
+ len -= subbuf->page_len;
+ base = 0;
+ } else {
+ base -= buf->page_len;
+ subbuf->pages = buf->pages;
+ subbuf->page_base = 0;
+ subbuf->page_len = 0;
+ }
+
+ if (base < buf->tail[0].iov_len) {
+ subbuf->tail[0].iov_base = buf->tail[0].iov_base + base;
+ subbuf->tail[0].iov_len = min_t(unsigned int, len,
+ buf->tail[0].iov_len - base);
+ len -= subbuf->tail[0].iov_len;
+ base = 0;
+ } else {
+ base -= buf->tail[0].iov_len;
+ subbuf->tail[0].iov_base = buf->tail[0].iov_base;
+ subbuf->tail[0].iov_len = 0;
+ }
+
+ if (base || len)
+ return -1;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(xdr_buf_subsegment);
+
+/**
+ * xdr_stream_subsegment - set @subbuf to a portion of @xdr
+ * @xdr: an xdr_stream set up for decoding
+ * @subbuf: the result buffer
+ * @nbytes: length of @xdr to extract, in bytes
+ *
+ * Sets up @subbuf to represent a portion of @xdr. The portion
+ * starts at the current offset in @xdr, and extends for a length
+ * of @nbytes. If this is successful, @xdr is advanced to the next
+ * XDR data item following that portion.
+ *
+ * Return values:
+ * %true: @subbuf has been initialized, and @xdr has been advanced.
+ * %false: a bounds error has occurred
+ */
+bool xdr_stream_subsegment(struct xdr_stream *xdr, struct xdr_buf *subbuf,
+ unsigned int nbytes)
+{
+ unsigned int start = xdr_stream_pos(xdr);
+ unsigned int remaining, len;
+
+ /* Extract @subbuf and bounds-check the fn arguments */
+ if (xdr_buf_subsegment(xdr->buf, subbuf, start, nbytes))
+ return false;
+
+ /* Advance @xdr by @nbytes */
+ for (remaining = nbytes; remaining;) {
+ if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr))
+ return false;
+
+ len = (char *)xdr->end - (char *)xdr->p;
+ if (remaining <= len) {
+ xdr->p = (__be32 *)((char *)xdr->p +
+ (remaining + xdr_pad_size(nbytes)));
+ break;
+ }
+
+ xdr->p = (__be32 *)((char *)xdr->p + len);
+ xdr->end = xdr->p;
+ remaining -= len;
+ }
+
+ xdr_stream_set_pos(xdr, start + nbytes);
+ return true;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_subsegment);
+
+/**
+ * xdr_stream_move_subsegment - Move part of a stream to another position
+ * @xdr: the source xdr_stream
+ * @offset: the source offset of the segment
+ * @target: the target offset of the segment
+ * @length: the number of bytes to move
+ *
+ * Moves @length bytes from @offset to @target in the xdr_stream, overwriting
+ * anything in its space. Returns the number of bytes in the segment.
+ */
+unsigned int xdr_stream_move_subsegment(struct xdr_stream *xdr, unsigned int offset,
+ unsigned int target, unsigned int length)
+{
+ struct xdr_buf buf;
+ unsigned int shift;
+
+ if (offset < target) {
+ shift = target - offset;
+ if (xdr_buf_subsegment(xdr->buf, &buf, offset, shift + length) < 0)
+ return 0;
+ xdr_buf_head_shift_right(&buf, 0, length, shift);
+ } else if (offset > target) {
+ shift = offset - target;
+ if (xdr_buf_subsegment(xdr->buf, &buf, target, shift + length) < 0)
+ return 0;
+ xdr_buf_head_shift_left(&buf, shift, length, shift);
+ }
+ return length;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_move_subsegment);
+
+/**
+ * xdr_stream_zero - zero out a portion of an xdr_stream
+ * @xdr: an xdr_stream to zero out
+ * @offset: the starting point in the stream
+ * @length: the number of bytes to zero
+ */
+unsigned int xdr_stream_zero(struct xdr_stream *xdr, unsigned int offset,
+ unsigned int length)
+{
+ struct xdr_buf buf;
+
+ if (xdr_buf_subsegment(xdr->buf, &buf, offset, length) < 0)
+ return 0;
+ if (buf.head[0].iov_len)
+ xdr_buf_iov_zero(buf.head, 0, buf.head[0].iov_len);
+ if (buf.page_len > 0)
+ xdr_buf_pages_zero(&buf, 0, buf.page_len);
+ if (buf.tail[0].iov_len)
+ xdr_buf_iov_zero(buf.tail, 0, buf.tail[0].iov_len);
+ return length;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_zero);
+
+/**
+ * xdr_buf_trim - lop at most "len" bytes off the end of "buf"
+ * @buf: buf to be trimmed
+ * @len: number of bytes to reduce "buf" by
+ *
+ * Trim an xdr_buf by the given number of bytes by fixing up the lengths. Note
+ * that it's possible that we'll trim less than that amount if the xdr_buf is
+ * too small, or if (for instance) it's all in the head and the parser has
+ * already read too far into it.
+ */
+void xdr_buf_trim(struct xdr_buf *buf, unsigned int len)
+{
+ size_t cur;
+ unsigned int trim = len;
+
+ if (buf->tail[0].iov_len) {
+ cur = min_t(size_t, buf->tail[0].iov_len, trim);
+ buf->tail[0].iov_len -= cur;
+ trim -= cur;
+ if (!trim)
+ goto fix_len;
+ }
+
+ if (buf->page_len) {
+ cur = min_t(unsigned int, buf->page_len, trim);
+ buf->page_len -= cur;
+ trim -= cur;
+ if (!trim)
+ goto fix_len;
+ }
+
+ if (buf->head[0].iov_len) {
+ cur = min_t(size_t, buf->head[0].iov_len, trim);
+ buf->head[0].iov_len -= cur;
+ trim -= cur;
+ }
+fix_len:
+ buf->len -= (len - trim);
+}
+EXPORT_SYMBOL_GPL(xdr_buf_trim);
+
+static void __read_bytes_from_xdr_buf(const struct xdr_buf *subbuf,
+ void *obj, unsigned int len)
+{
+ unsigned int this_len;
+
+ this_len = min_t(unsigned int, len, subbuf->head[0].iov_len);
+ memcpy(obj, subbuf->head[0].iov_base, this_len);
+ len -= this_len;
+ obj += this_len;
+ this_len = min_t(unsigned int, len, subbuf->page_len);
+ _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len);
+ len -= this_len;
+ obj += this_len;
+ this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len);
+ memcpy(obj, subbuf->tail[0].iov_base, this_len);
+}
+
+/* obj is assumed to point to allocated memory of size at least len: */
+int read_bytes_from_xdr_buf(const struct xdr_buf *buf, unsigned int base,
+ void *obj, unsigned int len)
+{
+ struct xdr_buf subbuf;
+ int status;
+
+ status = xdr_buf_subsegment(buf, &subbuf, base, len);
+ if (status != 0)
+ return status;
+ __read_bytes_from_xdr_buf(&subbuf, obj, len);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(read_bytes_from_xdr_buf);
+
+static void __write_bytes_to_xdr_buf(const struct xdr_buf *subbuf,
+ void *obj, unsigned int len)
+{
+ unsigned int this_len;
+
+ this_len = min_t(unsigned int, len, subbuf->head[0].iov_len);
+ memcpy(subbuf->head[0].iov_base, obj, this_len);
+ len -= this_len;
+ obj += this_len;
+ this_len = min_t(unsigned int, len, subbuf->page_len);
+ _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len);
+ len -= this_len;
+ obj += this_len;
+ this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len);
+ memcpy(subbuf->tail[0].iov_base, obj, this_len);
+}
+
+/* obj is assumed to point to allocated memory of size at least len: */
+int write_bytes_to_xdr_buf(const struct xdr_buf *buf, unsigned int base,
+ void *obj, unsigned int len)
+{
+ struct xdr_buf subbuf;
+ int status;
+
+ status = xdr_buf_subsegment(buf, &subbuf, base, len);
+ if (status != 0)
+ return status;
+ __write_bytes_to_xdr_buf(&subbuf, obj, len);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(write_bytes_to_xdr_buf);
+
+int xdr_decode_word(const struct xdr_buf *buf, unsigned int base, u32 *obj)
+{
+ __be32 raw;
+ int status;
+
+ status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj));
+ if (status)
+ return status;
+ *obj = be32_to_cpu(raw);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(xdr_decode_word);
+
+int xdr_encode_word(const struct xdr_buf *buf, unsigned int base, u32 obj)
+{
+ __be32 raw = cpu_to_be32(obj);
+
+ return write_bytes_to_xdr_buf(buf, base, &raw, sizeof(obj));
+}
+EXPORT_SYMBOL_GPL(xdr_encode_word);
+
+/* Returns 0 on success, or else a negative error code. */
+static int xdr_xcode_array2(const struct xdr_buf *buf, unsigned int base,
+ struct xdr_array2_desc *desc, int encode)
+{
+ char *elem = NULL, *c;
+ unsigned int copied = 0, todo, avail_here;
+ struct page **ppages = NULL;
+ int err;
+
+ if (encode) {
+ if (xdr_encode_word(buf, base, desc->array_len) != 0)
+ return -EINVAL;
+ } else {
+ if (xdr_decode_word(buf, base, &desc->array_len) != 0 ||
+ desc->array_len > desc->array_maxlen ||
+ (unsigned long) base + 4 + desc->array_len *
+ desc->elem_size > buf->len)
+ return -EINVAL;
+ }
+ base += 4;
+
+ if (!desc->xcode)
+ return 0;
+
+ todo = desc->array_len * desc->elem_size;
+
+ /* process head */
+ if (todo && base < buf->head->iov_len) {
+ c = buf->head->iov_base + base;
+ avail_here = min_t(unsigned int, todo,
+ buf->head->iov_len - base);
+ todo -= avail_here;
+
+ while (avail_here >= desc->elem_size) {
+ err = desc->xcode(desc, c);
+ if (err)
+ goto out;
+ c += desc->elem_size;
+ avail_here -= desc->elem_size;
+ }
+ if (avail_here) {
+ if (!elem) {
+ elem = kmalloc(desc->elem_size, GFP_KERNEL);
+ err = -ENOMEM;
+ if (!elem)
+ goto out;
+ }
+ if (encode) {
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ memcpy(c, elem, avail_here);
+ } else
+ memcpy(elem, c, avail_here);
+ copied = avail_here;
+ }
+ base = buf->head->iov_len; /* align to start of pages */
+ }
+
+ /* process pages array */
+ base -= buf->head->iov_len;
+ if (todo && base < buf->page_len) {
+ unsigned int avail_page;
+
+ avail_here = min(todo, buf->page_len - base);
+ todo -= avail_here;
+
+ base += buf->page_base;
+ ppages = buf->pages + (base >> PAGE_SHIFT);
+ base &= ~PAGE_MASK;
+ avail_page = min_t(unsigned int, PAGE_SIZE - base,
+ avail_here);
+ c = kmap(*ppages) + base;
+
+ while (avail_here) {
+ avail_here -= avail_page;
+ if (copied || avail_page < desc->elem_size) {
+ unsigned int l = min(avail_page,
+ desc->elem_size - copied);
+ if (!elem) {
+ elem = kmalloc(desc->elem_size,
+ GFP_KERNEL);
+ err = -ENOMEM;
+ if (!elem)
+ goto out;
+ }
+ if (encode) {
+ if (!copied) {
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ }
+ memcpy(c, elem + copied, l);
+ copied += l;
+ if (copied == desc->elem_size)
+ copied = 0;
+ } else {
+ memcpy(elem + copied, c, l);
+ copied += l;
+ if (copied == desc->elem_size) {
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ copied = 0;
+ }
+ }
+ avail_page -= l;
+ c += l;
+ }
+ while (avail_page >= desc->elem_size) {
+ err = desc->xcode(desc, c);
+ if (err)
+ goto out;
+ c += desc->elem_size;
+ avail_page -= desc->elem_size;
+ }
+ if (avail_page) {
+ unsigned int l = min(avail_page,
+ desc->elem_size - copied);
+ if (!elem) {
+ elem = kmalloc(desc->elem_size,
+ GFP_KERNEL);
+ err = -ENOMEM;
+ if (!elem)
+ goto out;
+ }
+ if (encode) {
+ if (!copied) {
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ }
+ memcpy(c, elem + copied, l);
+ copied += l;
+ if (copied == desc->elem_size)
+ copied = 0;
+ } else {
+ memcpy(elem + copied, c, l);
+ copied += l;
+ if (copied == desc->elem_size) {
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ copied = 0;
+ }
+ }
+ }
+ if (avail_here) {
+ kunmap(*ppages);
+ ppages++;
+ c = kmap(*ppages);
+ }
+
+ avail_page = min(avail_here,
+ (unsigned int) PAGE_SIZE);
+ }
+ base = buf->page_len; /* align to start of tail */
+ }
+
+ /* process tail */
+ base -= buf->page_len;
+ if (todo) {
+ c = buf->tail->iov_base + base;
+ if (copied) {
+ unsigned int l = desc->elem_size - copied;
+
+ if (encode)
+ memcpy(c, elem + copied, l);
+ else {
+ memcpy(elem + copied, c, l);
+ err = desc->xcode(desc, elem);
+ if (err)
+ goto out;
+ }
+ todo -= l;
+ c += l;
+ }
+ while (todo) {
+ err = desc->xcode(desc, c);
+ if (err)
+ goto out;
+ c += desc->elem_size;
+ todo -= desc->elem_size;
+ }
+ }
+ err = 0;
+
+out:
+ kfree(elem);
+ if (ppages)
+ kunmap(*ppages);
+ return err;
+}
+
+int xdr_decode_array2(const struct xdr_buf *buf, unsigned int base,
+ struct xdr_array2_desc *desc)
+{
+ if (base >= buf->len)
+ return -EINVAL;
+
+ return xdr_xcode_array2(buf, base, desc, 0);
+}
+EXPORT_SYMBOL_GPL(xdr_decode_array2);
+
+int xdr_encode_array2(const struct xdr_buf *buf, unsigned int base,
+ struct xdr_array2_desc *desc)
+{
+ if ((unsigned long) base + 4 + desc->array_len * desc->elem_size >
+ buf->head->iov_len + buf->page_len + buf->tail->iov_len)
+ return -EINVAL;
+
+ return xdr_xcode_array2(buf, base, desc, 1);
+}
+EXPORT_SYMBOL_GPL(xdr_encode_array2);
+
+int xdr_process_buf(const struct xdr_buf *buf, unsigned int offset,
+ unsigned int len,
+ int (*actor)(struct scatterlist *, void *), void *data)
+{
+ int i, ret = 0;
+ unsigned int page_len, thislen, page_offset;
+ struct scatterlist sg[1];
+
+ sg_init_table(sg, 1);
+
+ if (offset >= buf->head[0].iov_len) {
+ offset -= buf->head[0].iov_len;
+ } else {
+ thislen = buf->head[0].iov_len - offset;
+ if (thislen > len)
+ thislen = len;
+ sg_set_buf(sg, buf->head[0].iov_base + offset, thislen);
+ ret = actor(sg, data);
+ if (ret)
+ goto out;
+ offset = 0;
+ len -= thislen;
+ }
+ if (len == 0)
+ goto out;
+
+ if (offset >= buf->page_len) {
+ offset -= buf->page_len;
+ } else {
+ page_len = buf->page_len - offset;
+ if (page_len > len)
+ page_len = len;
+ len -= page_len;
+ page_offset = (offset + buf->page_base) & (PAGE_SIZE - 1);
+ i = (offset + buf->page_base) >> PAGE_SHIFT;
+ thislen = PAGE_SIZE - page_offset;
+ do {
+ if (thislen > page_len)
+ thislen = page_len;
+ sg_set_page(sg, buf->pages[i], thislen, page_offset);
+ ret = actor(sg, data);
+ if (ret)
+ goto out;
+ page_len -= thislen;
+ i++;
+ page_offset = 0;
+ thislen = PAGE_SIZE;
+ } while (page_len != 0);
+ offset = 0;
+ }
+ if (len == 0)
+ goto out;
+ if (offset < buf->tail[0].iov_len) {
+ thislen = buf->tail[0].iov_len - offset;
+ if (thislen > len)
+ thislen = len;
+ sg_set_buf(sg, buf->tail[0].iov_base + offset, thislen);
+ ret = actor(sg, data);
+ len -= thislen;
+ }
+ if (len != 0)
+ ret = -EINVAL;
+out:
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xdr_process_buf);
+
+/**
+ * xdr_stream_decode_opaque - Decode variable length opaque
+ * @xdr: pointer to xdr_stream
+ * @ptr: location to store opaque data
+ * @size: size of storage buffer @ptr
+ *
+ * Return values:
+ * On success, returns size of object stored in *@ptr
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE on overflow of storage buffer @ptr
+ */
+ssize_t xdr_stream_decode_opaque(struct xdr_stream *xdr, void *ptr, size_t size)
+{
+ ssize_t ret;
+ void *p;
+
+ ret = xdr_stream_decode_opaque_inline(xdr, &p, size);
+ if (ret <= 0)
+ return ret;
+ memcpy(ptr, p, ret);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque);
+
+/**
+ * xdr_stream_decode_opaque_dup - Decode and duplicate variable length opaque
+ * @xdr: pointer to xdr_stream
+ * @ptr: location to store pointer to opaque data
+ * @maxlen: maximum acceptable object size
+ * @gfp_flags: GFP mask to use
+ *
+ * Return values:
+ * On success, returns size of object stored in *@ptr
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE if the size of the object would exceed @maxlen
+ * %-ENOMEM on memory allocation failure
+ */
+ssize_t xdr_stream_decode_opaque_dup(struct xdr_stream *xdr, void **ptr,
+ size_t maxlen, gfp_t gfp_flags)
+{
+ ssize_t ret;
+ void *p;
+
+ ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen);
+ if (ret > 0) {
+ *ptr = kmemdup(p, ret, gfp_flags);
+ if (*ptr != NULL)
+ return ret;
+ ret = -ENOMEM;
+ }
+ *ptr = NULL;
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque_dup);
+
+/**
+ * xdr_stream_decode_string - Decode variable length string
+ * @xdr: pointer to xdr_stream
+ * @str: location to store string
+ * @size: size of storage buffer @str
+ *
+ * Return values:
+ * On success, returns length of NUL-terminated string stored in *@str
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE on overflow of storage buffer @str
+ */
+ssize_t xdr_stream_decode_string(struct xdr_stream *xdr, char *str, size_t size)
+{
+ ssize_t ret;
+ void *p;
+
+ ret = xdr_stream_decode_opaque_inline(xdr, &p, size);
+ if (ret > 0) {
+ memcpy(str, p, ret);
+ str[ret] = '\0';
+ return strlen(str);
+ }
+ *str = '\0';
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_decode_string);
+
+/**
+ * xdr_stream_decode_string_dup - Decode and duplicate variable length string
+ * @xdr: pointer to xdr_stream
+ * @str: location to store pointer to string
+ * @maxlen: maximum acceptable string length
+ * @gfp_flags: GFP mask to use
+ *
+ * Return values:
+ * On success, returns length of NUL-terminated string stored in *@ptr
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE if the size of the string would exceed @maxlen
+ * %-ENOMEM on memory allocation failure
+ */
+ssize_t xdr_stream_decode_string_dup(struct xdr_stream *xdr, char **str,
+ size_t maxlen, gfp_t gfp_flags)
+{
+ void *p;
+ ssize_t ret;
+
+ ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen);
+ if (ret > 0) {
+ char *s = kmemdup_nul(p, ret, gfp_flags);
+ if (s != NULL) {
+ *str = s;
+ return strlen(s);
+ }
+ ret = -ENOMEM;
+ }
+ *str = NULL;
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_decode_string_dup);
+
+/**
+ * xdr_stream_decode_opaque_auth - Decode struct opaque_auth (RFC5531 S8.2)
+ * @xdr: pointer to xdr_stream
+ * @flavor: location to store decoded flavor
+ * @body: location to store decode body
+ * @body_len: location to store length of decoded body
+ *
+ * Return values:
+ * On success, returns the number of buffer bytes consumed
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE if the decoded size of the body field exceeds 400 octets
+ */
+ssize_t xdr_stream_decode_opaque_auth(struct xdr_stream *xdr, u32 *flavor,
+ void **body, unsigned int *body_len)
+{
+ ssize_t ret, len;
+
+ len = xdr_stream_decode_u32(xdr, flavor);
+ if (unlikely(len < 0))
+ return len;
+ ret = xdr_stream_decode_opaque_inline(xdr, body, RPC_MAX_AUTH_SIZE);
+ if (unlikely(ret < 0))
+ return ret;
+ *body_len = ret;
+ return len + ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque_auth);
+
+/**
+ * xdr_stream_encode_opaque_auth - Encode struct opaque_auth (RFC5531 S8.2)
+ * @xdr: pointer to xdr_stream
+ * @flavor: verifier flavor to encode
+ * @body: content of body to encode
+ * @body_len: length of body to encode
+ *
+ * Return values:
+ * On success, returns length in bytes of XDR buffer consumed
+ * %-EBADMSG on XDR buffer overflow
+ * %-EMSGSIZE if the size of @body exceeds 400 octets
+ */
+ssize_t xdr_stream_encode_opaque_auth(struct xdr_stream *xdr, u32 flavor,
+ void *body, unsigned int body_len)
+{
+ ssize_t ret, len;
+
+ if (unlikely(body_len > RPC_MAX_AUTH_SIZE))
+ return -EMSGSIZE;
+ len = xdr_stream_encode_u32(xdr, flavor);
+ if (unlikely(len < 0))
+ return len;
+ ret = xdr_stream_encode_opaque(xdr, body, body_len);
+ if (unlikely(ret < 0))
+ return ret;
+ return len + ret;
+}
+EXPORT_SYMBOL_GPL(xdr_stream_encode_opaque_auth);
diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c
new file mode 100644
index 0000000000..ab453ede54
--- /dev/null
+++ b/net/sunrpc/xprt.c
@@ -0,0 +1,2192 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * linux/net/sunrpc/xprt.c
+ *
+ * This is a generic RPC call interface supporting congestion avoidance,
+ * and asynchronous calls.
+ *
+ * The interface works like this:
+ *
+ * - When a process places a call, it allocates a request slot if
+ * one is available. Otherwise, it sleeps on the backlog queue
+ * (xprt_reserve).
+ * - Next, the caller puts together the RPC message, stuffs it into
+ * the request struct, and calls xprt_transmit().
+ * - xprt_transmit sends the message and installs the caller on the
+ * transport's wait list. At the same time, if a reply is expected,
+ * it installs a timer that is run after the packet's timeout has
+ * expired.
+ * - When a packet arrives, the data_ready handler walks the list of
+ * pending requests for that transport. If a matching XID is found, the
+ * caller is woken up, and the timer removed.
+ * - When no reply arrives within the timeout interval, the timer is
+ * fired by the kernel and runs xprt_timer(). It either adjusts the
+ * timeout values (minor timeout) or wakes up the caller with a status
+ * of -ETIMEDOUT.
+ * - When the caller receives a notification from RPC that a reply arrived,
+ * it should release the RPC slot, and process the reply.
+ * If the call timed out, it may choose to retry the operation by
+ * adjusting the initial timeout value, and simply calling rpc_call
+ * again.
+ *
+ * Support for async RPC is done through a set of RPC-specific scheduling
+ * primitives that `transparently' work for processes as well as async
+ * tasks that rely on callbacks.
+ *
+ * Copyright (C) 1995-1997, Olaf Kirch <okir@monad.swb.de>
+ *
+ * Transport switch API copyright (C) 2005, Chuck Lever <cel@netapp.com>
+ */
+
+#include <linux/module.h>
+
+#include <linux/types.h>
+#include <linux/interrupt.h>
+#include <linux/workqueue.h>
+#include <linux/net.h>
+#include <linux/ktime.h>
+
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/metrics.h>
+#include <linux/sunrpc/bc_xprt.h>
+#include <linux/rcupdate.h>
+#include <linux/sched/mm.h>
+
+#include <trace/events/sunrpc.h>
+
+#include "sunrpc.h"
+#include "sysfs.h"
+#include "fail.h"
+
+/*
+ * Local variables
+ */
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# define RPCDBG_FACILITY RPCDBG_XPRT
+#endif
+
+/*
+ * Local functions
+ */
+static void xprt_init(struct rpc_xprt *xprt, struct net *net);
+static __be32 xprt_alloc_xid(struct rpc_xprt *xprt);
+static void xprt_destroy(struct rpc_xprt *xprt);
+static void xprt_request_init(struct rpc_task *task);
+static int xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf);
+
+static DEFINE_SPINLOCK(xprt_list_lock);
+static LIST_HEAD(xprt_list);
+
+static unsigned long xprt_request_timeout(const struct rpc_rqst *req)
+{
+ unsigned long timeout = jiffies + req->rq_timeout;
+
+ if (time_before(timeout, req->rq_majortimeo))
+ return timeout;
+ return req->rq_majortimeo;
+}
+
+/**
+ * xprt_register_transport - register a transport implementation
+ * @transport: transport to register
+ *
+ * If a transport implementation is loaded as a kernel module, it can
+ * call this interface to make itself known to the RPC client.
+ *
+ * Returns:
+ * 0: transport successfully registered
+ * -EEXIST: transport already registered
+ * -EINVAL: transport module being unloaded
+ */
+int xprt_register_transport(struct xprt_class *transport)
+{
+ struct xprt_class *t;
+ int result;
+
+ result = -EEXIST;
+ spin_lock(&xprt_list_lock);
+ list_for_each_entry(t, &xprt_list, list) {
+ /* don't register the same transport class twice */
+ if (t->ident == transport->ident)
+ goto out;
+ }
+
+ list_add_tail(&transport->list, &xprt_list);
+ printk(KERN_INFO "RPC: Registered %s transport module.\n",
+ transport->name);
+ result = 0;
+
+out:
+ spin_unlock(&xprt_list_lock);
+ return result;
+}
+EXPORT_SYMBOL_GPL(xprt_register_transport);
+
+/**
+ * xprt_unregister_transport - unregister a transport implementation
+ * @transport: transport to unregister
+ *
+ * Returns:
+ * 0: transport successfully unregistered
+ * -ENOENT: transport never registered
+ */
+int xprt_unregister_transport(struct xprt_class *transport)
+{
+ struct xprt_class *t;
+ int result;
+
+ result = 0;
+ spin_lock(&xprt_list_lock);
+ list_for_each_entry(t, &xprt_list, list) {
+ if (t == transport) {
+ printk(KERN_INFO
+ "RPC: Unregistered %s transport module.\n",
+ transport->name);
+ list_del_init(&transport->list);
+ goto out;
+ }
+ }
+ result = -ENOENT;
+
+out:
+ spin_unlock(&xprt_list_lock);
+ return result;
+}
+EXPORT_SYMBOL_GPL(xprt_unregister_transport);
+
+static void
+xprt_class_release(const struct xprt_class *t)
+{
+ module_put(t->owner);
+}
+
+static const struct xprt_class *
+xprt_class_find_by_ident_locked(int ident)
+{
+ const struct xprt_class *t;
+
+ list_for_each_entry(t, &xprt_list, list) {
+ if (t->ident != ident)
+ continue;
+ if (!try_module_get(t->owner))
+ continue;
+ return t;
+ }
+ return NULL;
+}
+
+static const struct xprt_class *
+xprt_class_find_by_ident(int ident)
+{
+ const struct xprt_class *t;
+
+ spin_lock(&xprt_list_lock);
+ t = xprt_class_find_by_ident_locked(ident);
+ spin_unlock(&xprt_list_lock);
+ return t;
+}
+
+static const struct xprt_class *
+xprt_class_find_by_netid_locked(const char *netid)
+{
+ const struct xprt_class *t;
+ unsigned int i;
+
+ list_for_each_entry(t, &xprt_list, list) {
+ for (i = 0; t->netid[i][0] != '\0'; i++) {
+ if (strcmp(t->netid[i], netid) != 0)
+ continue;
+ if (!try_module_get(t->owner))
+ continue;
+ return t;
+ }
+ }
+ return NULL;
+}
+
+static const struct xprt_class *
+xprt_class_find_by_netid(const char *netid)
+{
+ const struct xprt_class *t;
+
+ spin_lock(&xprt_list_lock);
+ t = xprt_class_find_by_netid_locked(netid);
+ if (!t) {
+ spin_unlock(&xprt_list_lock);
+ request_module("rpc%s", netid);
+ spin_lock(&xprt_list_lock);
+ t = xprt_class_find_by_netid_locked(netid);
+ }
+ spin_unlock(&xprt_list_lock);
+ return t;
+}
+
+/**
+ * xprt_find_transport_ident - convert a netid into a transport identifier
+ * @netid: transport to load
+ *
+ * Returns:
+ * > 0: transport identifier
+ * -ENOENT: transport module not available
+ */
+int xprt_find_transport_ident(const char *netid)
+{
+ const struct xprt_class *t;
+ int ret;
+
+ t = xprt_class_find_by_netid(netid);
+ if (!t)
+ return -ENOENT;
+ ret = t->ident;
+ xprt_class_release(t);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xprt_find_transport_ident);
+
+static void xprt_clear_locked(struct rpc_xprt *xprt)
+{
+ xprt->snd_task = NULL;
+ if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state))
+ clear_bit_unlock(XPRT_LOCKED, &xprt->state);
+ else
+ queue_work(xprtiod_workqueue, &xprt->task_cleanup);
+}
+
+/**
+ * xprt_reserve_xprt - serialize write access to transports
+ * @task: task that is requesting access to the transport
+ * @xprt: pointer to the target transport
+ *
+ * This prevents mixing the payload of separate requests, and prevents
+ * transport connects from colliding with writes. No congestion control
+ * is provided.
+ */
+int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) {
+ if (task == xprt->snd_task)
+ goto out_locked;
+ goto out_sleep;
+ }
+ if (test_bit(XPRT_WRITE_SPACE, &xprt->state))
+ goto out_unlock;
+ xprt->snd_task = task;
+
+out_locked:
+ trace_xprt_reserve_xprt(xprt, task);
+ return 1;
+
+out_unlock:
+ xprt_clear_locked(xprt);
+out_sleep:
+ task->tk_status = -EAGAIN;
+ if (RPC_IS_SOFT(task))
+ rpc_sleep_on_timeout(&xprt->sending, task, NULL,
+ xprt_request_timeout(req));
+ else
+ rpc_sleep_on(&xprt->sending, task, NULL);
+ return 0;
+}
+EXPORT_SYMBOL_GPL(xprt_reserve_xprt);
+
+static bool
+xprt_need_congestion_window_wait(struct rpc_xprt *xprt)
+{
+ return test_bit(XPRT_CWND_WAIT, &xprt->state);
+}
+
+static void
+xprt_set_congestion_window_wait(struct rpc_xprt *xprt)
+{
+ if (!list_empty(&xprt->xmit_queue)) {
+ /* Peek at head of queue to see if it can make progress */
+ if (list_first_entry(&xprt->xmit_queue, struct rpc_rqst,
+ rq_xmit)->rq_cong)
+ return;
+ }
+ set_bit(XPRT_CWND_WAIT, &xprt->state);
+}
+
+static void
+xprt_test_and_clear_congestion_window_wait(struct rpc_xprt *xprt)
+{
+ if (!RPCXPRT_CONGESTED(xprt))
+ clear_bit(XPRT_CWND_WAIT, &xprt->state);
+}
+
+/*
+ * xprt_reserve_xprt_cong - serialize write access to transports
+ * @task: task that is requesting access to the transport
+ *
+ * Same as xprt_reserve_xprt, but Van Jacobson congestion control is
+ * integrated into the decision of whether a request is allowed to be
+ * woken up and given access to the transport.
+ * Note that the lock is only granted if we know there are free slots.
+ */
+int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) {
+ if (task == xprt->snd_task)
+ goto out_locked;
+ goto out_sleep;
+ }
+ if (req == NULL) {
+ xprt->snd_task = task;
+ goto out_locked;
+ }
+ if (test_bit(XPRT_WRITE_SPACE, &xprt->state))
+ goto out_unlock;
+ if (!xprt_need_congestion_window_wait(xprt)) {
+ xprt->snd_task = task;
+ goto out_locked;
+ }
+out_unlock:
+ xprt_clear_locked(xprt);
+out_sleep:
+ task->tk_status = -EAGAIN;
+ if (RPC_IS_SOFT(task))
+ rpc_sleep_on_timeout(&xprt->sending, task, NULL,
+ xprt_request_timeout(req));
+ else
+ rpc_sleep_on(&xprt->sending, task, NULL);
+ return 0;
+out_locked:
+ trace_xprt_reserve_cong(xprt, task);
+ return 1;
+}
+EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong);
+
+static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ int retval;
+
+ if (test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == task)
+ return 1;
+ spin_lock(&xprt->transport_lock);
+ retval = xprt->ops->reserve_xprt(xprt, task);
+ spin_unlock(&xprt->transport_lock);
+ return retval;
+}
+
+static bool __xprt_lock_write_func(struct rpc_task *task, void *data)
+{
+ struct rpc_xprt *xprt = data;
+
+ xprt->snd_task = task;
+ return true;
+}
+
+static void __xprt_lock_write_next(struct rpc_xprt *xprt)
+{
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state))
+ return;
+ if (test_bit(XPRT_WRITE_SPACE, &xprt->state))
+ goto out_unlock;
+ if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending,
+ __xprt_lock_write_func, xprt))
+ return;
+out_unlock:
+ xprt_clear_locked(xprt);
+}
+
+static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt)
+{
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state))
+ return;
+ if (test_bit(XPRT_WRITE_SPACE, &xprt->state))
+ goto out_unlock;
+ if (xprt_need_congestion_window_wait(xprt))
+ goto out_unlock;
+ if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending,
+ __xprt_lock_write_func, xprt))
+ return;
+out_unlock:
+ xprt_clear_locked(xprt);
+}
+
+/**
+ * xprt_release_xprt - allow other requests to use a transport
+ * @xprt: transport with other tasks potentially waiting
+ * @task: task that is releasing access to the transport
+ *
+ * Note that "task" can be NULL. No congestion control is provided.
+ */
+void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ if (xprt->snd_task == task) {
+ xprt_clear_locked(xprt);
+ __xprt_lock_write_next(xprt);
+ }
+ trace_xprt_release_xprt(xprt, task);
+}
+EXPORT_SYMBOL_GPL(xprt_release_xprt);
+
+/**
+ * xprt_release_xprt_cong - allow other requests to use a transport
+ * @xprt: transport with other tasks potentially waiting
+ * @task: task that is releasing access to the transport
+ *
+ * Note that "task" can be NULL. Another task is awoken to use the
+ * transport if the transport's congestion window allows it.
+ */
+void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ if (xprt->snd_task == task) {
+ xprt_clear_locked(xprt);
+ __xprt_lock_write_next_cong(xprt);
+ }
+ trace_xprt_release_cong(xprt, task);
+}
+EXPORT_SYMBOL_GPL(xprt_release_xprt_cong);
+
+void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ if (xprt->snd_task != task)
+ return;
+ spin_lock(&xprt->transport_lock);
+ xprt->ops->release_xprt(xprt, task);
+ spin_unlock(&xprt->transport_lock);
+}
+
+/*
+ * Van Jacobson congestion avoidance. Check if the congestion window
+ * overflowed. Put the task to sleep if this is the case.
+ */
+static int
+__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ if (req->rq_cong)
+ return 1;
+ trace_xprt_get_cong(xprt, req->rq_task);
+ if (RPCXPRT_CONGESTED(xprt)) {
+ xprt_set_congestion_window_wait(xprt);
+ return 0;
+ }
+ req->rq_cong = 1;
+ xprt->cong += RPC_CWNDSCALE;
+ return 1;
+}
+
+/*
+ * Adjust the congestion window, and wake up the next task
+ * that has been sleeping due to congestion
+ */
+static void
+__xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ if (!req->rq_cong)
+ return;
+ req->rq_cong = 0;
+ xprt->cong -= RPC_CWNDSCALE;
+ xprt_test_and_clear_congestion_window_wait(xprt);
+ trace_xprt_put_cong(xprt, req->rq_task);
+ __xprt_lock_write_next_cong(xprt);
+}
+
+/**
+ * xprt_request_get_cong - Request congestion control credits
+ * @xprt: pointer to transport
+ * @req: pointer to RPC request
+ *
+ * Useful for transports that require congestion control.
+ */
+bool
+xprt_request_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ bool ret = false;
+
+ if (req->rq_cong)
+ return true;
+ spin_lock(&xprt->transport_lock);
+ ret = __xprt_get_cong(xprt, req) != 0;
+ spin_unlock(&xprt->transport_lock);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xprt_request_get_cong);
+
+/**
+ * xprt_release_rqst_cong - housekeeping when request is complete
+ * @task: RPC request that recently completed
+ *
+ * Useful for transports that require congestion control.
+ */
+void xprt_release_rqst_cong(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ __xprt_put_cong(req->rq_xprt, req);
+}
+EXPORT_SYMBOL_GPL(xprt_release_rqst_cong);
+
+static void xprt_clear_congestion_window_wait_locked(struct rpc_xprt *xprt)
+{
+ if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state))
+ __xprt_lock_write_next_cong(xprt);
+}
+
+/*
+ * Clear the congestion window wait flag and wake up the next
+ * entry on xprt->sending
+ */
+static void
+xprt_clear_congestion_window_wait(struct rpc_xprt *xprt)
+{
+ if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) {
+ spin_lock(&xprt->transport_lock);
+ __xprt_lock_write_next_cong(xprt);
+ spin_unlock(&xprt->transport_lock);
+ }
+}
+
+/**
+ * xprt_adjust_cwnd - adjust transport congestion window
+ * @xprt: pointer to xprt
+ * @task: recently completed RPC request used to adjust window
+ * @result: result code of completed RPC request
+ *
+ * The transport code maintains an estimate on the maximum number of out-
+ * standing RPC requests, using a smoothed version of the congestion
+ * avoidance implemented in 44BSD. This is basically the Van Jacobson
+ * congestion algorithm: If a retransmit occurs, the congestion window is
+ * halved; otherwise, it is incremented by 1/cwnd when
+ *
+ * - a reply is received and
+ * - a full number of requests are outstanding and
+ * - the congestion window hasn't been updated recently.
+ */
+void xprt_adjust_cwnd(struct rpc_xprt *xprt, struct rpc_task *task, int result)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ unsigned long cwnd = xprt->cwnd;
+
+ if (result >= 0 && cwnd <= xprt->cong) {
+ /* The (cwnd >> 1) term makes sure
+ * the result gets rounded properly. */
+ cwnd += (RPC_CWNDSCALE * RPC_CWNDSCALE + (cwnd >> 1)) / cwnd;
+ if (cwnd > RPC_MAXCWND(xprt))
+ cwnd = RPC_MAXCWND(xprt);
+ __xprt_lock_write_next_cong(xprt);
+ } else if (result == -ETIMEDOUT) {
+ cwnd >>= 1;
+ if (cwnd < RPC_CWNDSCALE)
+ cwnd = RPC_CWNDSCALE;
+ }
+ dprintk("RPC: cong %ld, cwnd was %ld, now %ld\n",
+ xprt->cong, xprt->cwnd, cwnd);
+ xprt->cwnd = cwnd;
+ __xprt_put_cong(xprt, req);
+}
+EXPORT_SYMBOL_GPL(xprt_adjust_cwnd);
+
+/**
+ * xprt_wake_pending_tasks - wake all tasks on a transport's pending queue
+ * @xprt: transport with waiting tasks
+ * @status: result code to plant in each task before waking it
+ *
+ */
+void xprt_wake_pending_tasks(struct rpc_xprt *xprt, int status)
+{
+ if (status < 0)
+ rpc_wake_up_status(&xprt->pending, status);
+ else
+ rpc_wake_up(&xprt->pending);
+}
+EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks);
+
+/**
+ * xprt_wait_for_buffer_space - wait for transport output buffer to clear
+ * @xprt: transport
+ *
+ * Note that we only set the timer for the case of RPC_IS_SOFT(), since
+ * we don't in general want to force a socket disconnection due to
+ * an incomplete RPC call transmission.
+ */
+void xprt_wait_for_buffer_space(struct rpc_xprt *xprt)
+{
+ set_bit(XPRT_WRITE_SPACE, &xprt->state);
+}
+EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space);
+
+static bool
+xprt_clear_write_space_locked(struct rpc_xprt *xprt)
+{
+ if (test_and_clear_bit(XPRT_WRITE_SPACE, &xprt->state)) {
+ __xprt_lock_write_next(xprt);
+ dprintk("RPC: write space: waking waiting task on "
+ "xprt %p\n", xprt);
+ return true;
+ }
+ return false;
+}
+
+/**
+ * xprt_write_space - wake the task waiting for transport output buffer space
+ * @xprt: transport with waiting tasks
+ *
+ * Can be called in a soft IRQ context, so xprt_write_space never sleeps.
+ */
+bool xprt_write_space(struct rpc_xprt *xprt)
+{
+ bool ret;
+
+ if (!test_bit(XPRT_WRITE_SPACE, &xprt->state))
+ return false;
+ spin_lock(&xprt->transport_lock);
+ ret = xprt_clear_write_space_locked(xprt);
+ spin_unlock(&xprt->transport_lock);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xprt_write_space);
+
+static unsigned long xprt_abs_ktime_to_jiffies(ktime_t abstime)
+{
+ s64 delta = ktime_to_ns(ktime_get() - abstime);
+ return likely(delta >= 0) ?
+ jiffies - nsecs_to_jiffies(delta) :
+ jiffies + nsecs_to_jiffies(-delta);
+}
+
+static unsigned long xprt_calc_majortimeo(struct rpc_rqst *req)
+{
+ const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout;
+ unsigned long majortimeo = req->rq_timeout;
+
+ if (to->to_exponential)
+ majortimeo <<= to->to_retries;
+ else
+ majortimeo += to->to_increment * to->to_retries;
+ if (majortimeo > to->to_maxval || majortimeo == 0)
+ majortimeo = to->to_maxval;
+ return majortimeo;
+}
+
+static void xprt_reset_majortimeo(struct rpc_rqst *req)
+{
+ req->rq_majortimeo += xprt_calc_majortimeo(req);
+}
+
+static void xprt_reset_minortimeo(struct rpc_rqst *req)
+{
+ req->rq_minortimeo += req->rq_timeout;
+}
+
+static void xprt_init_majortimeo(struct rpc_task *task, struct rpc_rqst *req)
+{
+ unsigned long time_init;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (likely(xprt && xprt_connected(xprt)))
+ time_init = jiffies;
+ else
+ time_init = xprt_abs_ktime_to_jiffies(task->tk_start);
+ req->rq_timeout = task->tk_client->cl_timeout->to_initval;
+ req->rq_majortimeo = time_init + xprt_calc_majortimeo(req);
+ req->rq_minortimeo = time_init + req->rq_timeout;
+}
+
+/**
+ * xprt_adjust_timeout - adjust timeout values for next retransmit
+ * @req: RPC request containing parameters to use for the adjustment
+ *
+ */
+int xprt_adjust_timeout(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout;
+ int status = 0;
+
+ if (time_before(jiffies, req->rq_majortimeo)) {
+ if (time_before(jiffies, req->rq_minortimeo))
+ return status;
+ if (to->to_exponential)
+ req->rq_timeout <<= 1;
+ else
+ req->rq_timeout += to->to_increment;
+ if (to->to_maxval && req->rq_timeout >= to->to_maxval)
+ req->rq_timeout = to->to_maxval;
+ req->rq_retries++;
+ } else {
+ req->rq_timeout = to->to_initval;
+ req->rq_retries = 0;
+ xprt_reset_majortimeo(req);
+ /* Reset the RTT counters == "slow start" */
+ spin_lock(&xprt->transport_lock);
+ rpc_init_rtt(req->rq_task->tk_client->cl_rtt, to->to_initval);
+ spin_unlock(&xprt->transport_lock);
+ status = -ETIMEDOUT;
+ }
+ xprt_reset_minortimeo(req);
+
+ if (req->rq_timeout == 0) {
+ printk(KERN_WARNING "xprt_adjust_timeout: rq_timeout = 0!\n");
+ req->rq_timeout = 5 * HZ;
+ }
+ return status;
+}
+
+static void xprt_autoclose(struct work_struct *work)
+{
+ struct rpc_xprt *xprt =
+ container_of(work, struct rpc_xprt, task_cleanup);
+ unsigned int pflags = memalloc_nofs_save();
+
+ trace_xprt_disconnect_auto(xprt);
+ xprt->connect_cookie++;
+ smp_mb__before_atomic();
+ clear_bit(XPRT_CLOSE_WAIT, &xprt->state);
+ xprt->ops->close(xprt);
+ xprt_release_write(xprt, NULL);
+ wake_up_bit(&xprt->state, XPRT_LOCKED);
+ memalloc_nofs_restore(pflags);
+}
+
+/**
+ * xprt_disconnect_done - mark a transport as disconnected
+ * @xprt: transport to flag for disconnect
+ *
+ */
+void xprt_disconnect_done(struct rpc_xprt *xprt)
+{
+ trace_xprt_disconnect_done(xprt);
+ spin_lock(&xprt->transport_lock);
+ xprt_clear_connected(xprt);
+ xprt_clear_write_space_locked(xprt);
+ xprt_clear_congestion_window_wait_locked(xprt);
+ xprt_wake_pending_tasks(xprt, -ENOTCONN);
+ spin_unlock(&xprt->transport_lock);
+}
+EXPORT_SYMBOL_GPL(xprt_disconnect_done);
+
+/**
+ * xprt_schedule_autoclose_locked - Try to schedule an autoclose RPC call
+ * @xprt: transport to disconnect
+ */
+static void xprt_schedule_autoclose_locked(struct rpc_xprt *xprt)
+{
+ if (test_and_set_bit(XPRT_CLOSE_WAIT, &xprt->state))
+ return;
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0)
+ queue_work(xprtiod_workqueue, &xprt->task_cleanup);
+ else if (xprt->snd_task && !test_bit(XPRT_SND_IS_COOKIE, &xprt->state))
+ rpc_wake_up_queued_task_set_status(&xprt->pending,
+ xprt->snd_task, -ENOTCONN);
+}
+
+/**
+ * xprt_force_disconnect - force a transport to disconnect
+ * @xprt: transport to disconnect
+ *
+ */
+void xprt_force_disconnect(struct rpc_xprt *xprt)
+{
+ trace_xprt_disconnect_force(xprt);
+
+ /* Don't race with the test_bit() in xprt_clear_locked() */
+ spin_lock(&xprt->transport_lock);
+ xprt_schedule_autoclose_locked(xprt);
+ spin_unlock(&xprt->transport_lock);
+}
+EXPORT_SYMBOL_GPL(xprt_force_disconnect);
+
+static unsigned int
+xprt_connect_cookie(struct rpc_xprt *xprt)
+{
+ return READ_ONCE(xprt->connect_cookie);
+}
+
+static bool
+xprt_request_retransmit_after_disconnect(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ return req->rq_connect_cookie != xprt_connect_cookie(xprt) ||
+ !xprt_connected(xprt);
+}
+
+/**
+ * xprt_conditional_disconnect - force a transport to disconnect
+ * @xprt: transport to disconnect
+ * @cookie: 'connection cookie'
+ *
+ * This attempts to break the connection if and only if 'cookie' matches
+ * the current transport 'connection cookie'. It ensures that we don't
+ * try to break the connection more than once when we need to retransmit
+ * a batch of RPC requests.
+ *
+ */
+void xprt_conditional_disconnect(struct rpc_xprt *xprt, unsigned int cookie)
+{
+ /* Don't race with the test_bit() in xprt_clear_locked() */
+ spin_lock(&xprt->transport_lock);
+ if (cookie != xprt->connect_cookie)
+ goto out;
+ if (test_bit(XPRT_CLOSING, &xprt->state))
+ goto out;
+ xprt_schedule_autoclose_locked(xprt);
+out:
+ spin_unlock(&xprt->transport_lock);
+}
+
+static bool
+xprt_has_timer(const struct rpc_xprt *xprt)
+{
+ return xprt->idle_timeout != 0;
+}
+
+static void
+xprt_schedule_autodisconnect(struct rpc_xprt *xprt)
+ __must_hold(&xprt->transport_lock)
+{
+ xprt->last_used = jiffies;
+ if (RB_EMPTY_ROOT(&xprt->recv_queue) && xprt_has_timer(xprt))
+ mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout);
+}
+
+static void
+xprt_init_autodisconnect(struct timer_list *t)
+{
+ struct rpc_xprt *xprt = from_timer(xprt, t, timer);
+
+ if (!RB_EMPTY_ROOT(&xprt->recv_queue))
+ return;
+ /* Reset xprt->last_used to avoid connect/autodisconnect cycling */
+ xprt->last_used = jiffies;
+ if (test_and_set_bit(XPRT_LOCKED, &xprt->state))
+ return;
+ queue_work(xprtiod_workqueue, &xprt->task_cleanup);
+}
+
+#if IS_ENABLED(CONFIG_FAIL_SUNRPC)
+static void xprt_inject_disconnect(struct rpc_xprt *xprt)
+{
+ if (!fail_sunrpc.ignore_client_disconnect &&
+ should_fail(&fail_sunrpc.attr, 1))
+ xprt->ops->inject_disconnect(xprt);
+}
+#else
+static inline void xprt_inject_disconnect(struct rpc_xprt *xprt)
+{
+}
+#endif
+
+bool xprt_lock_connect(struct rpc_xprt *xprt,
+ struct rpc_task *task,
+ void *cookie)
+{
+ bool ret = false;
+
+ spin_lock(&xprt->transport_lock);
+ if (!test_bit(XPRT_LOCKED, &xprt->state))
+ goto out;
+ if (xprt->snd_task != task)
+ goto out;
+ set_bit(XPRT_SND_IS_COOKIE, &xprt->state);
+ xprt->snd_task = cookie;
+ ret = true;
+out:
+ spin_unlock(&xprt->transport_lock);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(xprt_lock_connect);
+
+void xprt_unlock_connect(struct rpc_xprt *xprt, void *cookie)
+{
+ spin_lock(&xprt->transport_lock);
+ if (xprt->snd_task != cookie)
+ goto out;
+ if (!test_bit(XPRT_LOCKED, &xprt->state))
+ goto out;
+ xprt->snd_task =NULL;
+ clear_bit(XPRT_SND_IS_COOKIE, &xprt->state);
+ xprt->ops->release_xprt(xprt, NULL);
+ xprt_schedule_autodisconnect(xprt);
+out:
+ spin_unlock(&xprt->transport_lock);
+ wake_up_bit(&xprt->state, XPRT_LOCKED);
+}
+EXPORT_SYMBOL_GPL(xprt_unlock_connect);
+
+/**
+ * xprt_connect - schedule a transport connect operation
+ * @task: RPC task that is requesting the connect
+ *
+ */
+void xprt_connect(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+
+ trace_xprt_connect(xprt);
+
+ if (!xprt_bound(xprt)) {
+ task->tk_status = -EAGAIN;
+ return;
+ }
+ if (!xprt_lock_write(xprt, task))
+ return;
+
+ if (!xprt_connected(xprt) && !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) {
+ task->tk_rqstp->rq_connect_cookie = xprt->connect_cookie;
+ rpc_sleep_on_timeout(&xprt->pending, task, NULL,
+ xprt_request_timeout(task->tk_rqstp));
+
+ if (test_bit(XPRT_CLOSING, &xprt->state))
+ return;
+ if (xprt_test_and_set_connecting(xprt))
+ return;
+ /* Race breaker */
+ if (!xprt_connected(xprt)) {
+ xprt->stat.connect_start = jiffies;
+ xprt->ops->connect(xprt, task);
+ } else {
+ xprt_clear_connecting(xprt);
+ task->tk_status = 0;
+ rpc_wake_up_queued_task(&xprt->pending, task);
+ }
+ }
+ xprt_release_write(xprt, task);
+}
+
+/**
+ * xprt_reconnect_delay - compute the wait before scheduling a connect
+ * @xprt: transport instance
+ *
+ */
+unsigned long xprt_reconnect_delay(const struct rpc_xprt *xprt)
+{
+ unsigned long start, now = jiffies;
+
+ start = xprt->stat.connect_start + xprt->reestablish_timeout;
+ if (time_after(start, now))
+ return start - now;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(xprt_reconnect_delay);
+
+/**
+ * xprt_reconnect_backoff - compute the new re-establish timeout
+ * @xprt: transport instance
+ * @init_to: initial reestablish timeout
+ *
+ */
+void xprt_reconnect_backoff(struct rpc_xprt *xprt, unsigned long init_to)
+{
+ xprt->reestablish_timeout <<= 1;
+ if (xprt->reestablish_timeout > xprt->max_reconnect_timeout)
+ xprt->reestablish_timeout = xprt->max_reconnect_timeout;
+ if (xprt->reestablish_timeout < init_to)
+ xprt->reestablish_timeout = init_to;
+}
+EXPORT_SYMBOL_GPL(xprt_reconnect_backoff);
+
+enum xprt_xid_rb_cmp {
+ XID_RB_EQUAL,
+ XID_RB_LEFT,
+ XID_RB_RIGHT,
+};
+static enum xprt_xid_rb_cmp
+xprt_xid_cmp(__be32 xid1, __be32 xid2)
+{
+ if (xid1 == xid2)
+ return XID_RB_EQUAL;
+ if ((__force u32)xid1 < (__force u32)xid2)
+ return XID_RB_LEFT;
+ return XID_RB_RIGHT;
+}
+
+static struct rpc_rqst *
+xprt_request_rb_find(struct rpc_xprt *xprt, __be32 xid)
+{
+ struct rb_node *n = xprt->recv_queue.rb_node;
+ struct rpc_rqst *req;
+
+ while (n != NULL) {
+ req = rb_entry(n, struct rpc_rqst, rq_recv);
+ switch (xprt_xid_cmp(xid, req->rq_xid)) {
+ case XID_RB_LEFT:
+ n = n->rb_left;
+ break;
+ case XID_RB_RIGHT:
+ n = n->rb_right;
+ break;
+ case XID_RB_EQUAL:
+ return req;
+ }
+ }
+ return NULL;
+}
+
+static void
+xprt_request_rb_insert(struct rpc_xprt *xprt, struct rpc_rqst *new)
+{
+ struct rb_node **p = &xprt->recv_queue.rb_node;
+ struct rb_node *n = NULL;
+ struct rpc_rqst *req;
+
+ while (*p != NULL) {
+ n = *p;
+ req = rb_entry(n, struct rpc_rqst, rq_recv);
+ switch(xprt_xid_cmp(new->rq_xid, req->rq_xid)) {
+ case XID_RB_LEFT:
+ p = &n->rb_left;
+ break;
+ case XID_RB_RIGHT:
+ p = &n->rb_right;
+ break;
+ case XID_RB_EQUAL:
+ WARN_ON_ONCE(new != req);
+ return;
+ }
+ }
+ rb_link_node(&new->rq_recv, n, p);
+ rb_insert_color(&new->rq_recv, &xprt->recv_queue);
+}
+
+static void
+xprt_request_rb_remove(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ rb_erase(&req->rq_recv, &xprt->recv_queue);
+}
+
+/**
+ * xprt_lookup_rqst - find an RPC request corresponding to an XID
+ * @xprt: transport on which the original request was transmitted
+ * @xid: RPC XID of incoming reply
+ *
+ * Caller holds xprt->queue_lock.
+ */
+struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid)
+{
+ struct rpc_rqst *entry;
+
+ entry = xprt_request_rb_find(xprt, xid);
+ if (entry != NULL) {
+ trace_xprt_lookup_rqst(xprt, xid, 0);
+ entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime);
+ return entry;
+ }
+
+ dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n",
+ ntohl(xid));
+ trace_xprt_lookup_rqst(xprt, xid, -ENOENT);
+ xprt->stat.bad_xids++;
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(xprt_lookup_rqst);
+
+static bool
+xprt_is_pinned_rqst(struct rpc_rqst *req)
+{
+ return atomic_read(&req->rq_pin) != 0;
+}
+
+/**
+ * xprt_pin_rqst - Pin a request on the transport receive list
+ * @req: Request to pin
+ *
+ * Caller must ensure this is atomic with the call to xprt_lookup_rqst()
+ * so should be holding xprt->queue_lock.
+ */
+void xprt_pin_rqst(struct rpc_rqst *req)
+{
+ atomic_inc(&req->rq_pin);
+}
+EXPORT_SYMBOL_GPL(xprt_pin_rqst);
+
+/**
+ * xprt_unpin_rqst - Unpin a request on the transport receive list
+ * @req: Request to pin
+ *
+ * Caller should be holding xprt->queue_lock.
+ */
+void xprt_unpin_rqst(struct rpc_rqst *req)
+{
+ if (!test_bit(RPC_TASK_MSG_PIN_WAIT, &req->rq_task->tk_runstate)) {
+ atomic_dec(&req->rq_pin);
+ return;
+ }
+ if (atomic_dec_and_test(&req->rq_pin))
+ wake_up_var(&req->rq_pin);
+}
+EXPORT_SYMBOL_GPL(xprt_unpin_rqst);
+
+static void xprt_wait_on_pinned_rqst(struct rpc_rqst *req)
+{
+ wait_var_event(&req->rq_pin, !xprt_is_pinned_rqst(req));
+}
+
+static bool
+xprt_request_data_received(struct rpc_task *task)
+{
+ return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) &&
+ READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) != 0;
+}
+
+static bool
+xprt_request_need_enqueue_receive(struct rpc_task *task, struct rpc_rqst *req)
+{
+ return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) &&
+ READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) == 0;
+}
+
+/**
+ * xprt_request_enqueue_receive - Add an request to the receive queue
+ * @task: RPC task
+ *
+ */
+int
+xprt_request_enqueue_receive(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+ int ret;
+
+ if (!xprt_request_need_enqueue_receive(task, req))
+ return 0;
+
+ ret = xprt_request_prepare(task->tk_rqstp, &req->rq_rcv_buf);
+ if (ret)
+ return ret;
+ spin_lock(&xprt->queue_lock);
+
+ /* Update the softirq receive buffer */
+ memcpy(&req->rq_private_buf, &req->rq_rcv_buf,
+ sizeof(req->rq_private_buf));
+
+ /* Add request to the receive list */
+ xprt_request_rb_insert(xprt, req);
+ set_bit(RPC_TASK_NEED_RECV, &task->tk_runstate);
+ spin_unlock(&xprt->queue_lock);
+
+ /* Turn off autodisconnect */
+ del_timer_sync(&xprt->timer);
+ return 0;
+}
+
+/**
+ * xprt_request_dequeue_receive_locked - Remove a request from the receive queue
+ * @task: RPC task
+ *
+ * Caller must hold xprt->queue_lock.
+ */
+static void
+xprt_request_dequeue_receive_locked(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (test_and_clear_bit(RPC_TASK_NEED_RECV, &task->tk_runstate))
+ xprt_request_rb_remove(req->rq_xprt, req);
+}
+
+/**
+ * xprt_update_rtt - Update RPC RTT statistics
+ * @task: RPC request that recently completed
+ *
+ * Caller holds xprt->queue_lock.
+ */
+void xprt_update_rtt(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_rtt *rtt = task->tk_client->cl_rtt;
+ unsigned int timer = task->tk_msg.rpc_proc->p_timer;
+ long m = usecs_to_jiffies(ktime_to_us(req->rq_rtt));
+
+ if (timer) {
+ if (req->rq_ntrans == 1)
+ rpc_update_rtt(rtt, timer, m);
+ rpc_set_timeo(rtt, timer, req->rq_ntrans - 1);
+ }
+}
+EXPORT_SYMBOL_GPL(xprt_update_rtt);
+
+/**
+ * xprt_complete_rqst - called when reply processing is complete
+ * @task: RPC request that recently completed
+ * @copied: actual number of bytes received from the transport
+ *
+ * Caller holds xprt->queue_lock.
+ */
+void xprt_complete_rqst(struct rpc_task *task, int copied)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ xprt->stat.recvs++;
+
+ xdr_free_bvec(&req->rq_rcv_buf);
+ req->rq_private_buf.bvec = NULL;
+ req->rq_private_buf.len = copied;
+ /* Ensure all writes are done before we update */
+ /* req->rq_reply_bytes_recvd */
+ smp_wmb();
+ req->rq_reply_bytes_recvd = copied;
+ xprt_request_dequeue_receive_locked(task);
+ rpc_wake_up_queued_task(&xprt->pending, task);
+}
+EXPORT_SYMBOL_GPL(xprt_complete_rqst);
+
+static void xprt_timer(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (task->tk_status != -ETIMEDOUT)
+ return;
+
+ trace_xprt_timer(xprt, req->rq_xid, task->tk_status);
+ if (!req->rq_reply_bytes_recvd) {
+ if (xprt->ops->timer)
+ xprt->ops->timer(xprt, task);
+ } else
+ task->tk_status = 0;
+}
+
+/**
+ * xprt_wait_for_reply_request_def - wait for reply
+ * @task: pointer to rpc_task
+ *
+ * Set a request's retransmit timeout based on the transport's
+ * default timeout parameters. Used by transports that don't adjust
+ * the retransmit timeout based on round-trip time estimation,
+ * and put the task to sleep on the pending queue.
+ */
+void xprt_wait_for_reply_request_def(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer,
+ xprt_request_timeout(req));
+}
+EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_def);
+
+/**
+ * xprt_wait_for_reply_request_rtt - wait for reply using RTT estimator
+ * @task: pointer to rpc_task
+ *
+ * Set a request's retransmit timeout using the RTT estimator,
+ * and put the task to sleep on the pending queue.
+ */
+void xprt_wait_for_reply_request_rtt(struct rpc_task *task)
+{
+ int timer = task->tk_msg.rpc_proc->p_timer;
+ struct rpc_clnt *clnt = task->tk_client;
+ struct rpc_rtt *rtt = clnt->cl_rtt;
+ struct rpc_rqst *req = task->tk_rqstp;
+ unsigned long max_timeout = clnt->cl_timeout->to_maxval;
+ unsigned long timeout;
+
+ timeout = rpc_calc_rto(rtt, timer);
+ timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries;
+ if (timeout > max_timeout || timeout == 0)
+ timeout = max_timeout;
+ rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer,
+ jiffies + timeout);
+}
+EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_rtt);
+
+/**
+ * xprt_request_wait_receive - wait for the reply to an RPC request
+ * @task: RPC task about to send a request
+ *
+ */
+void xprt_request_wait_receive(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (!test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate))
+ return;
+ /*
+ * Sleep on the pending queue if we're expecting a reply.
+ * The spinlock ensures atomicity between the test of
+ * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on().
+ */
+ spin_lock(&xprt->queue_lock);
+ if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) {
+ xprt->ops->wait_for_reply_request(task);
+ /*
+ * Send an extra queue wakeup call if the
+ * connection was dropped in case the call to
+ * rpc_sleep_on() raced.
+ */
+ if (xprt_request_retransmit_after_disconnect(task))
+ rpc_wake_up_queued_task_set_status(&xprt->pending,
+ task, -ENOTCONN);
+ }
+ spin_unlock(&xprt->queue_lock);
+}
+
+static bool
+xprt_request_need_enqueue_transmit(struct rpc_task *task, struct rpc_rqst *req)
+{
+ return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate);
+}
+
+/**
+ * xprt_request_enqueue_transmit - queue a task for transmission
+ * @task: pointer to rpc_task
+ *
+ * Add a task to the transmission queue.
+ */
+void
+xprt_request_enqueue_transmit(struct rpc_task *task)
+{
+ struct rpc_rqst *pos, *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+ int ret;
+
+ if (xprt_request_need_enqueue_transmit(task, req)) {
+ ret = xprt_request_prepare(task->tk_rqstp, &req->rq_snd_buf);
+ if (ret) {
+ task->tk_status = ret;
+ return;
+ }
+ req->rq_bytes_sent = 0;
+ spin_lock(&xprt->queue_lock);
+ /*
+ * Requests that carry congestion control credits are added
+ * to the head of the list to avoid starvation issues.
+ */
+ if (req->rq_cong) {
+ xprt_clear_congestion_window_wait(xprt);
+ list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) {
+ if (pos->rq_cong)
+ continue;
+ /* Note: req is added _before_ pos */
+ list_add_tail(&req->rq_xmit, &pos->rq_xmit);
+ INIT_LIST_HEAD(&req->rq_xmit2);
+ goto out;
+ }
+ } else if (!req->rq_seqno) {
+ list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) {
+ if (pos->rq_task->tk_owner != task->tk_owner)
+ continue;
+ list_add_tail(&req->rq_xmit2, &pos->rq_xmit2);
+ INIT_LIST_HEAD(&req->rq_xmit);
+ goto out;
+ }
+ }
+ list_add_tail(&req->rq_xmit, &xprt->xmit_queue);
+ INIT_LIST_HEAD(&req->rq_xmit2);
+out:
+ atomic_long_inc(&xprt->xmit_queuelen);
+ set_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate);
+ spin_unlock(&xprt->queue_lock);
+ }
+}
+
+/**
+ * xprt_request_dequeue_transmit_locked - remove a task from the transmission queue
+ * @task: pointer to rpc_task
+ *
+ * Remove a task from the transmission queue
+ * Caller must hold xprt->queue_lock
+ */
+static void
+xprt_request_dequeue_transmit_locked(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (!test_and_clear_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate))
+ return;
+ if (!list_empty(&req->rq_xmit)) {
+ list_del(&req->rq_xmit);
+ if (!list_empty(&req->rq_xmit2)) {
+ struct rpc_rqst *next = list_first_entry(&req->rq_xmit2,
+ struct rpc_rqst, rq_xmit2);
+ list_del(&req->rq_xmit2);
+ list_add_tail(&next->rq_xmit, &next->rq_xprt->xmit_queue);
+ }
+ } else
+ list_del(&req->rq_xmit2);
+ atomic_long_dec(&req->rq_xprt->xmit_queuelen);
+ xdr_free_bvec(&req->rq_snd_buf);
+}
+
+/**
+ * xprt_request_dequeue_transmit - remove a task from the transmission queue
+ * @task: pointer to rpc_task
+ *
+ * Remove a task from the transmission queue
+ */
+static void
+xprt_request_dequeue_transmit(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ spin_lock(&xprt->queue_lock);
+ xprt_request_dequeue_transmit_locked(task);
+ spin_unlock(&xprt->queue_lock);
+}
+
+/**
+ * xprt_request_dequeue_xprt - remove a task from the transmit+receive queue
+ * @task: pointer to rpc_task
+ *
+ * Remove a task from the transmit and receive queues, and ensure that
+ * it is not pinned by the receive work item.
+ */
+void
+xprt_request_dequeue_xprt(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) ||
+ test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) ||
+ xprt_is_pinned_rqst(req)) {
+ spin_lock(&xprt->queue_lock);
+ while (xprt_is_pinned_rqst(req)) {
+ set_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate);
+ spin_unlock(&xprt->queue_lock);
+ xprt_wait_on_pinned_rqst(req);
+ spin_lock(&xprt->queue_lock);
+ clear_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate);
+ }
+ xprt_request_dequeue_transmit_locked(task);
+ xprt_request_dequeue_receive_locked(task);
+ spin_unlock(&xprt->queue_lock);
+ xdr_free_bvec(&req->rq_rcv_buf);
+ }
+}
+
+/**
+ * xprt_request_prepare - prepare an encoded request for transport
+ * @req: pointer to rpc_rqst
+ * @buf: pointer to send/rcv xdr_buf
+ *
+ * Calls into the transport layer to do whatever is needed to prepare
+ * the request for transmission or receive.
+ * Returns error, or zero.
+ */
+static int
+xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (xprt->ops->prepare_request)
+ return xprt->ops->prepare_request(req, buf);
+ return 0;
+}
+
+/**
+ * xprt_request_need_retransmit - Test if a task needs retransmission
+ * @task: pointer to rpc_task
+ *
+ * Test for whether a connection breakage requires the task to retransmit
+ */
+bool
+xprt_request_need_retransmit(struct rpc_task *task)
+{
+ return xprt_request_retransmit_after_disconnect(task);
+}
+
+/**
+ * xprt_prepare_transmit - reserve the transport before sending a request
+ * @task: RPC task about to send a request
+ *
+ */
+bool xprt_prepare_transmit(struct rpc_task *task)
+{
+ struct rpc_rqst *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+
+ if (!xprt_lock_write(xprt, task)) {
+ /* Race breaker: someone may have transmitted us */
+ if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate))
+ rpc_wake_up_queued_task_set_status(&xprt->sending,
+ task, 0);
+ return false;
+
+ }
+ if (atomic_read(&xprt->swapper))
+ /* This will be clear in __rpc_execute */
+ current->flags |= PF_MEMALLOC;
+ return true;
+}
+
+void xprt_end_transmit(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+
+ xprt_inject_disconnect(xprt);
+ xprt_release_write(xprt, task);
+}
+
+/**
+ * xprt_request_transmit - send an RPC request on a transport
+ * @req: pointer to request to transmit
+ * @snd_task: RPC task that owns the transport lock
+ *
+ * This performs the transmission of a single request.
+ * Note that if the request is not the same as snd_task, then it
+ * does need to be pinned.
+ * Returns '0' on success.
+ */
+static int
+xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct rpc_task *task = req->rq_task;
+ unsigned int connect_cookie;
+ int is_retrans = RPC_WAS_SENT(task);
+ int status;
+
+ if (!req->rq_bytes_sent) {
+ if (xprt_request_data_received(task)) {
+ status = 0;
+ goto out_dequeue;
+ }
+ /* Verify that our message lies in the RPCSEC_GSS window */
+ if (rpcauth_xmit_need_reencode(task)) {
+ status = -EBADMSG;
+ goto out_dequeue;
+ }
+ if (RPC_SIGNALLED(task)) {
+ status = -ERESTARTSYS;
+ goto out_dequeue;
+ }
+ }
+
+ /*
+ * Update req->rq_ntrans before transmitting to avoid races with
+ * xprt_update_rtt(), which needs to know that it is recording a
+ * reply to the first transmission.
+ */
+ req->rq_ntrans++;
+
+ trace_rpc_xdr_sendto(task, &req->rq_snd_buf);
+ connect_cookie = xprt->connect_cookie;
+ status = xprt->ops->send_request(req);
+ if (status != 0) {
+ req->rq_ntrans--;
+ trace_xprt_transmit(req, status);
+ return status;
+ }
+
+ if (is_retrans) {
+ task->tk_client->cl_stats->rpcretrans++;
+ trace_xprt_retransmit(req);
+ }
+
+ xprt_inject_disconnect(xprt);
+
+ task->tk_flags |= RPC_TASK_SENT;
+ spin_lock(&xprt->transport_lock);
+
+ xprt->stat.sends++;
+ xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs;
+ xprt->stat.bklog_u += xprt->backlog.qlen;
+ xprt->stat.sending_u += xprt->sending.qlen;
+ xprt->stat.pending_u += xprt->pending.qlen;
+ spin_unlock(&xprt->transport_lock);
+
+ req->rq_connect_cookie = connect_cookie;
+out_dequeue:
+ trace_xprt_transmit(req, status);
+ xprt_request_dequeue_transmit(task);
+ rpc_wake_up_queued_task_set_status(&xprt->sending, task, status);
+ return status;
+}
+
+/**
+ * xprt_transmit - send an RPC request on a transport
+ * @task: controlling RPC task
+ *
+ * Attempts to drain the transmit queue. On exit, either the transport
+ * signalled an error that needs to be handled before transmission can
+ * resume, or @task finished transmitting, and detected that it already
+ * received a reply.
+ */
+void
+xprt_transmit(struct rpc_task *task)
+{
+ struct rpc_rqst *next, *req = task->tk_rqstp;
+ struct rpc_xprt *xprt = req->rq_xprt;
+ int status;
+
+ spin_lock(&xprt->queue_lock);
+ for (;;) {
+ next = list_first_entry_or_null(&xprt->xmit_queue,
+ struct rpc_rqst, rq_xmit);
+ if (!next)
+ break;
+ xprt_pin_rqst(next);
+ spin_unlock(&xprt->queue_lock);
+ status = xprt_request_transmit(next, task);
+ if (status == -EBADMSG && next != req)
+ status = 0;
+ spin_lock(&xprt->queue_lock);
+ xprt_unpin_rqst(next);
+ if (status < 0) {
+ if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate))
+ task->tk_status = status;
+ break;
+ }
+ /* Was @task transmitted, and has it received a reply? */
+ if (xprt_request_data_received(task) &&
+ !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate))
+ break;
+ cond_resched_lock(&xprt->queue_lock);
+ }
+ spin_unlock(&xprt->queue_lock);
+}
+
+static void xprt_complete_request_init(struct rpc_task *task)
+{
+ if (task->tk_rqstp)
+ xprt_request_init(task);
+}
+
+void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ set_bit(XPRT_CONGESTED, &xprt->state);
+ rpc_sleep_on(&xprt->backlog, task, xprt_complete_request_init);
+}
+EXPORT_SYMBOL_GPL(xprt_add_backlog);
+
+static bool __xprt_set_rq(struct rpc_task *task, void *data)
+{
+ struct rpc_rqst *req = data;
+
+ if (task->tk_rqstp == NULL) {
+ memset(req, 0, sizeof(*req)); /* mark unused */
+ task->tk_rqstp = req;
+ return true;
+ }
+ return false;
+}
+
+bool xprt_wake_up_backlog(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ if (rpc_wake_up_first(&xprt->backlog, __xprt_set_rq, req) == NULL) {
+ clear_bit(XPRT_CONGESTED, &xprt->state);
+ return false;
+ }
+ return true;
+}
+EXPORT_SYMBOL_GPL(xprt_wake_up_backlog);
+
+static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ bool ret = false;
+
+ if (!test_bit(XPRT_CONGESTED, &xprt->state))
+ goto out;
+ spin_lock(&xprt->reserve_lock);
+ if (test_bit(XPRT_CONGESTED, &xprt->state)) {
+ xprt_add_backlog(xprt, task);
+ ret = true;
+ }
+ spin_unlock(&xprt->reserve_lock);
+out:
+ return ret;
+}
+
+static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt)
+{
+ struct rpc_rqst *req = ERR_PTR(-EAGAIN);
+
+ if (xprt->num_reqs >= xprt->max_reqs)
+ goto out;
+ ++xprt->num_reqs;
+ spin_unlock(&xprt->reserve_lock);
+ req = kzalloc(sizeof(*req), rpc_task_gfp_mask());
+ spin_lock(&xprt->reserve_lock);
+ if (req != NULL)
+ goto out;
+ --xprt->num_reqs;
+ req = ERR_PTR(-ENOMEM);
+out:
+ return req;
+}
+
+static bool xprt_dynamic_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ if (xprt->num_reqs > xprt->min_reqs) {
+ --xprt->num_reqs;
+ kfree(req);
+ return true;
+ }
+ return false;
+}
+
+void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct rpc_rqst *req;
+
+ spin_lock(&xprt->reserve_lock);
+ if (!list_empty(&xprt->free)) {
+ req = list_entry(xprt->free.next, struct rpc_rqst, rq_list);
+ list_del(&req->rq_list);
+ goto out_init_req;
+ }
+ req = xprt_dynamic_alloc_slot(xprt);
+ if (!IS_ERR(req))
+ goto out_init_req;
+ switch (PTR_ERR(req)) {
+ case -ENOMEM:
+ dprintk("RPC: dynamic allocation of request slot "
+ "failed! Retrying\n");
+ task->tk_status = -ENOMEM;
+ break;
+ case -EAGAIN:
+ xprt_add_backlog(xprt, task);
+ dprintk("RPC: waiting for request slot\n");
+ fallthrough;
+ default:
+ task->tk_status = -EAGAIN;
+ }
+ spin_unlock(&xprt->reserve_lock);
+ return;
+out_init_req:
+ xprt->stat.max_slots = max_t(unsigned int, xprt->stat.max_slots,
+ xprt->num_reqs);
+ spin_unlock(&xprt->reserve_lock);
+
+ task->tk_status = 0;
+ task->tk_rqstp = req;
+}
+EXPORT_SYMBOL_GPL(xprt_alloc_slot);
+
+void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req)
+{
+ spin_lock(&xprt->reserve_lock);
+ if (!xprt_wake_up_backlog(xprt, req) &&
+ !xprt_dynamic_free_slot(xprt, req)) {
+ memset(req, 0, sizeof(*req)); /* mark unused */
+ list_add(&req->rq_list, &xprt->free);
+ }
+ spin_unlock(&xprt->reserve_lock);
+}
+EXPORT_SYMBOL_GPL(xprt_free_slot);
+
+static void xprt_free_all_slots(struct rpc_xprt *xprt)
+{
+ struct rpc_rqst *req;
+ while (!list_empty(&xprt->free)) {
+ req = list_first_entry(&xprt->free, struct rpc_rqst, rq_list);
+ list_del(&req->rq_list);
+ kfree(req);
+ }
+}
+
+static DEFINE_IDA(rpc_xprt_ids);
+
+void xprt_cleanup_ids(void)
+{
+ ida_destroy(&rpc_xprt_ids);
+}
+
+static int xprt_alloc_id(struct rpc_xprt *xprt)
+{
+ int id;
+
+ id = ida_alloc(&rpc_xprt_ids, GFP_KERNEL);
+ if (id < 0)
+ return id;
+
+ xprt->id = id;
+ return 0;
+}
+
+static void xprt_free_id(struct rpc_xprt *xprt)
+{
+ ida_free(&rpc_xprt_ids, xprt->id);
+}
+
+struct rpc_xprt *xprt_alloc(struct net *net, size_t size,
+ unsigned int num_prealloc,
+ unsigned int max_alloc)
+{
+ struct rpc_xprt *xprt;
+ struct rpc_rqst *req;
+ int i;
+
+ xprt = kzalloc(size, GFP_KERNEL);
+ if (xprt == NULL)
+ goto out;
+
+ xprt_alloc_id(xprt);
+ xprt_init(xprt, net);
+
+ for (i = 0; i < num_prealloc; i++) {
+ req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL);
+ if (!req)
+ goto out_free;
+ list_add(&req->rq_list, &xprt->free);
+ }
+ xprt->max_reqs = max_t(unsigned int, max_alloc, num_prealloc);
+ xprt->min_reqs = num_prealloc;
+ xprt->num_reqs = num_prealloc;
+
+ return xprt;
+
+out_free:
+ xprt_free(xprt);
+out:
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(xprt_alloc);
+
+void xprt_free(struct rpc_xprt *xprt)
+{
+ put_net_track(xprt->xprt_net, &xprt->ns_tracker);
+ xprt_free_all_slots(xprt);
+ xprt_free_id(xprt);
+ rpc_sysfs_xprt_destroy(xprt);
+ kfree_rcu(xprt, rcu);
+}
+EXPORT_SYMBOL_GPL(xprt_free);
+
+static void
+xprt_init_connect_cookie(struct rpc_rqst *req, struct rpc_xprt *xprt)
+{
+ req->rq_connect_cookie = xprt_connect_cookie(xprt) - 1;
+}
+
+static __be32
+xprt_alloc_xid(struct rpc_xprt *xprt)
+{
+ __be32 xid;
+
+ spin_lock(&xprt->reserve_lock);
+ xid = (__force __be32)xprt->xid++;
+ spin_unlock(&xprt->reserve_lock);
+ return xid;
+}
+
+static void
+xprt_init_xid(struct rpc_xprt *xprt)
+{
+ xprt->xid = get_random_u32();
+}
+
+static void
+xprt_request_init(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_xprt;
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ req->rq_task = task;
+ req->rq_xprt = xprt;
+ req->rq_buffer = NULL;
+ req->rq_xid = xprt_alloc_xid(xprt);
+ xprt_init_connect_cookie(req, xprt);
+ req->rq_snd_buf.len = 0;
+ req->rq_snd_buf.buflen = 0;
+ req->rq_rcv_buf.len = 0;
+ req->rq_rcv_buf.buflen = 0;
+ req->rq_snd_buf.bvec = NULL;
+ req->rq_rcv_buf.bvec = NULL;
+ req->rq_release_snd_buf = NULL;
+ xprt_init_majortimeo(task, req);
+
+ trace_xprt_reserve(req);
+}
+
+static void
+xprt_do_reserve(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ xprt->ops->alloc_slot(xprt, task);
+ if (task->tk_rqstp != NULL)
+ xprt_request_init(task);
+}
+
+/**
+ * xprt_reserve - allocate an RPC request slot
+ * @task: RPC task requesting a slot allocation
+ *
+ * If the transport is marked as being congested, or if no more
+ * slots are available, place the task on the transport's
+ * backlog queue.
+ */
+void xprt_reserve(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_xprt;
+
+ task->tk_status = 0;
+ if (task->tk_rqstp != NULL)
+ return;
+
+ task->tk_status = -EAGAIN;
+ if (!xprt_throttle_congested(xprt, task))
+ xprt_do_reserve(xprt, task);
+}
+
+/**
+ * xprt_retry_reserve - allocate an RPC request slot
+ * @task: RPC task requesting a slot allocation
+ *
+ * If no more slots are available, place the task on the transport's
+ * backlog queue.
+ * Note that the only difference with xprt_reserve is that we now
+ * ignore the value of the XPRT_CONGESTED flag.
+ */
+void xprt_retry_reserve(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt = task->tk_xprt;
+
+ task->tk_status = 0;
+ if (task->tk_rqstp != NULL)
+ return;
+
+ task->tk_status = -EAGAIN;
+ xprt_do_reserve(xprt, task);
+}
+
+/**
+ * xprt_release - release an RPC request slot
+ * @task: task which is finished with the slot
+ *
+ */
+void xprt_release(struct rpc_task *task)
+{
+ struct rpc_xprt *xprt;
+ struct rpc_rqst *req = task->tk_rqstp;
+
+ if (req == NULL) {
+ if (task->tk_client) {
+ xprt = task->tk_xprt;
+ xprt_release_write(xprt, task);
+ }
+ return;
+ }
+
+ xprt = req->rq_xprt;
+ xprt_request_dequeue_xprt(task);
+ spin_lock(&xprt->transport_lock);
+ xprt->ops->release_xprt(xprt, task);
+ if (xprt->ops->release_request)
+ xprt->ops->release_request(task);
+ xprt_schedule_autodisconnect(xprt);
+ spin_unlock(&xprt->transport_lock);
+ if (req->rq_buffer)
+ xprt->ops->buf_free(task);
+ if (req->rq_cred != NULL)
+ put_rpccred(req->rq_cred);
+ if (req->rq_release_snd_buf)
+ req->rq_release_snd_buf(req);
+
+ task->tk_rqstp = NULL;
+ if (likely(!bc_prealloc(req)))
+ xprt->ops->free_slot(xprt, req);
+ else
+ xprt_free_bc_request(req);
+}
+
+#ifdef CONFIG_SUNRPC_BACKCHANNEL
+void
+xprt_init_bc_request(struct rpc_rqst *req, struct rpc_task *task)
+{
+ struct xdr_buf *xbufp = &req->rq_snd_buf;
+
+ task->tk_rqstp = req;
+ req->rq_task = task;
+ xprt_init_connect_cookie(req, req->rq_xprt);
+ /*
+ * Set up the xdr_buf length.
+ * This also indicates that the buffer is XDR encoded already.
+ */
+ xbufp->len = xbufp->head[0].iov_len + xbufp->page_len +
+ xbufp->tail[0].iov_len;
+}
+#endif
+
+static void xprt_init(struct rpc_xprt *xprt, struct net *net)
+{
+ kref_init(&xprt->kref);
+
+ spin_lock_init(&xprt->transport_lock);
+ spin_lock_init(&xprt->reserve_lock);
+ spin_lock_init(&xprt->queue_lock);
+
+ INIT_LIST_HEAD(&xprt->free);
+ xprt->recv_queue = RB_ROOT;
+ INIT_LIST_HEAD(&xprt->xmit_queue);
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+ spin_lock_init(&xprt->bc_pa_lock);
+ INIT_LIST_HEAD(&xprt->bc_pa_list);
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+ INIT_LIST_HEAD(&xprt->xprt_switch);
+
+ xprt->last_used = jiffies;
+ xprt->cwnd = RPC_INITCWND;
+ xprt->bind_index = 0;
+
+ rpc_init_wait_queue(&xprt->binding, "xprt_binding");
+ rpc_init_wait_queue(&xprt->pending, "xprt_pending");
+ rpc_init_wait_queue(&xprt->sending, "xprt_sending");
+ rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog");
+
+ xprt_init_xid(xprt);
+
+ xprt->xprt_net = get_net_track(net, &xprt->ns_tracker, GFP_KERNEL);
+}
+
+/**
+ * xprt_create_transport - create an RPC transport
+ * @args: rpc transport creation arguments
+ *
+ */
+struct rpc_xprt *xprt_create_transport(struct xprt_create *args)
+{
+ struct rpc_xprt *xprt;
+ const struct xprt_class *t;
+
+ t = xprt_class_find_by_ident(args->ident);
+ if (!t) {
+ dprintk("RPC: transport (%d) not supported\n", args->ident);
+ return ERR_PTR(-EIO);
+ }
+
+ xprt = t->setup(args);
+ xprt_class_release(t);
+
+ if (IS_ERR(xprt))
+ goto out;
+ if (args->flags & XPRT_CREATE_NO_IDLE_TIMEOUT)
+ xprt->idle_timeout = 0;
+ INIT_WORK(&xprt->task_cleanup, xprt_autoclose);
+ if (xprt_has_timer(xprt))
+ timer_setup(&xprt->timer, xprt_init_autodisconnect, 0);
+ else
+ timer_setup(&xprt->timer, NULL, 0);
+
+ if (strlen(args->servername) > RPC_MAXNETNAMELEN) {
+ xprt_destroy(xprt);
+ return ERR_PTR(-EINVAL);
+ }
+ xprt->servername = kstrdup(args->servername, GFP_KERNEL);
+ if (xprt->servername == NULL) {
+ xprt_destroy(xprt);
+ return ERR_PTR(-ENOMEM);
+ }
+
+ rpc_xprt_debugfs_register(xprt);
+
+ trace_xprt_create(xprt);
+out:
+ return xprt;
+}
+
+static void xprt_destroy_cb(struct work_struct *work)
+{
+ struct rpc_xprt *xprt =
+ container_of(work, struct rpc_xprt, task_cleanup);
+
+ trace_xprt_destroy(xprt);
+
+ rpc_xprt_debugfs_unregister(xprt);
+ rpc_destroy_wait_queue(&xprt->binding);
+ rpc_destroy_wait_queue(&xprt->pending);
+ rpc_destroy_wait_queue(&xprt->sending);
+ rpc_destroy_wait_queue(&xprt->backlog);
+ kfree(xprt->servername);
+ /*
+ * Destroy any existing back channel
+ */
+ xprt_destroy_backchannel(xprt, UINT_MAX);
+
+ /*
+ * Tear down transport state and free the rpc_xprt
+ */
+ xprt->ops->destroy(xprt);
+}
+
+/**
+ * xprt_destroy - destroy an RPC transport, killing off all requests.
+ * @xprt: transport to destroy
+ *
+ */
+static void xprt_destroy(struct rpc_xprt *xprt)
+{
+ /*
+ * Exclude transport connect/disconnect handlers and autoclose
+ */
+ wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_UNINTERRUPTIBLE);
+
+ /*
+ * xprt_schedule_autodisconnect() can run after XPRT_LOCKED
+ * is cleared. We use ->transport_lock to ensure the mod_timer()
+ * can only run *before* del_time_sync(), never after.
+ */
+ spin_lock(&xprt->transport_lock);
+ del_timer_sync(&xprt->timer);
+ spin_unlock(&xprt->transport_lock);
+
+ /*
+ * Destroy sockets etc from the system workqueue so they can
+ * safely flush receive work running on rpciod.
+ */
+ INIT_WORK(&xprt->task_cleanup, xprt_destroy_cb);
+ schedule_work(&xprt->task_cleanup);
+}
+
+static void xprt_destroy_kref(struct kref *kref)
+{
+ xprt_destroy(container_of(kref, struct rpc_xprt, kref));
+}
+
+/**
+ * xprt_get - return a reference to an RPC transport.
+ * @xprt: pointer to the transport
+ *
+ */
+struct rpc_xprt *xprt_get(struct rpc_xprt *xprt)
+{
+ if (xprt != NULL && kref_get_unless_zero(&xprt->kref))
+ return xprt;
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(xprt_get);
+
+/**
+ * xprt_put - release a reference to an RPC transport.
+ * @xprt: pointer to the transport
+ *
+ */
+void xprt_put(struct rpc_xprt *xprt)
+{
+ if (xprt != NULL)
+ kref_put(&xprt->kref, xprt_destroy_kref);
+}
+EXPORT_SYMBOL_GPL(xprt_put);
+
+void xprt_set_offline_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps)
+{
+ if (!test_and_set_bit(XPRT_OFFLINE, &xprt->state)) {
+ spin_lock(&xps->xps_lock);
+ xps->xps_nactive--;
+ spin_unlock(&xps->xps_lock);
+ }
+}
+
+void xprt_set_online_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps)
+{
+ if (test_and_clear_bit(XPRT_OFFLINE, &xprt->state)) {
+ spin_lock(&xps->xps_lock);
+ xps->xps_nactive++;
+ spin_unlock(&xps->xps_lock);
+ }
+}
+
+void xprt_delete_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps)
+{
+ if (test_and_set_bit(XPRT_REMOVE, &xprt->state))
+ return;
+
+ xprt_force_disconnect(xprt);
+ if (!test_bit(XPRT_CONNECTED, &xprt->state))
+ return;
+
+ if (!xprt->sending.qlen && !xprt->pending.qlen &&
+ !xprt->backlog.qlen && !atomic_long_read(&xprt->queuelen))
+ rpc_xprt_switch_remove_xprt(xps, xprt, true);
+}
diff --git a/net/sunrpc/xprtmultipath.c b/net/sunrpc/xprtmultipath.c
new file mode 100644
index 0000000000..74ee227125
--- /dev/null
+++ b/net/sunrpc/xprtmultipath.c
@@ -0,0 +1,655 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Multipath support for RPC
+ *
+ * Copyright (c) 2015, 2016, Primary Data, Inc. All rights reserved.
+ *
+ * Trond Myklebust <trond.myklebust@primarydata.com>
+ *
+ */
+#include <linux/atomic.h>
+#include <linux/types.h>
+#include <linux/kref.h>
+#include <linux/list.h>
+#include <linux/rcupdate.h>
+#include <linux/rculist.h>
+#include <linux/slab.h>
+#include <linux/spinlock.h>
+#include <linux/sunrpc/xprt.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/xprtmultipath.h>
+
+#include "sysfs.h"
+
+typedef struct rpc_xprt *(*xprt_switch_find_xprt_t)(struct rpc_xprt_switch *xps,
+ const struct rpc_xprt *cur);
+
+static const struct rpc_xprt_iter_ops rpc_xprt_iter_singular;
+static const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin;
+static const struct rpc_xprt_iter_ops rpc_xprt_iter_listall;
+static const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline;
+
+static void xprt_switch_add_xprt_locked(struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt)
+{
+ if (unlikely(xprt_get(xprt) == NULL))
+ return;
+ list_add_tail_rcu(&xprt->xprt_switch, &xps->xps_xprt_list);
+ smp_wmb();
+ if (xps->xps_nxprts == 0)
+ xps->xps_net = xprt->xprt_net;
+ xps->xps_nxprts++;
+ xps->xps_nactive++;
+}
+
+/**
+ * rpc_xprt_switch_add_xprt - Add a new rpc_xprt to an rpc_xprt_switch
+ * @xps: pointer to struct rpc_xprt_switch
+ * @xprt: pointer to struct rpc_xprt
+ *
+ * Adds xprt to the end of the list of struct rpc_xprt in xps.
+ */
+void rpc_xprt_switch_add_xprt(struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt)
+{
+ if (xprt == NULL)
+ return;
+ spin_lock(&xps->xps_lock);
+ if (xps->xps_net == xprt->xprt_net || xps->xps_net == NULL)
+ xprt_switch_add_xprt_locked(xps, xprt);
+ spin_unlock(&xps->xps_lock);
+ rpc_sysfs_xprt_setup(xps, xprt, GFP_KERNEL);
+}
+
+static void xprt_switch_remove_xprt_locked(struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt, bool offline)
+{
+ if (unlikely(xprt == NULL))
+ return;
+ if (!test_bit(XPRT_OFFLINE, &xprt->state) && offline)
+ xps->xps_nactive--;
+ xps->xps_nxprts--;
+ if (xps->xps_nxprts == 0)
+ xps->xps_net = NULL;
+ smp_wmb();
+ list_del_rcu(&xprt->xprt_switch);
+}
+
+/**
+ * rpc_xprt_switch_remove_xprt - Removes an rpc_xprt from a rpc_xprt_switch
+ * @xps: pointer to struct rpc_xprt_switch
+ * @xprt: pointer to struct rpc_xprt
+ * @offline: indicates if the xprt that's being removed is in an offline state
+ *
+ * Removes xprt from the list of struct rpc_xprt in xps.
+ */
+void rpc_xprt_switch_remove_xprt(struct rpc_xprt_switch *xps,
+ struct rpc_xprt *xprt, bool offline)
+{
+ spin_lock(&xps->xps_lock);
+ xprt_switch_remove_xprt_locked(xps, xprt, offline);
+ spin_unlock(&xps->xps_lock);
+ xprt_put(xprt);
+}
+
+static DEFINE_IDA(rpc_xprtswitch_ids);
+
+void xprt_multipath_cleanup_ids(void)
+{
+ ida_destroy(&rpc_xprtswitch_ids);
+}
+
+static int xprt_switch_alloc_id(struct rpc_xprt_switch *xps, gfp_t gfp_flags)
+{
+ int id;
+
+ id = ida_alloc(&rpc_xprtswitch_ids, gfp_flags);
+ if (id < 0)
+ return id;
+
+ xps->xps_id = id;
+ return 0;
+}
+
+static void xprt_switch_free_id(struct rpc_xprt_switch *xps)
+{
+ ida_free(&rpc_xprtswitch_ids, xps->xps_id);
+}
+
+/**
+ * xprt_switch_alloc - Allocate a new struct rpc_xprt_switch
+ * @xprt: pointer to struct rpc_xprt
+ * @gfp_flags: allocation flags
+ *
+ * On success, returns an initialised struct rpc_xprt_switch, containing
+ * the entry xprt. Returns NULL on failure.
+ */
+struct rpc_xprt_switch *xprt_switch_alloc(struct rpc_xprt *xprt,
+ gfp_t gfp_flags)
+{
+ struct rpc_xprt_switch *xps;
+
+ xps = kmalloc(sizeof(*xps), gfp_flags);
+ if (xps != NULL) {
+ spin_lock_init(&xps->xps_lock);
+ kref_init(&xps->xps_kref);
+ xprt_switch_alloc_id(xps, gfp_flags);
+ xps->xps_nxprts = xps->xps_nactive = 0;
+ atomic_long_set(&xps->xps_queuelen, 0);
+ xps->xps_net = NULL;
+ INIT_LIST_HEAD(&xps->xps_xprt_list);
+ xps->xps_iter_ops = &rpc_xprt_iter_singular;
+ rpc_sysfs_xprt_switch_setup(xps, xprt, gfp_flags);
+ xprt_switch_add_xprt_locked(xps, xprt);
+ xps->xps_nunique_destaddr_xprts = 1;
+ rpc_sysfs_xprt_setup(xps, xprt, gfp_flags);
+ }
+
+ return xps;
+}
+
+static void xprt_switch_free_entries(struct rpc_xprt_switch *xps)
+{
+ spin_lock(&xps->xps_lock);
+ while (!list_empty(&xps->xps_xprt_list)) {
+ struct rpc_xprt *xprt;
+
+ xprt = list_first_entry(&xps->xps_xprt_list,
+ struct rpc_xprt, xprt_switch);
+ xprt_switch_remove_xprt_locked(xps, xprt, true);
+ spin_unlock(&xps->xps_lock);
+ xprt_put(xprt);
+ spin_lock(&xps->xps_lock);
+ }
+ spin_unlock(&xps->xps_lock);
+}
+
+static void xprt_switch_free(struct kref *kref)
+{
+ struct rpc_xprt_switch *xps = container_of(kref,
+ struct rpc_xprt_switch, xps_kref);
+
+ xprt_switch_free_entries(xps);
+ rpc_sysfs_xprt_switch_destroy(xps);
+ xprt_switch_free_id(xps);
+ kfree_rcu(xps, xps_rcu);
+}
+
+/**
+ * xprt_switch_get - Return a reference to a rpc_xprt_switch
+ * @xps: pointer to struct rpc_xprt_switch
+ *
+ * Returns a reference to xps unless the refcount is already zero.
+ */
+struct rpc_xprt_switch *xprt_switch_get(struct rpc_xprt_switch *xps)
+{
+ if (xps != NULL && kref_get_unless_zero(&xps->xps_kref))
+ return xps;
+ return NULL;
+}
+
+/**
+ * xprt_switch_put - Release a reference to a rpc_xprt_switch
+ * @xps: pointer to struct rpc_xprt_switch
+ *
+ * Release the reference to xps, and free it once the refcount is zero.
+ */
+void xprt_switch_put(struct rpc_xprt_switch *xps)
+{
+ if (xps != NULL)
+ kref_put(&xps->xps_kref, xprt_switch_free);
+}
+
+/**
+ * rpc_xprt_switch_set_roundrobin - Set a round-robin policy on rpc_xprt_switch
+ * @xps: pointer to struct rpc_xprt_switch
+ *
+ * Sets a round-robin default policy for iterators acting on xps.
+ */
+void rpc_xprt_switch_set_roundrobin(struct rpc_xprt_switch *xps)
+{
+ if (READ_ONCE(xps->xps_iter_ops) != &rpc_xprt_iter_roundrobin)
+ WRITE_ONCE(xps->xps_iter_ops, &rpc_xprt_iter_roundrobin);
+}
+
+static
+const struct rpc_xprt_iter_ops *xprt_iter_ops(const struct rpc_xprt_iter *xpi)
+{
+ if (xpi->xpi_ops != NULL)
+ return xpi->xpi_ops;
+ return rcu_dereference(xpi->xpi_xpswitch)->xps_iter_ops;
+}
+
+static
+void xprt_iter_no_rewind(struct rpc_xprt_iter *xpi)
+{
+}
+
+static
+void xprt_iter_default_rewind(struct rpc_xprt_iter *xpi)
+{
+ WRITE_ONCE(xpi->xpi_cursor, NULL);
+}
+
+static
+bool xprt_is_active(const struct rpc_xprt *xprt)
+{
+ return (kref_read(&xprt->kref) != 0 &&
+ !test_bit(XPRT_OFFLINE, &xprt->state));
+}
+
+static
+struct rpc_xprt *xprt_switch_find_first_entry(struct list_head *head)
+{
+ struct rpc_xprt *pos;
+
+ list_for_each_entry_rcu(pos, head, xprt_switch) {
+ if (xprt_is_active(pos))
+ return pos;
+ }
+ return NULL;
+}
+
+static
+struct rpc_xprt *xprt_switch_find_first_entry_offline(struct list_head *head)
+{
+ struct rpc_xprt *pos;
+
+ list_for_each_entry_rcu(pos, head, xprt_switch) {
+ if (!xprt_is_active(pos))
+ return pos;
+ }
+ return NULL;
+}
+
+static
+struct rpc_xprt *xprt_iter_first_entry(struct rpc_xprt_iter *xpi)
+{
+ struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch);
+
+ if (xps == NULL)
+ return NULL;
+ return xprt_switch_find_first_entry(&xps->xps_xprt_list);
+}
+
+static
+struct rpc_xprt *_xprt_switch_find_current_entry(struct list_head *head,
+ const struct rpc_xprt *cur,
+ bool find_active)
+{
+ struct rpc_xprt *pos;
+ bool found = false;
+
+ list_for_each_entry_rcu(pos, head, xprt_switch) {
+ if (cur == pos)
+ found = true;
+ if (found && ((find_active && xprt_is_active(pos)) ||
+ (!find_active && !xprt_is_active(pos))))
+ return pos;
+ }
+ return NULL;
+}
+
+static
+struct rpc_xprt *xprt_switch_find_current_entry(struct list_head *head,
+ const struct rpc_xprt *cur)
+{
+ return _xprt_switch_find_current_entry(head, cur, true);
+}
+
+static
+struct rpc_xprt * _xprt_iter_current_entry(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt *first_entry(struct list_head *head),
+ struct rpc_xprt *current_entry(struct list_head *head,
+ const struct rpc_xprt *cur))
+{
+ struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch);
+ struct list_head *head;
+
+ if (xps == NULL)
+ return NULL;
+ head = &xps->xps_xprt_list;
+ if (xpi->xpi_cursor == NULL || xps->xps_nxprts < 2)
+ return first_entry(head);
+ return current_entry(head, xpi->xpi_cursor);
+}
+
+static
+struct rpc_xprt *xprt_iter_current_entry(struct rpc_xprt_iter *xpi)
+{
+ return _xprt_iter_current_entry(xpi, xprt_switch_find_first_entry,
+ xprt_switch_find_current_entry);
+}
+
+static
+struct rpc_xprt *xprt_switch_find_current_entry_offline(struct list_head *head,
+ const struct rpc_xprt *cur)
+{
+ return _xprt_switch_find_current_entry(head, cur, false);
+}
+
+static
+struct rpc_xprt *xprt_iter_current_entry_offline(struct rpc_xprt_iter *xpi)
+{
+ return _xprt_iter_current_entry(xpi,
+ xprt_switch_find_first_entry_offline,
+ xprt_switch_find_current_entry_offline);
+}
+
+bool rpc_xprt_switch_has_addr(struct rpc_xprt_switch *xps,
+ const struct sockaddr *sap)
+{
+ struct list_head *head;
+ struct rpc_xprt *pos;
+
+ if (xps == NULL || sap == NULL)
+ return false;
+
+ head = &xps->xps_xprt_list;
+ list_for_each_entry_rcu(pos, head, xprt_switch) {
+ if (rpc_cmp_addr_port(sap, (struct sockaddr *)&pos->addr)) {
+ pr_info("RPC: addr %s already in xprt switch\n",
+ pos->address_strings[RPC_DISPLAY_ADDR]);
+ return true;
+ }
+ }
+ return false;
+}
+
+static
+struct rpc_xprt *xprt_switch_find_next_entry(struct list_head *head,
+ const struct rpc_xprt *cur, bool check_active)
+{
+ struct rpc_xprt *pos, *prev = NULL;
+ bool found = false;
+
+ list_for_each_entry_rcu(pos, head, xprt_switch) {
+ if (cur == prev)
+ found = true;
+ /* for request to return active transports return only
+ * active, for request to return offline transports
+ * return only offline
+ */
+ if (found && ((check_active && xprt_is_active(pos)) ||
+ (!check_active && !xprt_is_active(pos))))
+ return pos;
+ prev = pos;
+ }
+ return NULL;
+}
+
+static
+struct rpc_xprt *xprt_switch_set_next_cursor(struct rpc_xprt_switch *xps,
+ struct rpc_xprt **cursor,
+ xprt_switch_find_xprt_t find_next)
+{
+ struct rpc_xprt *pos, *old;
+
+ old = smp_load_acquire(cursor);
+ pos = find_next(xps, old);
+ smp_store_release(cursor, pos);
+ return pos;
+}
+
+static
+struct rpc_xprt *xprt_iter_next_entry_multiple(struct rpc_xprt_iter *xpi,
+ xprt_switch_find_xprt_t find_next)
+{
+ struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch);
+
+ if (xps == NULL)
+ return NULL;
+ return xprt_switch_set_next_cursor(xps, &xpi->xpi_cursor, find_next);
+}
+
+static
+struct rpc_xprt *__xprt_switch_find_next_entry_roundrobin(struct list_head *head,
+ const struct rpc_xprt *cur)
+{
+ struct rpc_xprt *ret;
+
+ ret = xprt_switch_find_next_entry(head, cur, true);
+ if (ret != NULL)
+ return ret;
+ return xprt_switch_find_first_entry(head);
+}
+
+static
+struct rpc_xprt *xprt_switch_find_next_entry_roundrobin(struct rpc_xprt_switch *xps,
+ const struct rpc_xprt *cur)
+{
+ struct list_head *head = &xps->xps_xprt_list;
+ struct rpc_xprt *xprt;
+ unsigned int nactive;
+
+ for (;;) {
+ unsigned long xprt_queuelen, xps_queuelen;
+
+ xprt = __xprt_switch_find_next_entry_roundrobin(head, cur);
+ if (!xprt)
+ break;
+ xprt_queuelen = atomic_long_read(&xprt->queuelen);
+ xps_queuelen = atomic_long_read(&xps->xps_queuelen);
+ nactive = READ_ONCE(xps->xps_nactive);
+ /* Exit loop if xprt_queuelen <= average queue length */
+ if (xprt_queuelen * nactive <= xps_queuelen)
+ break;
+ cur = xprt;
+ }
+ return xprt;
+}
+
+static
+struct rpc_xprt *xprt_iter_next_entry_roundrobin(struct rpc_xprt_iter *xpi)
+{
+ return xprt_iter_next_entry_multiple(xpi,
+ xprt_switch_find_next_entry_roundrobin);
+}
+
+static
+struct rpc_xprt *xprt_switch_find_next_entry_all(struct rpc_xprt_switch *xps,
+ const struct rpc_xprt *cur)
+{
+ return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, true);
+}
+
+static
+struct rpc_xprt *xprt_switch_find_next_entry_offline(struct rpc_xprt_switch *xps,
+ const struct rpc_xprt *cur)
+{
+ return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, false);
+}
+
+static
+struct rpc_xprt *xprt_iter_next_entry_all(struct rpc_xprt_iter *xpi)
+{
+ return xprt_iter_next_entry_multiple(xpi,
+ xprt_switch_find_next_entry_all);
+}
+
+static
+struct rpc_xprt *xprt_iter_next_entry_offline(struct rpc_xprt_iter *xpi)
+{
+ return xprt_iter_next_entry_multiple(xpi,
+ xprt_switch_find_next_entry_offline);
+}
+
+/*
+ * xprt_iter_rewind - Resets the xprt iterator
+ * @xpi: pointer to rpc_xprt_iter
+ *
+ * Resets xpi to ensure that it points to the first entry in the list
+ * of transports.
+ */
+void xprt_iter_rewind(struct rpc_xprt_iter *xpi)
+{
+ rcu_read_lock();
+ xprt_iter_ops(xpi)->xpi_rewind(xpi);
+ rcu_read_unlock();
+}
+
+static void __xprt_iter_init(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt_switch *xps,
+ const struct rpc_xprt_iter_ops *ops)
+{
+ rcu_assign_pointer(xpi->xpi_xpswitch, xprt_switch_get(xps));
+ xpi->xpi_cursor = NULL;
+ xpi->xpi_ops = ops;
+}
+
+/**
+ * xprt_iter_init - Initialise an xprt iterator
+ * @xpi: pointer to rpc_xprt_iter
+ * @xps: pointer to rpc_xprt_switch
+ *
+ * Initialises the iterator to use the default iterator ops
+ * as set in xps. This function is mainly intended for internal
+ * use in the rpc_client.
+ */
+void xprt_iter_init(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt_switch *xps)
+{
+ __xprt_iter_init(xpi, xps, NULL);
+}
+
+/**
+ * xprt_iter_init_listall - Initialise an xprt iterator
+ * @xpi: pointer to rpc_xprt_iter
+ * @xps: pointer to rpc_xprt_switch
+ *
+ * Initialises the iterator to iterate once through the entire list
+ * of entries in xps.
+ */
+void xprt_iter_init_listall(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt_switch *xps)
+{
+ __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listall);
+}
+
+void xprt_iter_init_listoffline(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt_switch *xps)
+{
+ __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listoffline);
+}
+
+/**
+ * xprt_iter_xchg_switch - Atomically swap out the rpc_xprt_switch
+ * @xpi: pointer to rpc_xprt_iter
+ * @newswitch: pointer to a new rpc_xprt_switch or NULL
+ *
+ * Swaps out the existing xpi->xpi_xpswitch with a new value.
+ */
+struct rpc_xprt_switch *xprt_iter_xchg_switch(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt_switch *newswitch)
+{
+ struct rpc_xprt_switch __rcu *oldswitch;
+
+ /* Atomically swap out the old xpswitch */
+ oldswitch = xchg(&xpi->xpi_xpswitch, RCU_INITIALIZER(newswitch));
+ if (newswitch != NULL)
+ xprt_iter_rewind(xpi);
+ return rcu_dereference_protected(oldswitch, true);
+}
+
+/**
+ * xprt_iter_destroy - Destroys the xprt iterator
+ * @xpi: pointer to rpc_xprt_iter
+ */
+void xprt_iter_destroy(struct rpc_xprt_iter *xpi)
+{
+ xprt_switch_put(xprt_iter_xchg_switch(xpi, NULL));
+}
+
+/**
+ * xprt_iter_xprt - Returns the rpc_xprt pointed to by the cursor
+ * @xpi: pointer to rpc_xprt_iter
+ *
+ * Returns a pointer to the struct rpc_xprt that is currently
+ * pointed to by the cursor.
+ * Caller must be holding rcu_read_lock().
+ */
+struct rpc_xprt *xprt_iter_xprt(struct rpc_xprt_iter *xpi)
+{
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ return xprt_iter_ops(xpi)->xpi_xprt(xpi);
+}
+
+static
+struct rpc_xprt *xprt_iter_get_helper(struct rpc_xprt_iter *xpi,
+ struct rpc_xprt *(*fn)(struct rpc_xprt_iter *))
+{
+ struct rpc_xprt *ret;
+
+ do {
+ ret = fn(xpi);
+ if (ret == NULL)
+ break;
+ ret = xprt_get(ret);
+ } while (ret == NULL);
+ return ret;
+}
+
+/**
+ * xprt_iter_get_xprt - Returns the rpc_xprt pointed to by the cursor
+ * @xpi: pointer to rpc_xprt_iter
+ *
+ * Returns a reference to the struct rpc_xprt that is currently
+ * pointed to by the cursor.
+ */
+struct rpc_xprt *xprt_iter_get_xprt(struct rpc_xprt_iter *xpi)
+{
+ struct rpc_xprt *xprt;
+
+ rcu_read_lock();
+ xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_xprt);
+ rcu_read_unlock();
+ return xprt;
+}
+
+/**
+ * xprt_iter_get_next - Returns the next rpc_xprt following the cursor
+ * @xpi: pointer to rpc_xprt_iter
+ *
+ * Returns a reference to the struct rpc_xprt that immediately follows the
+ * entry pointed to by the cursor.
+ */
+struct rpc_xprt *xprt_iter_get_next(struct rpc_xprt_iter *xpi)
+{
+ struct rpc_xprt *xprt;
+
+ rcu_read_lock();
+ xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_next);
+ rcu_read_unlock();
+ return xprt;
+}
+
+/* Policy for always returning the first entry in the rpc_xprt_switch */
+static
+const struct rpc_xprt_iter_ops rpc_xprt_iter_singular = {
+ .xpi_rewind = xprt_iter_no_rewind,
+ .xpi_xprt = xprt_iter_first_entry,
+ .xpi_next = xprt_iter_first_entry,
+};
+
+/* Policy for round-robin iteration of entries in the rpc_xprt_switch */
+static
+const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin = {
+ .xpi_rewind = xprt_iter_default_rewind,
+ .xpi_xprt = xprt_iter_current_entry,
+ .xpi_next = xprt_iter_next_entry_roundrobin,
+};
+
+/* Policy for once-through iteration of entries in the rpc_xprt_switch */
+static
+const struct rpc_xprt_iter_ops rpc_xprt_iter_listall = {
+ .xpi_rewind = xprt_iter_default_rewind,
+ .xpi_xprt = xprt_iter_current_entry,
+ .xpi_next = xprt_iter_next_entry_all,
+};
+
+static
+const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline = {
+ .xpi_rewind = xprt_iter_default_rewind,
+ .xpi_xprt = xprt_iter_current_entry_offline,
+ .xpi_next = xprt_iter_next_entry_offline,
+};
diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile
new file mode 100644
index 0000000000..55b21bae86
--- /dev/null
+++ b/net/sunrpc/xprtrdma/Makefile
@@ -0,0 +1,8 @@
+# SPDX-License-Identifier: GPL-2.0
+obj-$(CONFIG_SUNRPC_XPRT_RDMA) += rpcrdma.o
+
+rpcrdma-y := transport.o rpc_rdma.o verbs.o frwr_ops.o \
+ svc_rdma.o svc_rdma_backchannel.o svc_rdma_transport.o \
+ svc_rdma_sendto.o svc_rdma_recvfrom.o svc_rdma_rw.o \
+ svc_rdma_pcl.o module.o
+rpcrdma-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel.o
diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c
new file mode 100644
index 0000000000..e4d84a13c5
--- /dev/null
+++ b/net/sunrpc/xprtrdma/backchannel.c
@@ -0,0 +1,282 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2015-2020, Oracle and/or its affiliates.
+ *
+ * Support for reverse-direction RPCs on RPC/RDMA.
+ */
+
+#include <linux/sunrpc/xprt.h>
+#include <linux/sunrpc/svc.h>
+#include <linux/sunrpc/svc_xprt.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+#undef RPCRDMA_BACKCHANNEL_DEBUG
+
+/**
+ * xprt_rdma_bc_setup - Pre-allocate resources for handling backchannel requests
+ * @xprt: transport associated with these backchannel resources
+ * @reqs: number of concurrent incoming requests to expect
+ *
+ * Returns 0 on success; otherwise a negative errno
+ */
+int xprt_rdma_bc_setup(struct rpc_xprt *xprt, unsigned int reqs)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ r_xprt->rx_buf.rb_bc_srv_max_requests = RPCRDMA_BACKWARD_WRS >> 1;
+ trace_xprtrdma_cb_setup(r_xprt, reqs);
+ return 0;
+}
+
+/**
+ * xprt_rdma_bc_maxpayload - Return maximum backchannel message size
+ * @xprt: transport
+ *
+ * Returns maximum size, in bytes, of a backchannel message
+ */
+size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *xprt)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ size_t maxmsg;
+
+ maxmsg = min_t(unsigned int, ep->re_inline_send, ep->re_inline_recv);
+ maxmsg = min_t(unsigned int, maxmsg, PAGE_SIZE);
+ return maxmsg - RPCRDMA_HDRLEN_MIN;
+}
+
+unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *xprt)
+{
+ return RPCRDMA_BACKWARD_WRS >> 1;
+}
+
+static int rpcrdma_bc_marshal_reply(struct rpc_rqst *rqst)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt);
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ __be32 *p;
+
+ rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0);
+ xdr_init_encode(&req->rl_stream, &req->rl_hdrbuf,
+ rdmab_data(req->rl_rdmabuf), rqst);
+
+ p = xdr_reserve_space(&req->rl_stream, 28);
+ if (unlikely(!p))
+ return -EIO;
+ *p++ = rqst->rq_xid;
+ *p++ = rpcrdma_version;
+ *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_srv_max_requests);
+ *p++ = rdma_msg;
+ *p++ = xdr_zero;
+ *p++ = xdr_zero;
+ *p = xdr_zero;
+
+ if (rpcrdma_prepare_send_sges(r_xprt, req, RPCRDMA_HDRLEN_MIN,
+ &rqst->rq_snd_buf, rpcrdma_noch_pullup))
+ return -EIO;
+
+ trace_xprtrdma_cb_reply(r_xprt, rqst);
+ return 0;
+}
+
+/**
+ * xprt_rdma_bc_send_reply - marshal and send a backchannel reply
+ * @rqst: RPC rqst with a backchannel RPC reply in rq_snd_buf
+ *
+ * Caller holds the transport's write lock.
+ *
+ * Returns:
+ * %0 if the RPC message has been sent
+ * %-ENOTCONN if the caller should reconnect and call again
+ * %-EIO if a permanent error occurred and the request was not
+ * sent. Do not try to send this message again.
+ */
+int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst)
+{
+ struct rpc_xprt *xprt = rqst->rq_xprt;
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ int rc;
+
+ if (!xprt_connected(xprt))
+ return -ENOTCONN;
+
+ if (!xprt_request_get_cong(xprt, rqst))
+ return -EBADSLT;
+
+ rc = rpcrdma_bc_marshal_reply(rqst);
+ if (rc < 0)
+ goto failed_marshal;
+
+ if (frwr_send(r_xprt, req))
+ goto drop_connection;
+ return 0;
+
+failed_marshal:
+ if (rc != -ENOTCONN)
+ return rc;
+drop_connection:
+ xprt_rdma_close(xprt);
+ return -ENOTCONN;
+}
+
+/**
+ * xprt_rdma_bc_destroy - Release resources for handling backchannel requests
+ * @xprt: transport associated with these backchannel resources
+ * @reqs: number of incoming requests to destroy; ignored
+ */
+void xprt_rdma_bc_destroy(struct rpc_xprt *xprt, unsigned int reqs)
+{
+ struct rpc_rqst *rqst, *tmp;
+
+ spin_lock(&xprt->bc_pa_lock);
+ list_for_each_entry_safe(rqst, tmp, &xprt->bc_pa_list, rq_bc_pa_list) {
+ list_del(&rqst->rq_bc_pa_list);
+ spin_unlock(&xprt->bc_pa_lock);
+
+ rpcrdma_req_destroy(rpcr_to_rdmar(rqst));
+
+ spin_lock(&xprt->bc_pa_lock);
+ }
+ spin_unlock(&xprt->bc_pa_lock);
+}
+
+/**
+ * xprt_rdma_bc_free_rqst - Release a backchannel rqst
+ * @rqst: request to release
+ */
+void xprt_rdma_bc_free_rqst(struct rpc_rqst *rqst)
+{
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ struct rpcrdma_rep *rep = req->rl_reply;
+ struct rpc_xprt *xprt = rqst->rq_xprt;
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ rpcrdma_rep_put(&r_xprt->rx_buf, rep);
+ req->rl_reply = NULL;
+
+ spin_lock(&xprt->bc_pa_lock);
+ list_add_tail(&rqst->rq_bc_pa_list, &xprt->bc_pa_list);
+ spin_unlock(&xprt->bc_pa_lock);
+ xprt_put(xprt);
+}
+
+static struct rpc_rqst *rpcrdma_bc_rqst_get(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct rpcrdma_req *req;
+ struct rpc_rqst *rqst;
+ size_t size;
+
+ spin_lock(&xprt->bc_pa_lock);
+ rqst = list_first_entry_or_null(&xprt->bc_pa_list, struct rpc_rqst,
+ rq_bc_pa_list);
+ if (!rqst)
+ goto create_req;
+ list_del(&rqst->rq_bc_pa_list);
+ spin_unlock(&xprt->bc_pa_lock);
+ return rqst;
+
+create_req:
+ spin_unlock(&xprt->bc_pa_lock);
+
+ /* Set a limit to prevent a remote from overrunning our resources.
+ */
+ if (xprt->bc_alloc_count >= RPCRDMA_BACKWARD_WRS)
+ return NULL;
+
+ size = min_t(size_t, r_xprt->rx_ep->re_inline_recv, PAGE_SIZE);
+ req = rpcrdma_req_create(r_xprt, size);
+ if (!req)
+ return NULL;
+ if (rpcrdma_req_setup(r_xprt, req)) {
+ rpcrdma_req_destroy(req);
+ return NULL;
+ }
+
+ xprt->bc_alloc_count++;
+ rqst = &req->rl_slot;
+ rqst->rq_xprt = xprt;
+ __set_bit(RPC_BC_PA_IN_USE, &rqst->rq_bc_pa_state);
+ xdr_buf_init(&rqst->rq_snd_buf, rdmab_data(req->rl_sendbuf), size);
+ return rqst;
+}
+
+/**
+ * rpcrdma_bc_receive_call - Handle a reverse-direction Call
+ * @r_xprt: transport receiving the call
+ * @rep: receive buffer containing the call
+ *
+ * Operational assumptions:
+ * o Backchannel credits are ignored, just as the NFS server
+ * forechannel currently does
+ * o The ULP manages a replay cache (eg, NFSv4.1 sessions).
+ * No replay detection is done at the transport level
+ */
+void rpcrdma_bc_receive_call(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_rep *rep)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct svc_serv *bc_serv;
+ struct rpcrdma_req *req;
+ struct rpc_rqst *rqst;
+ struct xdr_buf *buf;
+ size_t size;
+ __be32 *p;
+
+ p = xdr_inline_decode(&rep->rr_stream, 0);
+ size = xdr_stream_remaining(&rep->rr_stream);
+
+#ifdef RPCRDMA_BACKCHANNEL_DEBUG
+ pr_info("RPC: %s: callback XID %08x, length=%u\n",
+ __func__, be32_to_cpup(p), size);
+ pr_info("RPC: %s: %*ph\n", __func__, size, p);
+#endif
+
+ rqst = rpcrdma_bc_rqst_get(r_xprt);
+ if (!rqst)
+ goto out_overflow;
+
+ rqst->rq_reply_bytes_recvd = 0;
+ rqst->rq_xid = *p;
+
+ rqst->rq_private_buf.len = size;
+
+ buf = &rqst->rq_rcv_buf;
+ memset(buf, 0, sizeof(*buf));
+ buf->head[0].iov_base = p;
+ buf->head[0].iov_len = size;
+ buf->len = size;
+
+ /* The receive buffer has to be hooked to the rpcrdma_req
+ * so that it is not released while the req is pointing
+ * to its buffer, and so that it can be reposted after
+ * the Upper Layer is done decoding it.
+ */
+ req = rpcr_to_rdmar(rqst);
+ req->rl_reply = rep;
+ trace_xprtrdma_cb_call(r_xprt, rqst);
+
+ /* Queue rqst for ULP's callback service */
+ bc_serv = xprt->bc_serv;
+ xprt_get(xprt);
+ spin_lock(&bc_serv->sv_cb_lock);
+ list_add(&rqst->rq_bc_list, &bc_serv->sv_cb_list);
+ spin_unlock(&bc_serv->sv_cb_lock);
+
+ wake_up(&bc_serv->sv_cb_waitq);
+
+ r_xprt->rx_stats.bcall_count++;
+ return;
+
+out_overflow:
+ pr_warn("RPC/RDMA backchannel overflow\n");
+ xprt_force_disconnect(xprt);
+ /* This receive buffer gets reposted automatically
+ * when the connection is re-established.
+ */
+ return;
+}
diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c
new file mode 100644
index 0000000000..ffbf998949
--- /dev/null
+++ b/net/sunrpc/xprtrdma/frwr_ops.c
@@ -0,0 +1,696 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2015, 2017 Oracle. All rights reserved.
+ * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved.
+ */
+
+/* Lightweight memory registration using Fast Registration Work
+ * Requests (FRWR).
+ *
+ * FRWR features ordered asynchronous registration and invalidation
+ * of arbitrarily-sized memory regions. This is the fastest and safest
+ * but most complex memory registration mode.
+ */
+
+/* Normal operation
+ *
+ * A Memory Region is prepared for RDMA Read or Write using a FAST_REG
+ * Work Request (frwr_map). When the RDMA operation is finished, this
+ * Memory Region is invalidated using a LOCAL_INV Work Request
+ * (frwr_unmap_async and frwr_unmap_sync).
+ *
+ * Typically FAST_REG Work Requests are not signaled, and neither are
+ * RDMA Send Work Requests (with the exception of signaling occasionally
+ * to prevent provider work queue overflows). This greatly reduces HCA
+ * interrupt workload.
+ */
+
+/* Transport recovery
+ *
+ * frwr_map and frwr_unmap_* cannot run at the same time the transport
+ * connect worker is running. The connect worker holds the transport
+ * send lock, just as ->send_request does. This prevents frwr_map and
+ * the connect worker from running concurrently. When a connection is
+ * closed, the Receive completion queue is drained before the allowing
+ * the connect worker to get control. This prevents frwr_unmap and the
+ * connect worker from running concurrently.
+ *
+ * When the underlying transport disconnects, MRs that are in flight
+ * are flushed and are likely unusable. Thus all MRs are destroyed.
+ * New MRs are created on demand.
+ */
+
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+static void frwr_cid_init(struct rpcrdma_ep *ep,
+ struct rpcrdma_mr *mr)
+{
+ struct rpc_rdma_cid *cid = &mr->mr_cid;
+
+ cid->ci_queue_id = ep->re_attr.send_cq->res.id;
+ cid->ci_completion_id = mr->mr_ibmr->res.id;
+}
+
+static void frwr_mr_unmap(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr)
+{
+ if (mr->mr_device) {
+ trace_xprtrdma_mr_unmap(mr);
+ ib_dma_unmap_sg(mr->mr_device, mr->mr_sg, mr->mr_nents,
+ mr->mr_dir);
+ mr->mr_device = NULL;
+ }
+}
+
+/**
+ * frwr_mr_release - Destroy one MR
+ * @mr: MR allocated by frwr_mr_init
+ *
+ */
+void frwr_mr_release(struct rpcrdma_mr *mr)
+{
+ int rc;
+
+ frwr_mr_unmap(mr->mr_xprt, mr);
+
+ rc = ib_dereg_mr(mr->mr_ibmr);
+ if (rc)
+ trace_xprtrdma_frwr_dereg(mr, rc);
+ kfree(mr->mr_sg);
+ kfree(mr);
+}
+
+static void frwr_mr_put(struct rpcrdma_mr *mr)
+{
+ frwr_mr_unmap(mr->mr_xprt, mr);
+
+ /* The MR is returned to the req's MR free list instead
+ * of to the xprt's MR free list. No spinlock is needed.
+ */
+ rpcrdma_mr_push(mr, &mr->mr_req->rl_free_mrs);
+}
+
+/* frwr_reset - Place MRs back on the free list
+ * @req: request to reset
+ *
+ * Used after a failed marshal. For FRWR, this means the MRs
+ * don't have to be fully released and recreated.
+ *
+ * NB: This is safe only as long as none of @req's MRs are
+ * involved with an ongoing asynchronous FAST_REG or LOCAL_INV
+ * Work Request.
+ */
+void frwr_reset(struct rpcrdma_req *req)
+{
+ struct rpcrdma_mr *mr;
+
+ while ((mr = rpcrdma_mr_pop(&req->rl_registered)))
+ frwr_mr_put(mr);
+}
+
+/**
+ * frwr_mr_init - Initialize one MR
+ * @r_xprt: controlling transport instance
+ * @mr: generic MR to prepare for FRWR
+ *
+ * Returns zero if successful. Otherwise a negative errno
+ * is returned.
+ */
+int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr)
+{
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ unsigned int depth = ep->re_max_fr_depth;
+ struct scatterlist *sg;
+ struct ib_mr *frmr;
+
+ sg = kcalloc_node(depth, sizeof(*sg), XPRTRDMA_GFP_FLAGS,
+ ibdev_to_node(ep->re_id->device));
+ if (!sg)
+ return -ENOMEM;
+
+ frmr = ib_alloc_mr(ep->re_pd, ep->re_mrtype, depth);
+ if (IS_ERR(frmr))
+ goto out_mr_err;
+
+ mr->mr_xprt = r_xprt;
+ mr->mr_ibmr = frmr;
+ mr->mr_device = NULL;
+ INIT_LIST_HEAD(&mr->mr_list);
+ init_completion(&mr->mr_linv_done);
+ frwr_cid_init(ep, mr);
+
+ sg_init_table(sg, depth);
+ mr->mr_sg = sg;
+ return 0;
+
+out_mr_err:
+ kfree(sg);
+ trace_xprtrdma_frwr_alloc(mr, PTR_ERR(frmr));
+ return PTR_ERR(frmr);
+}
+
+/**
+ * frwr_query_device - Prepare a transport for use with FRWR
+ * @ep: endpoint to fill in
+ * @device: RDMA device to query
+ *
+ * On success, sets:
+ * ep->re_attr
+ * ep->re_max_requests
+ * ep->re_max_rdma_segs
+ * ep->re_max_fr_depth
+ * ep->re_mrtype
+ *
+ * Return values:
+ * On success, returns zero.
+ * %-EINVAL - the device does not support FRWR memory registration
+ * %-ENOMEM - the device is not sufficiently capable for NFS/RDMA
+ */
+int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device)
+{
+ const struct ib_device_attr *attrs = &device->attrs;
+ int max_qp_wr, depth, delta;
+ unsigned int max_sge;
+
+ if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS) ||
+ attrs->max_fast_reg_page_list_len == 0) {
+ pr_err("rpcrdma: 'frwr' mode is not supported by device %s\n",
+ device->name);
+ return -EINVAL;
+ }
+
+ max_sge = min_t(unsigned int, attrs->max_send_sge,
+ RPCRDMA_MAX_SEND_SGES);
+ if (max_sge < RPCRDMA_MIN_SEND_SGES) {
+ pr_err("rpcrdma: HCA provides only %u send SGEs\n", max_sge);
+ return -ENOMEM;
+ }
+ ep->re_attr.cap.max_send_sge = max_sge;
+ ep->re_attr.cap.max_recv_sge = 1;
+
+ ep->re_mrtype = IB_MR_TYPE_MEM_REG;
+ if (attrs->kernel_cap_flags & IBK_SG_GAPS_REG)
+ ep->re_mrtype = IB_MR_TYPE_SG_GAPS;
+
+ /* Quirk: Some devices advertise a large max_fast_reg_page_list_len
+ * capability, but perform optimally when the MRs are not larger
+ * than a page.
+ */
+ if (attrs->max_sge_rd > RPCRDMA_MAX_HDR_SEGS)
+ ep->re_max_fr_depth = attrs->max_sge_rd;
+ else
+ ep->re_max_fr_depth = attrs->max_fast_reg_page_list_len;
+ if (ep->re_max_fr_depth > RPCRDMA_MAX_DATA_SEGS)
+ ep->re_max_fr_depth = RPCRDMA_MAX_DATA_SEGS;
+
+ /* Add room for frwr register and invalidate WRs.
+ * 1. FRWR reg WR for head
+ * 2. FRWR invalidate WR for head
+ * 3. N FRWR reg WRs for pagelist
+ * 4. N FRWR invalidate WRs for pagelist
+ * 5. FRWR reg WR for tail
+ * 6. FRWR invalidate WR for tail
+ * 7. The RDMA_SEND WR
+ */
+ depth = 7;
+
+ /* Calculate N if the device max FRWR depth is smaller than
+ * RPCRDMA_MAX_DATA_SEGS.
+ */
+ if (ep->re_max_fr_depth < RPCRDMA_MAX_DATA_SEGS) {
+ delta = RPCRDMA_MAX_DATA_SEGS - ep->re_max_fr_depth;
+ do {
+ depth += 2; /* FRWR reg + invalidate */
+ delta -= ep->re_max_fr_depth;
+ } while (delta > 0);
+ }
+
+ max_qp_wr = attrs->max_qp_wr;
+ max_qp_wr -= RPCRDMA_BACKWARD_WRS;
+ max_qp_wr -= 1;
+ if (max_qp_wr < RPCRDMA_MIN_SLOT_TABLE)
+ return -ENOMEM;
+ if (ep->re_max_requests > max_qp_wr)
+ ep->re_max_requests = max_qp_wr;
+ ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth;
+ if (ep->re_attr.cap.max_send_wr > max_qp_wr) {
+ ep->re_max_requests = max_qp_wr / depth;
+ if (!ep->re_max_requests)
+ return -ENOMEM;
+ ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth;
+ }
+ ep->re_attr.cap.max_send_wr += RPCRDMA_BACKWARD_WRS;
+ ep->re_attr.cap.max_send_wr += 1; /* for ib_drain_sq */
+ ep->re_attr.cap.max_recv_wr = ep->re_max_requests;
+ ep->re_attr.cap.max_recv_wr += RPCRDMA_BACKWARD_WRS;
+ ep->re_attr.cap.max_recv_wr += RPCRDMA_MAX_RECV_BATCH;
+ ep->re_attr.cap.max_recv_wr += 1; /* for ib_drain_rq */
+
+ ep->re_max_rdma_segs =
+ DIV_ROUND_UP(RPCRDMA_MAX_DATA_SEGS, ep->re_max_fr_depth);
+ /* Reply chunks require segments for head and tail buffers */
+ ep->re_max_rdma_segs += 2;
+ if (ep->re_max_rdma_segs > RPCRDMA_MAX_HDR_SEGS)
+ ep->re_max_rdma_segs = RPCRDMA_MAX_HDR_SEGS;
+
+ /* Ensure the underlying device is capable of conveying the
+ * largest r/wsize NFS will ask for. This guarantees that
+ * failing over from one RDMA device to another will not
+ * break NFS I/O.
+ */
+ if ((ep->re_max_rdma_segs * ep->re_max_fr_depth) < RPCRDMA_MAX_SEGS)
+ return -ENOMEM;
+
+ return 0;
+}
+
+/**
+ * frwr_map - Register a memory region
+ * @r_xprt: controlling transport
+ * @seg: memory region co-ordinates
+ * @nsegs: number of segments remaining
+ * @writing: true when RDMA Write will be used
+ * @xid: XID of RPC using the registered memory
+ * @mr: MR to fill in
+ *
+ * Prepare a REG_MR Work Request to register a memory region
+ * for remote access via RDMA READ or RDMA WRITE.
+ *
+ * Returns the next segment or a negative errno pointer.
+ * On success, @mr is filled in.
+ */
+struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_mr_seg *seg,
+ int nsegs, bool writing, __be32 xid,
+ struct rpcrdma_mr *mr)
+{
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct ib_reg_wr *reg_wr;
+ int i, n, dma_nents;
+ struct ib_mr *ibmr;
+ u8 key;
+
+ if (nsegs > ep->re_max_fr_depth)
+ nsegs = ep->re_max_fr_depth;
+ for (i = 0; i < nsegs;) {
+ sg_set_page(&mr->mr_sg[i], seg->mr_page,
+ seg->mr_len, seg->mr_offset);
+
+ ++seg;
+ ++i;
+ if (ep->re_mrtype == IB_MR_TYPE_SG_GAPS)
+ continue;
+ if ((i < nsegs && seg->mr_offset) ||
+ offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len))
+ break;
+ }
+ mr->mr_dir = rpcrdma_data_dir(writing);
+ mr->mr_nents = i;
+
+ dma_nents = ib_dma_map_sg(ep->re_id->device, mr->mr_sg, mr->mr_nents,
+ mr->mr_dir);
+ if (!dma_nents)
+ goto out_dmamap_err;
+ mr->mr_device = ep->re_id->device;
+
+ ibmr = mr->mr_ibmr;
+ n = ib_map_mr_sg(ibmr, mr->mr_sg, dma_nents, NULL, PAGE_SIZE);
+ if (n != dma_nents)
+ goto out_mapmr_err;
+
+ ibmr->iova &= 0x00000000ffffffff;
+ ibmr->iova |= ((u64)be32_to_cpu(xid)) << 32;
+ key = (u8)(ibmr->rkey & 0x000000FF);
+ ib_update_fast_reg_key(ibmr, ++key);
+
+ reg_wr = &mr->mr_regwr;
+ reg_wr->mr = ibmr;
+ reg_wr->key = ibmr->rkey;
+ reg_wr->access = writing ?
+ IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE :
+ IB_ACCESS_REMOTE_READ;
+
+ mr->mr_handle = ibmr->rkey;
+ mr->mr_length = ibmr->length;
+ mr->mr_offset = ibmr->iova;
+ trace_xprtrdma_mr_map(mr);
+
+ return seg;
+
+out_dmamap_err:
+ trace_xprtrdma_frwr_sgerr(mr, i);
+ return ERR_PTR(-EIO);
+
+out_mapmr_err:
+ trace_xprtrdma_frwr_maperr(mr, n);
+ return ERR_PTR(-EIO);
+}
+
+/**
+ * frwr_wc_fastreg - Invoked by RDMA provider for a flushed FastReg WC
+ * @cq: completion queue
+ * @wc: WCE for a completed FastReg WR
+ *
+ * Each flushed MR gets destroyed after the QP has drained.
+ */
+static void frwr_wc_fastreg(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe);
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_fastreg(wc, &mr->mr_cid);
+
+ rpcrdma_flush_disconnect(cq->cq_context, wc);
+}
+
+/**
+ * frwr_send - post Send WRs containing the RPC Call message
+ * @r_xprt: controlling transport instance
+ * @req: prepared RPC Call
+ *
+ * For FRWR, chain any FastReg WRs to the Send WR. Only a
+ * single ib_post_send call is needed to register memory
+ * and then post the Send WR.
+ *
+ * Returns the return code from ib_post_send.
+ *
+ * Caller must hold the transport send lock to ensure that the
+ * pointers to the transport's rdma_cm_id and QP are stable.
+ */
+int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req)
+{
+ struct ib_send_wr *post_wr, *send_wr = &req->rl_wr;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rpcrdma_mr *mr;
+ unsigned int num_wrs;
+ int ret;
+
+ num_wrs = 1;
+ post_wr = send_wr;
+ list_for_each_entry(mr, &req->rl_registered, mr_list) {
+ trace_xprtrdma_mr_fastreg(mr);
+
+ mr->mr_cqe.done = frwr_wc_fastreg;
+ mr->mr_regwr.wr.next = post_wr;
+ mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe;
+ mr->mr_regwr.wr.num_sge = 0;
+ mr->mr_regwr.wr.opcode = IB_WR_REG_MR;
+ mr->mr_regwr.wr.send_flags = 0;
+ post_wr = &mr->mr_regwr.wr;
+ ++num_wrs;
+ }
+
+ if ((kref_read(&req->rl_kref) > 1) || num_wrs > ep->re_send_count) {
+ send_wr->send_flags |= IB_SEND_SIGNALED;
+ ep->re_send_count = min_t(unsigned int, ep->re_send_batch,
+ num_wrs - ep->re_send_count);
+ } else {
+ send_wr->send_flags &= ~IB_SEND_SIGNALED;
+ ep->re_send_count -= num_wrs;
+ }
+
+ trace_xprtrdma_post_send(req);
+ ret = ib_post_send(ep->re_id->qp, post_wr, NULL);
+ if (ret)
+ trace_xprtrdma_post_send_err(r_xprt, req, ret);
+ return ret;
+}
+
+/**
+ * frwr_reminv - handle a remotely invalidated mr on the @mrs list
+ * @rep: Received reply
+ * @mrs: list of MRs to check
+ *
+ */
+void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs)
+{
+ struct rpcrdma_mr *mr;
+
+ list_for_each_entry(mr, mrs, mr_list)
+ if (mr->mr_handle == rep->rr_inv_rkey) {
+ list_del_init(&mr->mr_list);
+ trace_xprtrdma_mr_reminv(mr);
+ frwr_mr_put(mr);
+ break; /* only one invalidated MR per RPC */
+ }
+}
+
+static void frwr_mr_done(struct ib_wc *wc, struct rpcrdma_mr *mr)
+{
+ if (likely(wc->status == IB_WC_SUCCESS))
+ frwr_mr_put(mr);
+}
+
+/**
+ * frwr_wc_localinv - Invoked by RDMA provider for a LOCAL_INV WC
+ * @cq: completion queue
+ * @wc: WCE for a completed LocalInv WR
+ *
+ */
+static void frwr_wc_localinv(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe);
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_li(wc, &mr->mr_cid);
+ frwr_mr_done(wc, mr);
+
+ rpcrdma_flush_disconnect(cq->cq_context, wc);
+}
+
+/**
+ * frwr_wc_localinv_wake - Invoked by RDMA provider for a LOCAL_INV WC
+ * @cq: completion queue
+ * @wc: WCE for a completed LocalInv WR
+ *
+ * Awaken anyone waiting for an MR to finish being fenced.
+ */
+static void frwr_wc_localinv_wake(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe);
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_li_wake(wc, &mr->mr_cid);
+ frwr_mr_done(wc, mr);
+ complete(&mr->mr_linv_done);
+
+ rpcrdma_flush_disconnect(cq->cq_context, wc);
+}
+
+/**
+ * frwr_unmap_sync - invalidate memory regions that were registered for @req
+ * @r_xprt: controlling transport instance
+ * @req: rpcrdma_req with a non-empty list of MRs to process
+ *
+ * Sleeps until it is safe for the host CPU to access the previously mapped
+ * memory regions. This guarantees that registered MRs are properly fenced
+ * from the server before the RPC consumer accesses the data in them. It
+ * also ensures proper Send flow control: waking the next RPC waits until
+ * this RPC has relinquished all its Send Queue entries.
+ */
+void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req)
+{
+ struct ib_send_wr *first, **prev, *last;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ const struct ib_send_wr *bad_wr;
+ struct rpcrdma_mr *mr;
+ int rc;
+
+ /* ORDER: Invalidate all of the MRs first
+ *
+ * Chain the LOCAL_INV Work Requests and post them with
+ * a single ib_post_send() call.
+ */
+ prev = &first;
+ mr = rpcrdma_mr_pop(&req->rl_registered);
+ do {
+ trace_xprtrdma_mr_localinv(mr);
+ r_xprt->rx_stats.local_inv_needed++;
+
+ last = &mr->mr_invwr;
+ last->next = NULL;
+ last->wr_cqe = &mr->mr_cqe;
+ last->sg_list = NULL;
+ last->num_sge = 0;
+ last->opcode = IB_WR_LOCAL_INV;
+ last->send_flags = IB_SEND_SIGNALED;
+ last->ex.invalidate_rkey = mr->mr_handle;
+
+ last->wr_cqe->done = frwr_wc_localinv;
+
+ *prev = last;
+ prev = &last->next;
+ } while ((mr = rpcrdma_mr_pop(&req->rl_registered)));
+
+ mr = container_of(last, struct rpcrdma_mr, mr_invwr);
+
+ /* Strong send queue ordering guarantees that when the
+ * last WR in the chain completes, all WRs in the chain
+ * are complete.
+ */
+ last->wr_cqe->done = frwr_wc_localinv_wake;
+ reinit_completion(&mr->mr_linv_done);
+
+ /* Transport disconnect drains the receive CQ before it
+ * replaces the QP. The RPC reply handler won't call us
+ * unless re_id->qp is a valid pointer.
+ */
+ bad_wr = NULL;
+ rc = ib_post_send(ep->re_id->qp, first, &bad_wr);
+
+ /* The final LOCAL_INV WR in the chain is supposed to
+ * do the wake. If it was never posted, the wake will
+ * not happen, so don't wait in that case.
+ */
+ if (bad_wr != first)
+ wait_for_completion(&mr->mr_linv_done);
+ if (!rc)
+ return;
+
+ /* On error, the MRs get destroyed once the QP has drained. */
+ trace_xprtrdma_post_linv_err(req, rc);
+
+ /* Force a connection loss to ensure complete recovery.
+ */
+ rpcrdma_force_disconnect(ep);
+}
+
+/**
+ * frwr_wc_localinv_done - Invoked by RDMA provider for a signaled LOCAL_INV WC
+ * @cq: completion queue
+ * @wc: WCE for a completed LocalInv WR
+ *
+ */
+static void frwr_wc_localinv_done(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe);
+ struct rpcrdma_rep *rep;
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_li_done(wc, &mr->mr_cid);
+
+ /* Ensure that @rep is generated before the MR is released */
+ rep = mr->mr_req->rl_reply;
+ smp_rmb();
+
+ if (wc->status != IB_WC_SUCCESS) {
+ if (rep)
+ rpcrdma_unpin_rqst(rep);
+ rpcrdma_flush_disconnect(cq->cq_context, wc);
+ return;
+ }
+ frwr_mr_put(mr);
+ rpcrdma_complete_rqst(rep);
+}
+
+/**
+ * frwr_unmap_async - invalidate memory regions that were registered for @req
+ * @r_xprt: controlling transport instance
+ * @req: rpcrdma_req with a non-empty list of MRs to process
+ *
+ * This guarantees that registered MRs are properly fenced from the
+ * server before the RPC consumer accesses the data in them. It also
+ * ensures proper Send flow control: waking the next RPC waits until
+ * this RPC has relinquished all its Send Queue entries.
+ */
+void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req)
+{
+ struct ib_send_wr *first, *last, **prev;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rpcrdma_mr *mr;
+ int rc;
+
+ /* Chain the LOCAL_INV Work Requests and post them with
+ * a single ib_post_send() call.
+ */
+ prev = &first;
+ mr = rpcrdma_mr_pop(&req->rl_registered);
+ do {
+ trace_xprtrdma_mr_localinv(mr);
+ r_xprt->rx_stats.local_inv_needed++;
+
+ last = &mr->mr_invwr;
+ last->next = NULL;
+ last->wr_cqe = &mr->mr_cqe;
+ last->sg_list = NULL;
+ last->num_sge = 0;
+ last->opcode = IB_WR_LOCAL_INV;
+ last->send_flags = IB_SEND_SIGNALED;
+ last->ex.invalidate_rkey = mr->mr_handle;
+
+ last->wr_cqe->done = frwr_wc_localinv;
+
+ *prev = last;
+ prev = &last->next;
+ } while ((mr = rpcrdma_mr_pop(&req->rl_registered)));
+
+ /* Strong send queue ordering guarantees that when the
+ * last WR in the chain completes, all WRs in the chain
+ * are complete. The last completion will wake up the
+ * RPC waiter.
+ */
+ last->wr_cqe->done = frwr_wc_localinv_done;
+
+ /* Transport disconnect drains the receive CQ before it
+ * replaces the QP. The RPC reply handler won't call us
+ * unless re_id->qp is a valid pointer.
+ */
+ rc = ib_post_send(ep->re_id->qp, first, NULL);
+ if (!rc)
+ return;
+
+ /* On error, the MRs get destroyed once the QP has drained. */
+ trace_xprtrdma_post_linv_err(req, rc);
+
+ /* The final LOCAL_INV WR in the chain is supposed to
+ * do the wake. If it was never posted, the wake does
+ * not happen. Unpin the rqst in preparation for its
+ * retransmission.
+ */
+ rpcrdma_unpin_rqst(req->rl_reply);
+
+ /* Force a connection loss to ensure complete recovery.
+ */
+ rpcrdma_force_disconnect(ep);
+}
+
+/**
+ * frwr_wp_create - Create an MR for padding Write chunks
+ * @r_xprt: transport resources to use
+ *
+ * Return 0 on success, negative errno on failure.
+ */
+int frwr_wp_create(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rpcrdma_mr_seg seg;
+ struct rpcrdma_mr *mr;
+
+ mr = rpcrdma_mr_get(r_xprt);
+ if (!mr)
+ return -EAGAIN;
+ mr->mr_req = NULL;
+ ep->re_write_pad_mr = mr;
+
+ seg.mr_len = XDR_UNIT;
+ seg.mr_page = virt_to_page(ep->re_write_pad);
+ seg.mr_offset = offset_in_page(ep->re_write_pad);
+ if (IS_ERR(frwr_map(r_xprt, &seg, 1, true, xdr_zero, mr)))
+ return -EIO;
+ trace_xprtrdma_mr_fastreg(mr);
+
+ mr->mr_cqe.done = frwr_wc_fastreg;
+ mr->mr_regwr.wr.next = NULL;
+ mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe;
+ mr->mr_regwr.wr.num_sge = 0;
+ mr->mr_regwr.wr.opcode = IB_WR_REG_MR;
+ mr->mr_regwr.wr.send_flags = 0;
+
+ return ib_post_send(ep->re_id->qp, &mr->mr_regwr.wr, NULL);
+}
diff --git a/net/sunrpc/xprtrdma/module.c b/net/sunrpc/xprtrdma/module.c
new file mode 100644
index 0000000000..45c5b41ac8
--- /dev/null
+++ b/net/sunrpc/xprtrdma/module.c
@@ -0,0 +1,52 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2015, 2017 Oracle. All rights reserved.
+ */
+
+/* rpcrdma.ko module initialization
+ */
+
+#include <linux/types.h>
+#include <linux/compiler.h>
+#include <linux/module.h>
+#include <linux/init.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include <asm/swab.h>
+
+#include "xprt_rdma.h"
+
+#define CREATE_TRACE_POINTS
+#include <trace/events/rpcrdma.h>
+
+MODULE_AUTHOR("Open Grid Computing and Network Appliance, Inc.");
+MODULE_DESCRIPTION("RPC/RDMA Transport");
+MODULE_LICENSE("Dual BSD/GPL");
+MODULE_ALIAS("svcrdma");
+MODULE_ALIAS("xprtrdma");
+MODULE_ALIAS("rpcrdma6");
+
+static void __exit rpc_rdma_cleanup(void)
+{
+ xprt_rdma_cleanup();
+ svc_rdma_cleanup();
+}
+
+static int __init rpc_rdma_init(void)
+{
+ int rc;
+
+ rc = svc_rdma_init();
+ if (rc)
+ goto out;
+
+ rc = xprt_rdma_init();
+ if (rc)
+ svc_rdma_cleanup();
+
+out:
+ return rc;
+}
+
+module_init(rpc_rdma_init);
+module_exit(rpc_rdma_cleanup);
diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c
new file mode 100644
index 0000000000..190a4de239
--- /dev/null
+++ b/net/sunrpc/xprtrdma/rpc_rdma.c
@@ -0,0 +1,1510 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2014-2020, Oracle and/or its affiliates.
+ * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * rpc_rdma.c
+ *
+ * This file contains the guts of the RPC RDMA protocol, and
+ * does marshaling/unmarshaling, etc. It is also where interfacing
+ * to the Linux RPC framework lives.
+ */
+
+#include <linux/highmem.h>
+
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+/* Returns size of largest RPC-over-RDMA header in a Call message
+ *
+ * The largest Call header contains a full-size Read list and a
+ * minimal Reply chunk.
+ */
+static unsigned int rpcrdma_max_call_header_size(unsigned int maxsegs)
+{
+ unsigned int size;
+
+ /* Fixed header fields and list discriminators */
+ size = RPCRDMA_HDRLEN_MIN;
+
+ /* Maximum Read list size */
+ size += maxsegs * rpcrdma_readchunk_maxsz * sizeof(__be32);
+
+ /* Minimal Read chunk size */
+ size += sizeof(__be32); /* segment count */
+ size += rpcrdma_segment_maxsz * sizeof(__be32);
+ size += sizeof(__be32); /* list discriminator */
+
+ return size;
+}
+
+/* Returns size of largest RPC-over-RDMA header in a Reply message
+ *
+ * There is only one Write list or one Reply chunk per Reply
+ * message. The larger list is the Write list.
+ */
+static unsigned int rpcrdma_max_reply_header_size(unsigned int maxsegs)
+{
+ unsigned int size;
+
+ /* Fixed header fields and list discriminators */
+ size = RPCRDMA_HDRLEN_MIN;
+
+ /* Maximum Write list size */
+ size += sizeof(__be32); /* segment count */
+ size += maxsegs * rpcrdma_segment_maxsz * sizeof(__be32);
+ size += sizeof(__be32); /* list discriminator */
+
+ return size;
+}
+
+/**
+ * rpcrdma_set_max_header_sizes - Initialize inline payload sizes
+ * @ep: endpoint to initialize
+ *
+ * The max_inline fields contain the maximum size of an RPC message
+ * so the marshaling code doesn't have to repeat this calculation
+ * for every RPC.
+ */
+void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep)
+{
+ unsigned int maxsegs = ep->re_max_rdma_segs;
+
+ ep->re_max_inline_send =
+ ep->re_inline_send - rpcrdma_max_call_header_size(maxsegs);
+ ep->re_max_inline_recv =
+ ep->re_inline_recv - rpcrdma_max_reply_header_size(maxsegs);
+}
+
+/* The client can send a request inline as long as the RPCRDMA header
+ * plus the RPC call fit under the transport's inline limit. If the
+ * combined call message size exceeds that limit, the client must use
+ * a Read chunk for this operation.
+ *
+ * A Read chunk is also required if sending the RPC call inline would
+ * exceed this device's max_sge limit.
+ */
+static bool rpcrdma_args_inline(struct rpcrdma_xprt *r_xprt,
+ struct rpc_rqst *rqst)
+{
+ struct xdr_buf *xdr = &rqst->rq_snd_buf;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ unsigned int count, remaining, offset;
+
+ if (xdr->len > ep->re_max_inline_send)
+ return false;
+
+ if (xdr->page_len) {
+ remaining = xdr->page_len;
+ offset = offset_in_page(xdr->page_base);
+ count = RPCRDMA_MIN_SEND_SGES;
+ while (remaining) {
+ remaining -= min_t(unsigned int,
+ PAGE_SIZE - offset, remaining);
+ offset = 0;
+ if (++count > ep->re_attr.cap.max_send_sge)
+ return false;
+ }
+ }
+
+ return true;
+}
+
+/* The client can't know how large the actual reply will be. Thus it
+ * plans for the largest possible reply for that particular ULP
+ * operation. If the maximum combined reply message size exceeds that
+ * limit, the client must provide a write list or a reply chunk for
+ * this request.
+ */
+static bool rpcrdma_results_inline(struct rpcrdma_xprt *r_xprt,
+ struct rpc_rqst *rqst)
+{
+ return rqst->rq_rcv_buf.buflen <= r_xprt->rx_ep->re_max_inline_recv;
+}
+
+/* The client is required to provide a Reply chunk if the maximum
+ * size of the non-payload part of the RPC Reply is larger than
+ * the inline threshold.
+ */
+static bool
+rpcrdma_nonpayload_inline(const struct rpcrdma_xprt *r_xprt,
+ const struct rpc_rqst *rqst)
+{
+ const struct xdr_buf *buf = &rqst->rq_rcv_buf;
+
+ return (buf->head[0].iov_len + buf->tail[0].iov_len) <
+ r_xprt->rx_ep->re_max_inline_recv;
+}
+
+/* ACL likes to be lazy in allocating pages. For TCP, these
+ * pages can be allocated during receive processing. Not true
+ * for RDMA, which must always provision receive buffers
+ * up front.
+ */
+static noinline int
+rpcrdma_alloc_sparse_pages(struct xdr_buf *buf)
+{
+ struct page **ppages;
+ int len;
+
+ len = buf->page_len;
+ ppages = buf->pages + (buf->page_base >> PAGE_SHIFT);
+ while (len > 0) {
+ if (!*ppages)
+ *ppages = alloc_page(GFP_NOWAIT | __GFP_NOWARN);
+ if (!*ppages)
+ return -ENOBUFS;
+ ppages++;
+ len -= PAGE_SIZE;
+ }
+
+ return 0;
+}
+
+/* Convert @vec to a single SGL element.
+ *
+ * Returns pointer to next available SGE, and bumps the total number
+ * of SGEs consumed.
+ */
+static struct rpcrdma_mr_seg *
+rpcrdma_convert_kvec(struct kvec *vec, struct rpcrdma_mr_seg *seg,
+ unsigned int *n)
+{
+ seg->mr_page = virt_to_page(vec->iov_base);
+ seg->mr_offset = offset_in_page(vec->iov_base);
+ seg->mr_len = vec->iov_len;
+ ++seg;
+ ++(*n);
+ return seg;
+}
+
+/* Convert @xdrbuf into SGEs no larger than a page each. As they
+ * are registered, these SGEs are then coalesced into RDMA segments
+ * when the selected memreg mode supports it.
+ *
+ * Returns positive number of SGEs consumed, or a negative errno.
+ */
+
+static int
+rpcrdma_convert_iovs(struct rpcrdma_xprt *r_xprt, struct xdr_buf *xdrbuf,
+ unsigned int pos, enum rpcrdma_chunktype type,
+ struct rpcrdma_mr_seg *seg)
+{
+ unsigned long page_base;
+ unsigned int len, n;
+ struct page **ppages;
+
+ n = 0;
+ if (pos == 0)
+ seg = rpcrdma_convert_kvec(&xdrbuf->head[0], seg, &n);
+
+ len = xdrbuf->page_len;
+ ppages = xdrbuf->pages + (xdrbuf->page_base >> PAGE_SHIFT);
+ page_base = offset_in_page(xdrbuf->page_base);
+ while (len) {
+ seg->mr_page = *ppages;
+ seg->mr_offset = page_base;
+ seg->mr_len = min_t(u32, PAGE_SIZE - page_base, len);
+ len -= seg->mr_len;
+ ++ppages;
+ ++seg;
+ ++n;
+ page_base = 0;
+ }
+
+ if (type == rpcrdma_readch || type == rpcrdma_writech)
+ goto out;
+
+ if (xdrbuf->tail[0].iov_len)
+ rpcrdma_convert_kvec(&xdrbuf->tail[0], seg, &n);
+
+out:
+ if (unlikely(n > RPCRDMA_MAX_SEGS))
+ return -EIO;
+ return n;
+}
+
+static int
+encode_rdma_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4 * sizeof(*p));
+ if (unlikely(!p))
+ return -EMSGSIZE;
+
+ xdr_encode_rdma_segment(p, mr->mr_handle, mr->mr_length, mr->mr_offset);
+ return 0;
+}
+
+static int
+encode_read_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr,
+ u32 position)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 6 * sizeof(*p));
+ if (unlikely(!p))
+ return -EMSGSIZE;
+
+ *p++ = xdr_one; /* Item present */
+ xdr_encode_read_segment(p, position, mr->mr_handle, mr->mr_length,
+ mr->mr_offset);
+ return 0;
+}
+
+static struct rpcrdma_mr_seg *rpcrdma_mr_prepare(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct rpcrdma_mr_seg *seg,
+ int nsegs, bool writing,
+ struct rpcrdma_mr **mr)
+{
+ *mr = rpcrdma_mr_pop(&req->rl_free_mrs);
+ if (!*mr) {
+ *mr = rpcrdma_mr_get(r_xprt);
+ if (!*mr)
+ goto out_getmr_err;
+ (*mr)->mr_req = req;
+ }
+
+ rpcrdma_mr_push(*mr, &req->rl_registered);
+ return frwr_map(r_xprt, seg, nsegs, writing, req->rl_slot.rq_xid, *mr);
+
+out_getmr_err:
+ trace_xprtrdma_nomrs_err(r_xprt, req);
+ xprt_wait_for_buffer_space(&r_xprt->rx_xprt);
+ rpcrdma_mrs_refresh(r_xprt);
+ return ERR_PTR(-EAGAIN);
+}
+
+/* Register and XDR encode the Read list. Supports encoding a list of read
+ * segments that belong to a single read chunk.
+ *
+ * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64):
+ *
+ * Read chunklist (a linked list):
+ * N elements, position P (same P for all chunks of same arg!):
+ * 1 - PHLOO - 1 - PHLOO - ... - 1 - PHLOO - 0
+ *
+ * Returns zero on success, or a negative errno if a failure occurred.
+ * @xdr is advanced to the next position in the stream.
+ *
+ * Only a single @pos value is currently supported.
+ */
+static int rpcrdma_encode_read_list(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct rpc_rqst *rqst,
+ enum rpcrdma_chunktype rtype)
+{
+ struct xdr_stream *xdr = &req->rl_stream;
+ struct rpcrdma_mr_seg *seg;
+ struct rpcrdma_mr *mr;
+ unsigned int pos;
+ int nsegs;
+
+ if (rtype == rpcrdma_noch_pullup || rtype == rpcrdma_noch_mapped)
+ goto done;
+
+ pos = rqst->rq_snd_buf.head[0].iov_len;
+ if (rtype == rpcrdma_areadch)
+ pos = 0;
+ seg = req->rl_segments;
+ nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_snd_buf, pos,
+ rtype, seg);
+ if (nsegs < 0)
+ return nsegs;
+
+ do {
+ seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, false, &mr);
+ if (IS_ERR(seg))
+ return PTR_ERR(seg);
+
+ if (encode_read_segment(xdr, mr, pos) < 0)
+ return -EMSGSIZE;
+
+ trace_xprtrdma_chunk_read(rqst->rq_task, pos, mr, nsegs);
+ r_xprt->rx_stats.read_chunk_count++;
+ nsegs -= mr->mr_nents;
+ } while (nsegs);
+
+done:
+ if (xdr_stream_encode_item_absent(xdr) < 0)
+ return -EMSGSIZE;
+ return 0;
+}
+
+/* Register and XDR encode the Write list. Supports encoding a list
+ * containing one array of plain segments that belong to a single
+ * write chunk.
+ *
+ * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64):
+ *
+ * Write chunklist (a list of (one) counted array):
+ * N elements:
+ * 1 - N - HLOO - HLOO - ... - HLOO - 0
+ *
+ * Returns zero on success, or a negative errno if a failure occurred.
+ * @xdr is advanced to the next position in the stream.
+ *
+ * Only a single Write chunk is currently supported.
+ */
+static int rpcrdma_encode_write_list(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct rpc_rqst *rqst,
+ enum rpcrdma_chunktype wtype)
+{
+ struct xdr_stream *xdr = &req->rl_stream;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rpcrdma_mr_seg *seg;
+ struct rpcrdma_mr *mr;
+ int nsegs, nchunks;
+ __be32 *segcount;
+
+ if (wtype != rpcrdma_writech)
+ goto done;
+
+ seg = req->rl_segments;
+ nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf,
+ rqst->rq_rcv_buf.head[0].iov_len,
+ wtype, seg);
+ if (nsegs < 0)
+ return nsegs;
+
+ if (xdr_stream_encode_item_present(xdr) < 0)
+ return -EMSGSIZE;
+ segcount = xdr_reserve_space(xdr, sizeof(*segcount));
+ if (unlikely(!segcount))
+ return -EMSGSIZE;
+ /* Actual value encoded below */
+
+ nchunks = 0;
+ do {
+ seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr);
+ if (IS_ERR(seg))
+ return PTR_ERR(seg);
+
+ if (encode_rdma_segment(xdr, mr) < 0)
+ return -EMSGSIZE;
+
+ trace_xprtrdma_chunk_write(rqst->rq_task, mr, nsegs);
+ r_xprt->rx_stats.write_chunk_count++;
+ r_xprt->rx_stats.total_rdma_request += mr->mr_length;
+ nchunks++;
+ nsegs -= mr->mr_nents;
+ } while (nsegs);
+
+ if (xdr_pad_size(rqst->rq_rcv_buf.page_len)) {
+ if (encode_rdma_segment(xdr, ep->re_write_pad_mr) < 0)
+ return -EMSGSIZE;
+
+ trace_xprtrdma_chunk_wp(rqst->rq_task, ep->re_write_pad_mr,
+ nsegs);
+ r_xprt->rx_stats.write_chunk_count++;
+ r_xprt->rx_stats.total_rdma_request += mr->mr_length;
+ nchunks++;
+ nsegs -= mr->mr_nents;
+ }
+
+ /* Update count of segments in this Write chunk */
+ *segcount = cpu_to_be32(nchunks);
+
+done:
+ if (xdr_stream_encode_item_absent(xdr) < 0)
+ return -EMSGSIZE;
+ return 0;
+}
+
+/* Register and XDR encode the Reply chunk. Supports encoding an array
+ * of plain segments that belong to a single write (reply) chunk.
+ *
+ * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64):
+ *
+ * Reply chunk (a counted array):
+ * N elements:
+ * 1 - N - HLOO - HLOO - ... - HLOO
+ *
+ * Returns zero on success, or a negative errno if a failure occurred.
+ * @xdr is advanced to the next position in the stream.
+ */
+static int rpcrdma_encode_reply_chunk(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct rpc_rqst *rqst,
+ enum rpcrdma_chunktype wtype)
+{
+ struct xdr_stream *xdr = &req->rl_stream;
+ struct rpcrdma_mr_seg *seg;
+ struct rpcrdma_mr *mr;
+ int nsegs, nchunks;
+ __be32 *segcount;
+
+ if (wtype != rpcrdma_replych) {
+ if (xdr_stream_encode_item_absent(xdr) < 0)
+ return -EMSGSIZE;
+ return 0;
+ }
+
+ seg = req->rl_segments;
+ nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, 0, wtype, seg);
+ if (nsegs < 0)
+ return nsegs;
+
+ if (xdr_stream_encode_item_present(xdr) < 0)
+ return -EMSGSIZE;
+ segcount = xdr_reserve_space(xdr, sizeof(*segcount));
+ if (unlikely(!segcount))
+ return -EMSGSIZE;
+ /* Actual value encoded below */
+
+ nchunks = 0;
+ do {
+ seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr);
+ if (IS_ERR(seg))
+ return PTR_ERR(seg);
+
+ if (encode_rdma_segment(xdr, mr) < 0)
+ return -EMSGSIZE;
+
+ trace_xprtrdma_chunk_reply(rqst->rq_task, mr, nsegs);
+ r_xprt->rx_stats.reply_chunk_count++;
+ r_xprt->rx_stats.total_rdma_request += mr->mr_length;
+ nchunks++;
+ nsegs -= mr->mr_nents;
+ } while (nsegs);
+
+ /* Update count of segments in the Reply chunk */
+ *segcount = cpu_to_be32(nchunks);
+
+ return 0;
+}
+
+static void rpcrdma_sendctx_done(struct kref *kref)
+{
+ struct rpcrdma_req *req =
+ container_of(kref, struct rpcrdma_req, rl_kref);
+ struct rpcrdma_rep *rep = req->rl_reply;
+
+ rpcrdma_complete_rqst(rep);
+ rep->rr_rxprt->rx_stats.reply_waits_for_send++;
+}
+
+/**
+ * rpcrdma_sendctx_unmap - DMA-unmap Send buffer
+ * @sc: sendctx containing SGEs to unmap
+ *
+ */
+void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc)
+{
+ struct rpcrdma_regbuf *rb = sc->sc_req->rl_sendbuf;
+ struct ib_sge *sge;
+
+ if (!sc->sc_unmap_count)
+ return;
+
+ /* The first two SGEs contain the transport header and
+ * the inline buffer. These are always left mapped so
+ * they can be cheaply re-used.
+ */
+ for (sge = &sc->sc_sges[2]; sc->sc_unmap_count;
+ ++sge, --sc->sc_unmap_count)
+ ib_dma_unmap_page(rdmab_device(rb), sge->addr, sge->length,
+ DMA_TO_DEVICE);
+
+ kref_put(&sc->sc_req->rl_kref, rpcrdma_sendctx_done);
+}
+
+/* Prepare an SGE for the RPC-over-RDMA transport header.
+ */
+static void rpcrdma_prepare_hdr_sge(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req, u32 len)
+{
+ struct rpcrdma_sendctx *sc = req->rl_sendctx;
+ struct rpcrdma_regbuf *rb = req->rl_rdmabuf;
+ struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++];
+
+ sge->addr = rdmab_addr(rb);
+ sge->length = len;
+ sge->lkey = rdmab_lkey(rb);
+
+ ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length,
+ DMA_TO_DEVICE);
+}
+
+/* The head iovec is straightforward, as it is usually already
+ * DMA-mapped. Sync the content that has changed.
+ */
+static bool rpcrdma_prepare_head_iov(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req, unsigned int len)
+{
+ struct rpcrdma_sendctx *sc = req->rl_sendctx;
+ struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++];
+ struct rpcrdma_regbuf *rb = req->rl_sendbuf;
+
+ if (!rpcrdma_regbuf_dma_map(r_xprt, rb))
+ return false;
+
+ sge->addr = rdmab_addr(rb);
+ sge->length = len;
+ sge->lkey = rdmab_lkey(rb);
+
+ ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length,
+ DMA_TO_DEVICE);
+ return true;
+}
+
+/* If there is a page list present, DMA map and prepare an
+ * SGE for each page to be sent.
+ */
+static bool rpcrdma_prepare_pagelist(struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ struct rpcrdma_sendctx *sc = req->rl_sendctx;
+ struct rpcrdma_regbuf *rb = req->rl_sendbuf;
+ unsigned int page_base, len, remaining;
+ struct page **ppages;
+ struct ib_sge *sge;
+
+ ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT);
+ page_base = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining) {
+ sge = &sc->sc_sges[req->rl_wr.num_sge++];
+ len = min_t(unsigned int, PAGE_SIZE - page_base, remaining);
+ sge->addr = ib_dma_map_page(rdmab_device(rb), *ppages,
+ page_base, len, DMA_TO_DEVICE);
+ if (ib_dma_mapping_error(rdmab_device(rb), sge->addr))
+ goto out_mapping_err;
+
+ sge->length = len;
+ sge->lkey = rdmab_lkey(rb);
+
+ sc->sc_unmap_count++;
+ ppages++;
+ remaining -= len;
+ page_base = 0;
+ }
+
+ return true;
+
+out_mapping_err:
+ trace_xprtrdma_dma_maperr(sge->addr);
+ return false;
+}
+
+/* The tail iovec may include an XDR pad for the page list,
+ * as well as additional content, and may not reside in the
+ * same page as the head iovec.
+ */
+static bool rpcrdma_prepare_tail_iov(struct rpcrdma_req *req,
+ struct xdr_buf *xdr,
+ unsigned int page_base, unsigned int len)
+{
+ struct rpcrdma_sendctx *sc = req->rl_sendctx;
+ struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++];
+ struct rpcrdma_regbuf *rb = req->rl_sendbuf;
+ struct page *page = virt_to_page(xdr->tail[0].iov_base);
+
+ sge->addr = ib_dma_map_page(rdmab_device(rb), page, page_base, len,
+ DMA_TO_DEVICE);
+ if (ib_dma_mapping_error(rdmab_device(rb), sge->addr))
+ goto out_mapping_err;
+
+ sge->length = len;
+ sge->lkey = rdmab_lkey(rb);
+ ++sc->sc_unmap_count;
+ return true;
+
+out_mapping_err:
+ trace_xprtrdma_dma_maperr(sge->addr);
+ return false;
+}
+
+/* Copy the tail to the end of the head buffer.
+ */
+static void rpcrdma_pullup_tail_iov(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ unsigned char *dst;
+
+ dst = (unsigned char *)xdr->head[0].iov_base;
+ dst += xdr->head[0].iov_len + xdr->page_len;
+ memmove(dst, xdr->tail[0].iov_base, xdr->tail[0].iov_len);
+ r_xprt->rx_stats.pullup_copy_count += xdr->tail[0].iov_len;
+}
+
+/* Copy pagelist content into the head buffer.
+ */
+static void rpcrdma_pullup_pagelist(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ unsigned int len, page_base, remaining;
+ struct page **ppages;
+ unsigned char *src, *dst;
+
+ dst = (unsigned char *)xdr->head[0].iov_base;
+ dst += xdr->head[0].iov_len;
+ ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT);
+ page_base = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining) {
+ src = page_address(*ppages);
+ src += page_base;
+ len = min_t(unsigned int, PAGE_SIZE - page_base, remaining);
+ memcpy(dst, src, len);
+ r_xprt->rx_stats.pullup_copy_count += len;
+
+ ppages++;
+ dst += len;
+ remaining -= len;
+ page_base = 0;
+ }
+}
+
+/* Copy the contents of @xdr into @rl_sendbuf and DMA sync it.
+ * When the head, pagelist, and tail are small, a pull-up copy
+ * is considerably less costly than DMA mapping the components
+ * of @xdr.
+ *
+ * Assumptions:
+ * - the caller has already verified that the total length
+ * of the RPC Call body will fit into @rl_sendbuf.
+ */
+static bool rpcrdma_prepare_noch_pullup(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ if (unlikely(xdr->tail[0].iov_len))
+ rpcrdma_pullup_tail_iov(r_xprt, req, xdr);
+
+ if (unlikely(xdr->page_len))
+ rpcrdma_pullup_pagelist(r_xprt, req, xdr);
+
+ /* The whole RPC message resides in the head iovec now */
+ return rpcrdma_prepare_head_iov(r_xprt, req, xdr->len);
+}
+
+static bool rpcrdma_prepare_noch_mapped(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ struct kvec *tail = &xdr->tail[0];
+
+ if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len))
+ return false;
+ if (xdr->page_len)
+ if (!rpcrdma_prepare_pagelist(req, xdr))
+ return false;
+ if (tail->iov_len)
+ if (!rpcrdma_prepare_tail_iov(req, xdr,
+ offset_in_page(tail->iov_base),
+ tail->iov_len))
+ return false;
+
+ if (req->rl_sendctx->sc_unmap_count)
+ kref_get(&req->rl_kref);
+ return true;
+}
+
+static bool rpcrdma_prepare_readch(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req,
+ struct xdr_buf *xdr)
+{
+ if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len))
+ return false;
+
+ /* If there is a Read chunk, the page list is being handled
+ * via explicit RDMA, and thus is skipped here.
+ */
+
+ /* Do not include the tail if it is only an XDR pad */
+ if (xdr->tail[0].iov_len > 3) {
+ unsigned int page_base, len;
+
+ /* If the content in the page list is an odd length,
+ * xdr_write_pages() adds a pad at the beginning of
+ * the tail iovec. Force the tail's non-pad content to
+ * land at the next XDR position in the Send message.
+ */
+ page_base = offset_in_page(xdr->tail[0].iov_base);
+ len = xdr->tail[0].iov_len;
+ page_base += len & 3;
+ len -= len & 3;
+ if (!rpcrdma_prepare_tail_iov(req, xdr, page_base, len))
+ return false;
+ kref_get(&req->rl_kref);
+ }
+
+ return true;
+}
+
+/**
+ * rpcrdma_prepare_send_sges - Construct SGEs for a Send WR
+ * @r_xprt: controlling transport
+ * @req: context of RPC Call being marshalled
+ * @hdrlen: size of transport header, in bytes
+ * @xdr: xdr_buf containing RPC Call
+ * @rtype: chunk type being encoded
+ *
+ * Returns 0 on success; otherwise a negative errno is returned.
+ */
+inline int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req, u32 hdrlen,
+ struct xdr_buf *xdr,
+ enum rpcrdma_chunktype rtype)
+{
+ int ret;
+
+ ret = -EAGAIN;
+ req->rl_sendctx = rpcrdma_sendctx_get_locked(r_xprt);
+ if (!req->rl_sendctx)
+ goto out_nosc;
+ req->rl_sendctx->sc_unmap_count = 0;
+ req->rl_sendctx->sc_req = req;
+ kref_init(&req->rl_kref);
+ req->rl_wr.wr_cqe = &req->rl_sendctx->sc_cqe;
+ req->rl_wr.sg_list = req->rl_sendctx->sc_sges;
+ req->rl_wr.num_sge = 0;
+ req->rl_wr.opcode = IB_WR_SEND;
+
+ rpcrdma_prepare_hdr_sge(r_xprt, req, hdrlen);
+
+ ret = -EIO;
+ switch (rtype) {
+ case rpcrdma_noch_pullup:
+ if (!rpcrdma_prepare_noch_pullup(r_xprt, req, xdr))
+ goto out_unmap;
+ break;
+ case rpcrdma_noch_mapped:
+ if (!rpcrdma_prepare_noch_mapped(r_xprt, req, xdr))
+ goto out_unmap;
+ break;
+ case rpcrdma_readch:
+ if (!rpcrdma_prepare_readch(r_xprt, req, xdr))
+ goto out_unmap;
+ break;
+ case rpcrdma_areadch:
+ break;
+ default:
+ goto out_unmap;
+ }
+
+ return 0;
+
+out_unmap:
+ rpcrdma_sendctx_unmap(req->rl_sendctx);
+out_nosc:
+ trace_xprtrdma_prepsend_failed(&req->rl_slot, ret);
+ return ret;
+}
+
+/**
+ * rpcrdma_marshal_req - Marshal and send one RPC request
+ * @r_xprt: controlling transport
+ * @rqst: RPC request to be marshaled
+ *
+ * For the RPC in "rqst", this function:
+ * - Chooses the transfer mode (eg., RDMA_MSG or RDMA_NOMSG)
+ * - Registers Read, Write, and Reply chunks
+ * - Constructs the transport header
+ * - Posts a Send WR to send the transport header and request
+ *
+ * Returns:
+ * %0 if the RPC was sent successfully,
+ * %-ENOTCONN if the connection was lost,
+ * %-EAGAIN if the caller should call again with the same arguments,
+ * %-ENOBUFS if the caller should call again after a delay,
+ * %-EMSGSIZE if the transport header is too small,
+ * %-EIO if a permanent problem occurred while marshaling.
+ */
+int
+rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst)
+{
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ struct xdr_stream *xdr = &req->rl_stream;
+ enum rpcrdma_chunktype rtype, wtype;
+ struct xdr_buf *buf = &rqst->rq_snd_buf;
+ bool ddp_allowed;
+ __be32 *p;
+ int ret;
+
+ if (unlikely(rqst->rq_rcv_buf.flags & XDRBUF_SPARSE_PAGES)) {
+ ret = rpcrdma_alloc_sparse_pages(&rqst->rq_rcv_buf);
+ if (ret)
+ return ret;
+ }
+
+ rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0);
+ xdr_init_encode(xdr, &req->rl_hdrbuf, rdmab_data(req->rl_rdmabuf),
+ rqst);
+
+ /* Fixed header fields */
+ ret = -EMSGSIZE;
+ p = xdr_reserve_space(xdr, 4 * sizeof(*p));
+ if (!p)
+ goto out_err;
+ *p++ = rqst->rq_xid;
+ *p++ = rpcrdma_version;
+ *p++ = r_xprt->rx_buf.rb_max_requests;
+
+ /* When the ULP employs a GSS flavor that guarantees integrity
+ * or privacy, direct data placement of individual data items
+ * is not allowed.
+ */
+ ddp_allowed = !test_bit(RPCAUTH_AUTH_DATATOUCH,
+ &rqst->rq_cred->cr_auth->au_flags);
+
+ /*
+ * Chunks needed for results?
+ *
+ * o If the expected result is under the inline threshold, all ops
+ * return as inline.
+ * o Large read ops return data as write chunk(s), header as
+ * inline.
+ * o Large non-read ops return as a single reply chunk.
+ */
+ if (rpcrdma_results_inline(r_xprt, rqst))
+ wtype = rpcrdma_noch;
+ else if ((ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) &&
+ rpcrdma_nonpayload_inline(r_xprt, rqst))
+ wtype = rpcrdma_writech;
+ else
+ wtype = rpcrdma_replych;
+
+ /*
+ * Chunks needed for arguments?
+ *
+ * o If the total request is under the inline threshold, all ops
+ * are sent as inline.
+ * o Large write ops transmit data as read chunk(s), header as
+ * inline.
+ * o Large non-write ops are sent with the entire message as a
+ * single read chunk (protocol 0-position special case).
+ *
+ * This assumes that the upper layer does not present a request
+ * that both has a data payload, and whose non-data arguments
+ * by themselves are larger than the inline threshold.
+ */
+ if (rpcrdma_args_inline(r_xprt, rqst)) {
+ *p++ = rdma_msg;
+ rtype = buf->len < rdmab_length(req->rl_sendbuf) ?
+ rpcrdma_noch_pullup : rpcrdma_noch_mapped;
+ } else if (ddp_allowed && buf->flags & XDRBUF_WRITE) {
+ *p++ = rdma_msg;
+ rtype = rpcrdma_readch;
+ } else {
+ r_xprt->rx_stats.nomsg_call_count++;
+ *p++ = rdma_nomsg;
+ rtype = rpcrdma_areadch;
+ }
+
+ /* This implementation supports the following combinations
+ * of chunk lists in one RPC-over-RDMA Call message:
+ *
+ * - Read list
+ * - Write list
+ * - Reply chunk
+ * - Read list + Reply chunk
+ *
+ * It might not yet support the following combinations:
+ *
+ * - Read list + Write list
+ *
+ * It does not support the following combinations:
+ *
+ * - Write list + Reply chunk
+ * - Read list + Write list + Reply chunk
+ *
+ * This implementation supports only a single chunk in each
+ * Read or Write list. Thus for example the client cannot
+ * send a Call message with a Position Zero Read chunk and a
+ * regular Read chunk at the same time.
+ */
+ ret = rpcrdma_encode_read_list(r_xprt, req, rqst, rtype);
+ if (ret)
+ goto out_err;
+ ret = rpcrdma_encode_write_list(r_xprt, req, rqst, wtype);
+ if (ret)
+ goto out_err;
+ ret = rpcrdma_encode_reply_chunk(r_xprt, req, rqst, wtype);
+ if (ret)
+ goto out_err;
+
+ ret = rpcrdma_prepare_send_sges(r_xprt, req, req->rl_hdrbuf.len,
+ buf, rtype);
+ if (ret)
+ goto out_err;
+
+ trace_xprtrdma_marshal(req, rtype, wtype);
+ return 0;
+
+out_err:
+ trace_xprtrdma_marshal_failed(rqst, ret);
+ r_xprt->rx_stats.failed_marshal_count++;
+ frwr_reset(req);
+ return ret;
+}
+
+static void __rpcrdma_update_cwnd_locked(struct rpc_xprt *xprt,
+ struct rpcrdma_buffer *buf,
+ u32 grant)
+{
+ buf->rb_credits = grant;
+ xprt->cwnd = grant << RPC_CWNDSHIFT;
+}
+
+static void rpcrdma_update_cwnd(struct rpcrdma_xprt *r_xprt, u32 grant)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+
+ spin_lock(&xprt->transport_lock);
+ __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, grant);
+ spin_unlock(&xprt->transport_lock);
+}
+
+/**
+ * rpcrdma_reset_cwnd - Reset the xprt's congestion window
+ * @r_xprt: controlling transport instance
+ *
+ * Prepare @r_xprt for the next connection by reinitializing
+ * its credit grant to one (see RFC 8166, Section 3.3.3).
+ */
+void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+
+ spin_lock(&xprt->transport_lock);
+ xprt->cong = 0;
+ __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, 1);
+ spin_unlock(&xprt->transport_lock);
+}
+
+/**
+ * rpcrdma_inline_fixup - Scatter inline received data into rqst's iovecs
+ * @rqst: controlling RPC request
+ * @srcp: points to RPC message payload in receive buffer
+ * @copy_len: remaining length of receive buffer content
+ * @pad: Write chunk pad bytes needed (zero for pure inline)
+ *
+ * The upper layer has set the maximum number of bytes it can
+ * receive in each component of rq_rcv_buf. These values are set in
+ * the head.iov_len, page_len, tail.iov_len, and buflen fields.
+ *
+ * Unlike the TCP equivalent (xdr_partial_copy_from_skb), in
+ * many cases this function simply updates iov_base pointers in
+ * rq_rcv_buf to point directly to the received reply data, to
+ * avoid copying reply data.
+ *
+ * Returns the count of bytes which had to be memcopied.
+ */
+static unsigned long
+rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad)
+{
+ unsigned long fixup_copy_count;
+ int i, npages, curlen;
+ char *destp;
+ struct page **ppages;
+ int page_base;
+
+ /* The head iovec is redirected to the RPC reply message
+ * in the receive buffer, to avoid a memcopy.
+ */
+ rqst->rq_rcv_buf.head[0].iov_base = srcp;
+ rqst->rq_private_buf.head[0].iov_base = srcp;
+
+ /* The contents of the receive buffer that follow
+ * head.iov_len bytes are copied into the page list.
+ */
+ curlen = rqst->rq_rcv_buf.head[0].iov_len;
+ if (curlen > copy_len)
+ curlen = copy_len;
+ srcp += curlen;
+ copy_len -= curlen;
+
+ ppages = rqst->rq_rcv_buf.pages +
+ (rqst->rq_rcv_buf.page_base >> PAGE_SHIFT);
+ page_base = offset_in_page(rqst->rq_rcv_buf.page_base);
+ fixup_copy_count = 0;
+ if (copy_len && rqst->rq_rcv_buf.page_len) {
+ int pagelist_len;
+
+ pagelist_len = rqst->rq_rcv_buf.page_len;
+ if (pagelist_len > copy_len)
+ pagelist_len = copy_len;
+ npages = PAGE_ALIGN(page_base + pagelist_len) >> PAGE_SHIFT;
+ for (i = 0; i < npages; i++) {
+ curlen = PAGE_SIZE - page_base;
+ if (curlen > pagelist_len)
+ curlen = pagelist_len;
+
+ destp = kmap_atomic(ppages[i]);
+ memcpy(destp + page_base, srcp, curlen);
+ flush_dcache_page(ppages[i]);
+ kunmap_atomic(destp);
+ srcp += curlen;
+ copy_len -= curlen;
+ fixup_copy_count += curlen;
+ pagelist_len -= curlen;
+ if (!pagelist_len)
+ break;
+ page_base = 0;
+ }
+
+ /* Implicit padding for the last segment in a Write
+ * chunk is inserted inline at the front of the tail
+ * iovec. The upper layer ignores the content of
+ * the pad. Simply ensure inline content in the tail
+ * that follows the Write chunk is properly aligned.
+ */
+ if (pad)
+ srcp -= pad;
+ }
+
+ /* The tail iovec is redirected to the remaining data
+ * in the receive buffer, to avoid a memcopy.
+ */
+ if (copy_len || pad) {
+ rqst->rq_rcv_buf.tail[0].iov_base = srcp;
+ rqst->rq_private_buf.tail[0].iov_base = srcp;
+ }
+
+ if (fixup_copy_count)
+ trace_xprtrdma_fixup(rqst, fixup_copy_count);
+ return fixup_copy_count;
+}
+
+/* By convention, backchannel calls arrive via rdma_msg type
+ * messages, and never populate the chunk lists. This makes
+ * the RPC/RDMA header small and fixed in size, so it is
+ * straightforward to check the RPC header's direction field.
+ */
+static bool
+rpcrdma_is_bcall(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep)
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct xdr_stream *xdr = &rep->rr_stream;
+ __be32 *p;
+
+ if (rep->rr_proc != rdma_msg)
+ return false;
+
+ /* Peek at stream contents without advancing. */
+ p = xdr_inline_decode(xdr, 0);
+
+ /* Chunk lists */
+ if (xdr_item_is_present(p++))
+ return false;
+ if (xdr_item_is_present(p++))
+ return false;
+ if (xdr_item_is_present(p++))
+ return false;
+
+ /* RPC header */
+ if (*p++ != rep->rr_xid)
+ return false;
+ if (*p != cpu_to_be32(RPC_CALL))
+ return false;
+
+ /* No bc service. */
+ if (xprt->bc_serv == NULL)
+ return false;
+
+ /* Now that we are sure this is a backchannel call,
+ * advance to the RPC header.
+ */
+ p = xdr_inline_decode(xdr, 3 * sizeof(*p));
+ if (unlikely(!p))
+ return true;
+
+ rpcrdma_bc_receive_call(r_xprt, rep);
+ return true;
+}
+#else /* CONFIG_SUNRPC_BACKCHANNEL */
+{
+ return false;
+}
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+static int decode_rdma_segment(struct xdr_stream *xdr, u32 *length)
+{
+ u32 handle;
+ u64 offset;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, 4 * sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+
+ xdr_decode_rdma_segment(p, &handle, length, &offset);
+ trace_xprtrdma_decode_seg(handle, *length, offset);
+ return 0;
+}
+
+static int decode_write_chunk(struct xdr_stream *xdr, u32 *length)
+{
+ u32 segcount, seglength;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+
+ *length = 0;
+ segcount = be32_to_cpup(p);
+ while (segcount--) {
+ if (decode_rdma_segment(xdr, &seglength))
+ return -EIO;
+ *length += seglength;
+ }
+
+ return 0;
+}
+
+/* In RPC-over-RDMA Version One replies, a Read list is never
+ * expected. This decoder is a stub that returns an error if
+ * a Read list is present.
+ */
+static int decode_read_list(struct xdr_stream *xdr)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+ if (unlikely(xdr_item_is_present(p)))
+ return -EIO;
+ return 0;
+}
+
+/* Supports only one Write chunk in the Write list
+ */
+static int decode_write_list(struct xdr_stream *xdr, u32 *length)
+{
+ u32 chunklen;
+ bool first;
+ __be32 *p;
+
+ *length = 0;
+ first = true;
+ do {
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+ if (xdr_item_is_absent(p))
+ break;
+ if (!first)
+ return -EIO;
+
+ if (decode_write_chunk(xdr, &chunklen))
+ return -EIO;
+ *length += chunklen;
+ first = false;
+ } while (true);
+ return 0;
+}
+
+static int decode_reply_chunk(struct xdr_stream *xdr, u32 *length)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+
+ *length = 0;
+ if (xdr_item_is_present(p))
+ if (decode_write_chunk(xdr, length))
+ return -EIO;
+ return 0;
+}
+
+static int
+rpcrdma_decode_msg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep,
+ struct rpc_rqst *rqst)
+{
+ struct xdr_stream *xdr = &rep->rr_stream;
+ u32 writelist, replychunk, rpclen;
+ char *base;
+
+ /* Decode the chunk lists */
+ if (decode_read_list(xdr))
+ return -EIO;
+ if (decode_write_list(xdr, &writelist))
+ return -EIO;
+ if (decode_reply_chunk(xdr, &replychunk))
+ return -EIO;
+
+ /* RDMA_MSG sanity checks */
+ if (unlikely(replychunk))
+ return -EIO;
+
+ /* Build the RPC reply's Payload stream in rqst->rq_rcv_buf */
+ base = (char *)xdr_inline_decode(xdr, 0);
+ rpclen = xdr_stream_remaining(xdr);
+ r_xprt->rx_stats.fixup_copy_count +=
+ rpcrdma_inline_fixup(rqst, base, rpclen, writelist & 3);
+
+ r_xprt->rx_stats.total_rdma_reply += writelist;
+ return rpclen + xdr_align_size(writelist);
+}
+
+static noinline int
+rpcrdma_decode_nomsg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep)
+{
+ struct xdr_stream *xdr = &rep->rr_stream;
+ u32 writelist, replychunk;
+
+ /* Decode the chunk lists */
+ if (decode_read_list(xdr))
+ return -EIO;
+ if (decode_write_list(xdr, &writelist))
+ return -EIO;
+ if (decode_reply_chunk(xdr, &replychunk))
+ return -EIO;
+
+ /* RDMA_NOMSG sanity checks */
+ if (unlikely(writelist))
+ return -EIO;
+ if (unlikely(!replychunk))
+ return -EIO;
+
+ /* Reply chunk buffer already is the reply vector */
+ r_xprt->rx_stats.total_rdma_reply += replychunk;
+ return replychunk;
+}
+
+static noinline int
+rpcrdma_decode_error(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep,
+ struct rpc_rqst *rqst)
+{
+ struct xdr_stream *xdr = &rep->rr_stream;
+ __be32 *p;
+
+ p = xdr_inline_decode(xdr, sizeof(*p));
+ if (unlikely(!p))
+ return -EIO;
+
+ switch (*p) {
+ case err_vers:
+ p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+ if (!p)
+ break;
+ trace_xprtrdma_err_vers(rqst, p, p + 1);
+ break;
+ case err_chunk:
+ trace_xprtrdma_err_chunk(rqst);
+ break;
+ default:
+ trace_xprtrdma_err_unrecognized(rqst, p);
+ }
+
+ return -EIO;
+}
+
+/**
+ * rpcrdma_unpin_rqst - Release rqst without completing it
+ * @rep: RPC/RDMA Receive context
+ *
+ * This is done when a connection is lost so that a Reply
+ * can be dropped and its matching Call can be subsequently
+ * retransmitted on a new connection.
+ */
+void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep)
+{
+ struct rpc_xprt *xprt = &rep->rr_rxprt->rx_xprt;
+ struct rpc_rqst *rqst = rep->rr_rqst;
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+
+ req->rl_reply = NULL;
+ rep->rr_rqst = NULL;
+
+ spin_lock(&xprt->queue_lock);
+ xprt_unpin_rqst(rqst);
+ spin_unlock(&xprt->queue_lock);
+}
+
+/**
+ * rpcrdma_complete_rqst - Pass completed rqst back to RPC
+ * @rep: RPC/RDMA Receive context
+ *
+ * Reconstruct the RPC reply and complete the transaction
+ * while @rqst is still pinned to ensure the rep, rqst, and
+ * rq_task pointers remain stable.
+ */
+void rpcrdma_complete_rqst(struct rpcrdma_rep *rep)
+{
+ struct rpcrdma_xprt *r_xprt = rep->rr_rxprt;
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct rpc_rqst *rqst = rep->rr_rqst;
+ int status;
+
+ switch (rep->rr_proc) {
+ case rdma_msg:
+ status = rpcrdma_decode_msg(r_xprt, rep, rqst);
+ break;
+ case rdma_nomsg:
+ status = rpcrdma_decode_nomsg(r_xprt, rep);
+ break;
+ case rdma_error:
+ status = rpcrdma_decode_error(r_xprt, rep, rqst);
+ break;
+ default:
+ status = -EIO;
+ }
+ if (status < 0)
+ goto out_badheader;
+
+out:
+ spin_lock(&xprt->queue_lock);
+ xprt_complete_rqst(rqst->rq_task, status);
+ xprt_unpin_rqst(rqst);
+ spin_unlock(&xprt->queue_lock);
+ return;
+
+out_badheader:
+ trace_xprtrdma_reply_hdr_err(rep);
+ r_xprt->rx_stats.bad_reply_count++;
+ rqst->rq_task->tk_status = status;
+ status = 0;
+ goto out;
+}
+
+static void rpcrdma_reply_done(struct kref *kref)
+{
+ struct rpcrdma_req *req =
+ container_of(kref, struct rpcrdma_req, rl_kref);
+
+ rpcrdma_complete_rqst(req->rl_reply);
+}
+
+/**
+ * rpcrdma_reply_handler - Process received RPC/RDMA messages
+ * @rep: Incoming rpcrdma_rep object to process
+ *
+ * Errors must result in the RPC task either being awakened, or
+ * allowed to timeout, to discover the errors at that time.
+ */
+void rpcrdma_reply_handler(struct rpcrdma_rep *rep)
+{
+ struct rpcrdma_xprt *r_xprt = rep->rr_rxprt;
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_req *req;
+ struct rpc_rqst *rqst;
+ u32 credits;
+ __be32 *p;
+
+ /* Any data means we had a useful conversation, so
+ * then we don't need to delay the next reconnect.
+ */
+ if (xprt->reestablish_timeout)
+ xprt->reestablish_timeout = 0;
+
+ /* Fixed transport header fields */
+ xdr_init_decode(&rep->rr_stream, &rep->rr_hdrbuf,
+ rep->rr_hdrbuf.head[0].iov_base, NULL);
+ p = xdr_inline_decode(&rep->rr_stream, 4 * sizeof(*p));
+ if (unlikely(!p))
+ goto out_shortreply;
+ rep->rr_xid = *p++;
+ rep->rr_vers = *p++;
+ credits = be32_to_cpu(*p++);
+ rep->rr_proc = *p++;
+
+ if (rep->rr_vers != rpcrdma_version)
+ goto out_badversion;
+
+ if (rpcrdma_is_bcall(r_xprt, rep))
+ return;
+
+ /* Match incoming rpcrdma_rep to an rpcrdma_req to
+ * get context for handling any incoming chunks.
+ */
+ spin_lock(&xprt->queue_lock);
+ rqst = xprt_lookup_rqst(xprt, rep->rr_xid);
+ if (!rqst)
+ goto out_norqst;
+ xprt_pin_rqst(rqst);
+ spin_unlock(&xprt->queue_lock);
+
+ if (credits == 0)
+ credits = 1; /* don't deadlock */
+ else if (credits > r_xprt->rx_ep->re_max_requests)
+ credits = r_xprt->rx_ep->re_max_requests;
+ rpcrdma_post_recvs(r_xprt, credits + (buf->rb_bc_srv_max_requests << 1),
+ false);
+ if (buf->rb_credits != credits)
+ rpcrdma_update_cwnd(r_xprt, credits);
+
+ req = rpcr_to_rdmar(rqst);
+ if (unlikely(req->rl_reply))
+ rpcrdma_rep_put(buf, req->rl_reply);
+ req->rl_reply = rep;
+ rep->rr_rqst = rqst;
+
+ trace_xprtrdma_reply(rqst->rq_task, rep, credits);
+
+ if (rep->rr_wc_flags & IB_WC_WITH_INVALIDATE)
+ frwr_reminv(rep, &req->rl_registered);
+ if (!list_empty(&req->rl_registered))
+ frwr_unmap_async(r_xprt, req);
+ /* LocalInv completion will complete the RPC */
+ else
+ kref_put(&req->rl_kref, rpcrdma_reply_done);
+ return;
+
+out_badversion:
+ trace_xprtrdma_reply_vers_err(rep);
+ goto out;
+
+out_norqst:
+ spin_unlock(&xprt->queue_lock);
+ trace_xprtrdma_reply_rqst_err(rep);
+ goto out;
+
+out_shortreply:
+ trace_xprtrdma_reply_short_err(rep);
+
+out:
+ rpcrdma_rep_put(buf, rep);
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c
new file mode 100644
index 0000000000..f0d5eeed4c
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma.c
@@ -0,0 +1,283 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2015-2018 Oracle. All rights reserved.
+ * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: Tom Tucker <tom@opengridcomputing.com>
+ */
+
+#include <linux/slab.h>
+#include <linux/fs.h>
+#include <linux/sysctl.h>
+#include <linux/workqueue.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#define RPCDBG_FACILITY RPCDBG_SVCXPRT
+
+/* RPC/RDMA parameters */
+unsigned int svcrdma_ord = 16; /* historical default */
+static unsigned int min_ord = 1;
+static unsigned int max_ord = 255;
+unsigned int svcrdma_max_requests = RPCRDMA_MAX_REQUESTS;
+unsigned int svcrdma_max_bc_requests = RPCRDMA_MAX_BC_REQUESTS;
+static unsigned int min_max_requests = 4;
+static unsigned int max_max_requests = 16384;
+unsigned int svcrdma_max_req_size = RPCRDMA_DEF_INLINE_THRESH;
+static unsigned int min_max_inline = RPCRDMA_DEF_INLINE_THRESH;
+static unsigned int max_max_inline = RPCRDMA_MAX_INLINE_THRESH;
+static unsigned int svcrdma_stat_unused;
+static unsigned int zero;
+
+struct percpu_counter svcrdma_stat_read;
+struct percpu_counter svcrdma_stat_recv;
+struct percpu_counter svcrdma_stat_sq_starve;
+struct percpu_counter svcrdma_stat_write;
+
+enum {
+ SVCRDMA_COUNTER_BUFSIZ = sizeof(unsigned long long),
+};
+
+static int svcrdma_counter_handler(struct ctl_table *table, int write,
+ void *buffer, size_t *lenp, loff_t *ppos)
+{
+ struct percpu_counter *stat = (struct percpu_counter *)table->data;
+ char tmp[SVCRDMA_COUNTER_BUFSIZ + 1];
+ int len;
+
+ if (write) {
+ percpu_counter_set(stat, 0);
+ return 0;
+ }
+
+ len = snprintf(tmp, SVCRDMA_COUNTER_BUFSIZ, "%lld\n",
+ percpu_counter_sum_positive(stat));
+ if (len >= SVCRDMA_COUNTER_BUFSIZ)
+ return -EFAULT;
+ len = strlen(tmp);
+ if (*ppos > len) {
+ *lenp = 0;
+ return 0;
+ }
+ len -= *ppos;
+ if (len > *lenp)
+ len = *lenp;
+ if (len)
+ memcpy(buffer, tmp, len);
+ *lenp = len;
+ *ppos += len;
+
+ return 0;
+}
+
+static struct ctl_table_header *svcrdma_table_header;
+static struct ctl_table svcrdma_parm_table[] = {
+ {
+ .procname = "max_requests",
+ .data = &svcrdma_max_requests,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_max_requests,
+ .extra2 = &max_max_requests
+ },
+ {
+ .procname = "max_req_size",
+ .data = &svcrdma_max_req_size,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_max_inline,
+ .extra2 = &max_max_inline
+ },
+ {
+ .procname = "max_outbound_read_requests",
+ .data = &svcrdma_ord,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_ord,
+ .extra2 = &max_ord,
+ },
+
+ {
+ .procname = "rdma_stat_read",
+ .data = &svcrdma_stat_read,
+ .maxlen = SVCRDMA_COUNTER_BUFSIZ,
+ .mode = 0644,
+ .proc_handler = svcrdma_counter_handler,
+ },
+ {
+ .procname = "rdma_stat_recv",
+ .data = &svcrdma_stat_recv,
+ .maxlen = SVCRDMA_COUNTER_BUFSIZ,
+ .mode = 0644,
+ .proc_handler = svcrdma_counter_handler,
+ },
+ {
+ .procname = "rdma_stat_write",
+ .data = &svcrdma_stat_write,
+ .maxlen = SVCRDMA_COUNTER_BUFSIZ,
+ .mode = 0644,
+ .proc_handler = svcrdma_counter_handler,
+ },
+ {
+ .procname = "rdma_stat_sq_starve",
+ .data = &svcrdma_stat_sq_starve,
+ .maxlen = SVCRDMA_COUNTER_BUFSIZ,
+ .mode = 0644,
+ .proc_handler = svcrdma_counter_handler,
+ },
+ {
+ .procname = "rdma_stat_rq_starve",
+ .data = &svcrdma_stat_unused,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &zero,
+ .extra2 = &zero,
+ },
+ {
+ .procname = "rdma_stat_rq_poll",
+ .data = &svcrdma_stat_unused,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &zero,
+ .extra2 = &zero,
+ },
+ {
+ .procname = "rdma_stat_rq_prod",
+ .data = &svcrdma_stat_unused,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &zero,
+ .extra2 = &zero,
+ },
+ {
+ .procname = "rdma_stat_sq_poll",
+ .data = &svcrdma_stat_unused,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &zero,
+ .extra2 = &zero,
+ },
+ {
+ .procname = "rdma_stat_sq_prod",
+ .data = &svcrdma_stat_unused,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &zero,
+ .extra2 = &zero,
+ },
+ { },
+};
+
+static void svc_rdma_proc_cleanup(void)
+{
+ if (!svcrdma_table_header)
+ return;
+ unregister_sysctl_table(svcrdma_table_header);
+ svcrdma_table_header = NULL;
+
+ percpu_counter_destroy(&svcrdma_stat_write);
+ percpu_counter_destroy(&svcrdma_stat_sq_starve);
+ percpu_counter_destroy(&svcrdma_stat_recv);
+ percpu_counter_destroy(&svcrdma_stat_read);
+}
+
+static int svc_rdma_proc_init(void)
+{
+ int rc;
+
+ if (svcrdma_table_header)
+ return 0;
+
+ rc = percpu_counter_init(&svcrdma_stat_read, 0, GFP_KERNEL);
+ if (rc)
+ goto out_err;
+ rc = percpu_counter_init(&svcrdma_stat_recv, 0, GFP_KERNEL);
+ if (rc)
+ goto out_err;
+ rc = percpu_counter_init(&svcrdma_stat_sq_starve, 0, GFP_KERNEL);
+ if (rc)
+ goto out_err;
+ rc = percpu_counter_init(&svcrdma_stat_write, 0, GFP_KERNEL);
+ if (rc)
+ goto out_err;
+
+ svcrdma_table_header = register_sysctl("sunrpc/svc_rdma",
+ svcrdma_parm_table);
+ return 0;
+
+out_err:
+ percpu_counter_destroy(&svcrdma_stat_sq_starve);
+ percpu_counter_destroy(&svcrdma_stat_recv);
+ percpu_counter_destroy(&svcrdma_stat_read);
+ return rc;
+}
+
+void svc_rdma_cleanup(void)
+{
+ dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n");
+ svc_unreg_xprt_class(&svc_rdma_class);
+ svc_rdma_proc_cleanup();
+}
+
+int svc_rdma_init(void)
+{
+ int rc;
+
+ dprintk("SVCRDMA Module Init, register RPC RDMA transport\n");
+ dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord);
+ dprintk("\tmax_requests : %u\n", svcrdma_max_requests);
+ dprintk("\tmax_bc_requests : %u\n", svcrdma_max_bc_requests);
+ dprintk("\tmax_inline : %d\n", svcrdma_max_req_size);
+
+ rc = svc_rdma_proc_init();
+ if (rc)
+ return rc;
+
+ /* Register RDMA with the SVC transport switch */
+ svc_reg_xprt_class(&svc_rdma_class);
+ return 0;
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c
new file mode 100644
index 0000000000..7420a2c990
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c
@@ -0,0 +1,287 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2015-2018 Oracle. All rights reserved.
+ *
+ * Support for reverse-direction RPCs on RPC/RDMA (server-side).
+ */
+
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+/**
+ * svc_rdma_handle_bc_reply - Process incoming backchannel Reply
+ * @rqstp: resources for handling the Reply
+ * @rctxt: Received message
+ *
+ */
+void svc_rdma_handle_bc_reply(struct svc_rqst *rqstp,
+ struct svc_rdma_recv_ctxt *rctxt)
+{
+ struct svc_xprt *sxprt = rqstp->rq_xprt;
+ struct rpc_xprt *xprt = sxprt->xpt_bc_xprt;
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct xdr_buf *rcvbuf = &rqstp->rq_arg;
+ struct kvec *dst, *src = &rcvbuf->head[0];
+ __be32 *rdma_resp = rctxt->rc_recv_buf;
+ struct rpc_rqst *req;
+ u32 credits;
+
+ spin_lock(&xprt->queue_lock);
+ req = xprt_lookup_rqst(xprt, *rdma_resp);
+ if (!req)
+ goto out_unlock;
+
+ dst = &req->rq_private_buf.head[0];
+ memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf));
+ if (dst->iov_len < src->iov_len)
+ goto out_unlock;
+ memcpy(dst->iov_base, src->iov_base, src->iov_len);
+ xprt_pin_rqst(req);
+ spin_unlock(&xprt->queue_lock);
+
+ credits = be32_to_cpup(rdma_resp + 2);
+ if (credits == 0)
+ credits = 1; /* don't deadlock */
+ else if (credits > r_xprt->rx_buf.rb_bc_max_requests)
+ credits = r_xprt->rx_buf.rb_bc_max_requests;
+ spin_lock(&xprt->transport_lock);
+ xprt->cwnd = credits << RPC_CWNDSHIFT;
+ spin_unlock(&xprt->transport_lock);
+
+ spin_lock(&xprt->queue_lock);
+ xprt_complete_rqst(req->rq_task, rcvbuf->len);
+ xprt_unpin_rqst(req);
+ rcvbuf->len = 0;
+
+out_unlock:
+ spin_unlock(&xprt->queue_lock);
+}
+
+/* Send a reverse-direction RPC Call.
+ *
+ * Caller holds the connection's mutex and has already marshaled
+ * the RPC/RDMA request.
+ *
+ * This is similar to svc_rdma_send_reply_msg, but takes a struct
+ * rpc_rqst instead, does not support chunks, and avoids blocking
+ * memory allocation.
+ *
+ * XXX: There is still an opportunity to block in svc_rdma_send()
+ * if there are no SQ entries to post the Send. This may occur if
+ * the adapter has a small maximum SQ depth.
+ */
+static int svc_rdma_bc_sendto(struct svcxprt_rdma *rdma,
+ struct rpc_rqst *rqst,
+ struct svc_rdma_send_ctxt *sctxt)
+{
+ struct svc_rdma_recv_ctxt *rctxt;
+ int ret;
+
+ rctxt = svc_rdma_recv_ctxt_get(rdma);
+ if (!rctxt)
+ return -EIO;
+
+ ret = svc_rdma_map_reply_msg(rdma, sctxt, rctxt, &rqst->rq_snd_buf);
+ svc_rdma_recv_ctxt_put(rdma, rctxt);
+ if (ret < 0)
+ return -EIO;
+
+ /* Bump page refcnt so Send completion doesn't release
+ * the rq_buffer before all retransmits are complete.
+ */
+ get_page(virt_to_page(rqst->rq_buffer));
+ sctxt->sc_send_wr.opcode = IB_WR_SEND;
+ return svc_rdma_send(rdma, sctxt);
+}
+
+/* Server-side transport endpoint wants a whole page for its send
+ * buffer. The client RPC code constructs the RPC header in this
+ * buffer before it invokes ->send_request.
+ */
+static int
+xprt_rdma_bc_allocate(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+ size_t size = rqst->rq_callsize;
+ struct page *page;
+
+ if (size > PAGE_SIZE) {
+ WARN_ONCE(1, "svcrdma: large bc buffer request (size %zu)\n",
+ size);
+ return -EINVAL;
+ }
+
+ page = alloc_page(GFP_NOIO | __GFP_NOWARN);
+ if (!page)
+ return -ENOMEM;
+ rqst->rq_buffer = page_address(page);
+
+ rqst->rq_rbuffer = kmalloc(rqst->rq_rcvsize, GFP_NOIO | __GFP_NOWARN);
+ if (!rqst->rq_rbuffer) {
+ put_page(page);
+ return -ENOMEM;
+ }
+ return 0;
+}
+
+static void
+xprt_rdma_bc_free(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+
+ put_page(virt_to_page(rqst->rq_buffer));
+ kfree(rqst->rq_rbuffer);
+}
+
+static int
+rpcrdma_bc_send_request(struct svcxprt_rdma *rdma, struct rpc_rqst *rqst)
+{
+ struct rpc_xprt *xprt = rqst->rq_xprt;
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct svc_rdma_send_ctxt *ctxt;
+ __be32 *p;
+ int rc;
+
+ ctxt = svc_rdma_send_ctxt_get(rdma);
+ if (!ctxt)
+ goto drop_connection;
+
+ p = xdr_reserve_space(&ctxt->sc_stream, RPCRDMA_HDRLEN_MIN);
+ if (!p)
+ goto put_ctxt;
+ *p++ = rqst->rq_xid;
+ *p++ = rpcrdma_version;
+ *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_max_requests);
+ *p++ = rdma_msg;
+ *p++ = xdr_zero;
+ *p++ = xdr_zero;
+ *p = xdr_zero;
+
+ rqst->rq_xtime = ktime_get();
+ rc = svc_rdma_bc_sendto(rdma, rqst, ctxt);
+ if (rc)
+ goto put_ctxt;
+ return 0;
+
+put_ctxt:
+ svc_rdma_send_ctxt_put(rdma, ctxt);
+
+drop_connection:
+ return -ENOTCONN;
+}
+
+/**
+ * xprt_rdma_bc_send_request - Send a reverse-direction Call
+ * @rqst: rpc_rqst containing Call message to be sent
+ *
+ * Return values:
+ * %0 if the message was sent successfully
+ * %ENOTCONN if the message was not sent
+ */
+static int xprt_rdma_bc_send_request(struct rpc_rqst *rqst)
+{
+ struct svc_xprt *sxprt = rqst->rq_xprt->bc_xprt;
+ struct svcxprt_rdma *rdma =
+ container_of(sxprt, struct svcxprt_rdma, sc_xprt);
+ int ret;
+
+ if (test_bit(XPT_DEAD, &sxprt->xpt_flags))
+ return -ENOTCONN;
+
+ ret = rpcrdma_bc_send_request(rdma, rqst);
+ if (ret == -ENOTCONN)
+ svc_xprt_close(sxprt);
+ return ret;
+}
+
+static void
+xprt_rdma_bc_close(struct rpc_xprt *xprt)
+{
+ xprt_disconnect_done(xprt);
+ xprt->cwnd = RPC_CWNDSHIFT;
+}
+
+static void
+xprt_rdma_bc_put(struct rpc_xprt *xprt)
+{
+ xprt_rdma_free_addresses(xprt);
+ xprt_free(xprt);
+}
+
+static const struct rpc_xprt_ops xprt_rdma_bc_procs = {
+ .reserve_xprt = xprt_reserve_xprt_cong,
+ .release_xprt = xprt_release_xprt_cong,
+ .alloc_slot = xprt_alloc_slot,
+ .free_slot = xprt_free_slot,
+ .release_request = xprt_release_rqst_cong,
+ .buf_alloc = xprt_rdma_bc_allocate,
+ .buf_free = xprt_rdma_bc_free,
+ .send_request = xprt_rdma_bc_send_request,
+ .wait_for_reply_request = xprt_wait_for_reply_request_def,
+ .close = xprt_rdma_bc_close,
+ .destroy = xprt_rdma_bc_put,
+ .print_stats = xprt_rdma_print_stats
+};
+
+static const struct rpc_timeout xprt_rdma_bc_timeout = {
+ .to_initval = 60 * HZ,
+ .to_maxval = 60 * HZ,
+};
+
+/* It shouldn't matter if the number of backchannel session slots
+ * doesn't match the number of RPC/RDMA credits. That just means
+ * one or the other will have extra slots that aren't used.
+ */
+static struct rpc_xprt *
+xprt_setup_rdma_bc(struct xprt_create *args)
+{
+ struct rpc_xprt *xprt;
+ struct rpcrdma_xprt *new_xprt;
+
+ if (args->addrlen > sizeof(xprt->addr))
+ return ERR_PTR(-EBADF);
+
+ xprt = xprt_alloc(args->net, sizeof(*new_xprt),
+ RPCRDMA_MAX_BC_REQUESTS,
+ RPCRDMA_MAX_BC_REQUESTS);
+ if (!xprt)
+ return ERR_PTR(-ENOMEM);
+
+ xprt->timeout = &xprt_rdma_bc_timeout;
+ xprt_set_bound(xprt);
+ xprt_set_connected(xprt);
+ xprt->bind_timeout = 0;
+ xprt->reestablish_timeout = 0;
+ xprt->idle_timeout = 0;
+
+ xprt->prot = XPRT_TRANSPORT_BC_RDMA;
+ xprt->ops = &xprt_rdma_bc_procs;
+
+ memcpy(&xprt->addr, args->dstaddr, args->addrlen);
+ xprt->addrlen = args->addrlen;
+ xprt_rdma_format_addresses(xprt, (struct sockaddr *)&xprt->addr);
+ xprt->resvport = 0;
+
+ xprt->max_payload = xprt_rdma_max_inline_read;
+
+ new_xprt = rpcx_to_rdmax(xprt);
+ new_xprt->rx_buf.rb_bc_max_requests = xprt->max_reqs;
+
+ xprt_get(xprt);
+ args->bc_xprt->xpt_bc_xprt = xprt;
+ xprt->bc_xprt = args->bc_xprt;
+
+ /* Final put for backchannel xprt is in __svc_rdma_free */
+ xprt_get(xprt);
+ return xprt;
+}
+
+struct xprt_class xprt_rdma_bc = {
+ .list = LIST_HEAD_INIT(xprt_rdma_bc.list),
+ .name = "rdma backchannel",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_BC_RDMA,
+ .setup = xprt_setup_rdma_bc,
+};
diff --git a/net/sunrpc/xprtrdma/svc_rdma_pcl.c b/net/sunrpc/xprtrdma/svc_rdma_pcl.c
new file mode 100644
index 0000000000..b63cfeaa29
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_pcl.c
@@ -0,0 +1,306 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2020 Oracle. All rights reserved.
+ */
+
+#include <linux/sunrpc/svc_rdma.h>
+#include <linux/sunrpc/rpc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+/**
+ * pcl_free - Release all memory associated with a parsed chunk list
+ * @pcl: parsed chunk list
+ *
+ */
+void pcl_free(struct svc_rdma_pcl *pcl)
+{
+ while (!list_empty(&pcl->cl_chunks)) {
+ struct svc_rdma_chunk *chunk;
+
+ chunk = pcl_first_chunk(pcl);
+ list_del(&chunk->ch_list);
+ kfree(chunk);
+ }
+}
+
+static struct svc_rdma_chunk *pcl_alloc_chunk(u32 segcount, u32 position)
+{
+ struct svc_rdma_chunk *chunk;
+
+ chunk = kmalloc(struct_size(chunk, ch_segments, segcount), GFP_KERNEL);
+ if (!chunk)
+ return NULL;
+
+ chunk->ch_position = position;
+ chunk->ch_length = 0;
+ chunk->ch_payload_length = 0;
+ chunk->ch_segcount = 0;
+ return chunk;
+}
+
+static struct svc_rdma_chunk *
+pcl_lookup_position(struct svc_rdma_pcl *pcl, u32 position)
+{
+ struct svc_rdma_chunk *pos;
+
+ pcl_for_each_chunk(pos, pcl) {
+ if (pos->ch_position == position)
+ return pos;
+ }
+ return NULL;
+}
+
+static void pcl_insert_position(struct svc_rdma_pcl *pcl,
+ struct svc_rdma_chunk *chunk)
+{
+ struct svc_rdma_chunk *pos;
+
+ pcl_for_each_chunk(pos, pcl) {
+ if (pos->ch_position > chunk->ch_position)
+ break;
+ }
+ __list_add(&chunk->ch_list, pos->ch_list.prev, &pos->ch_list);
+ pcl->cl_count++;
+}
+
+static void pcl_set_read_segment(const struct svc_rdma_recv_ctxt *rctxt,
+ struct svc_rdma_chunk *chunk,
+ u32 handle, u32 length, u64 offset)
+{
+ struct svc_rdma_segment *segment;
+
+ segment = &chunk->ch_segments[chunk->ch_segcount];
+ segment->rs_handle = handle;
+ segment->rs_length = length;
+ segment->rs_offset = offset;
+
+ trace_svcrdma_decode_rseg(&rctxt->rc_cid, chunk, segment);
+
+ chunk->ch_length += length;
+ chunk->ch_segcount++;
+}
+
+/**
+ * pcl_alloc_call - Construct a parsed chunk list for the Call body
+ * @rctxt: Ingress receive context
+ * @p: Start of an un-decoded Read list
+ *
+ * Assumptions:
+ * - The incoming Read list has already been sanity checked.
+ * - cl_count is already set to the number of segments in
+ * the un-decoded list.
+ * - The list might not be in order by position.
+ *
+ * Return values:
+ * %true: Parsed chunk list was successfully constructed, and
+ * cl_count is updated to be the number of chunks (ie.
+ * unique positions) in the Read list.
+ * %false: Memory allocation failed.
+ */
+bool pcl_alloc_call(struct svc_rdma_recv_ctxt *rctxt, __be32 *p)
+{
+ struct svc_rdma_pcl *pcl = &rctxt->rc_call_pcl;
+ unsigned int i, segcount = pcl->cl_count;
+
+ pcl->cl_count = 0;
+ for (i = 0; i < segcount; i++) {
+ struct svc_rdma_chunk *chunk;
+ u32 position, handle, length;
+ u64 offset;
+
+ p++; /* skip the list discriminator */
+ p = xdr_decode_read_segment(p, &position, &handle,
+ &length, &offset);
+ if (position != 0)
+ continue;
+
+ if (pcl_is_empty(pcl)) {
+ chunk = pcl_alloc_chunk(segcount, position);
+ if (!chunk)
+ return false;
+ pcl_insert_position(pcl, chunk);
+ } else {
+ chunk = list_first_entry(&pcl->cl_chunks,
+ struct svc_rdma_chunk,
+ ch_list);
+ }
+
+ pcl_set_read_segment(rctxt, chunk, handle, length, offset);
+ }
+
+ return true;
+}
+
+/**
+ * pcl_alloc_read - Construct a parsed chunk list for normal Read chunks
+ * @rctxt: Ingress receive context
+ * @p: Start of an un-decoded Read list
+ *
+ * Assumptions:
+ * - The incoming Read list has already been sanity checked.
+ * - cl_count is already set to the number of segments in
+ * the un-decoded list.
+ * - The list might not be in order by position.
+ *
+ * Return values:
+ * %true: Parsed chunk list was successfully constructed, and
+ * cl_count is updated to be the number of chunks (ie.
+ * unique position values) in the Read list.
+ * %false: Memory allocation failed.
+ *
+ * TODO:
+ * - Check for chunk range overlaps
+ */
+bool pcl_alloc_read(struct svc_rdma_recv_ctxt *rctxt, __be32 *p)
+{
+ struct svc_rdma_pcl *pcl = &rctxt->rc_read_pcl;
+ unsigned int i, segcount = pcl->cl_count;
+
+ pcl->cl_count = 0;
+ for (i = 0; i < segcount; i++) {
+ struct svc_rdma_chunk *chunk;
+ u32 position, handle, length;
+ u64 offset;
+
+ p++; /* skip the list discriminator */
+ p = xdr_decode_read_segment(p, &position, &handle,
+ &length, &offset);
+ if (position == 0)
+ continue;
+
+ chunk = pcl_lookup_position(pcl, position);
+ if (!chunk) {
+ chunk = pcl_alloc_chunk(segcount, position);
+ if (!chunk)
+ return false;
+ pcl_insert_position(pcl, chunk);
+ }
+
+ pcl_set_read_segment(rctxt, chunk, handle, length, offset);
+ }
+
+ return true;
+}
+
+/**
+ * pcl_alloc_write - Construct a parsed chunk list from a Write list
+ * @rctxt: Ingress receive context
+ * @pcl: Parsed chunk list to populate
+ * @p: Start of an un-decoded Write list
+ *
+ * Assumptions:
+ * - The incoming Write list has already been sanity checked, and
+ * - cl_count is set to the number of chunks in the un-decoded list.
+ *
+ * Return values:
+ * %true: Parsed chunk list was successfully constructed.
+ * %false: Memory allocation failed.
+ */
+bool pcl_alloc_write(struct svc_rdma_recv_ctxt *rctxt,
+ struct svc_rdma_pcl *pcl, __be32 *p)
+{
+ struct svc_rdma_segment *segment;
+ struct svc_rdma_chunk *chunk;
+ unsigned int i, j;
+ u32 segcount;
+
+ for (i = 0; i < pcl->cl_count; i++) {
+ p++; /* skip the list discriminator */
+ segcount = be32_to_cpup(p++);
+
+ chunk = pcl_alloc_chunk(segcount, 0);
+ if (!chunk)
+ return false;
+ list_add_tail(&chunk->ch_list, &pcl->cl_chunks);
+
+ for (j = 0; j < segcount; j++) {
+ segment = &chunk->ch_segments[j];
+ p = xdr_decode_rdma_segment(p, &segment->rs_handle,
+ &segment->rs_length,
+ &segment->rs_offset);
+ trace_svcrdma_decode_wseg(&rctxt->rc_cid, chunk, j);
+
+ chunk->ch_length += segment->rs_length;
+ chunk->ch_segcount++;
+ }
+ }
+ return true;
+}
+
+static int pcl_process_region(const struct xdr_buf *xdr,
+ unsigned int offset, unsigned int length,
+ int (*actor)(const struct xdr_buf *, void *),
+ void *data)
+{
+ struct xdr_buf subbuf;
+
+ if (!length)
+ return 0;
+ if (xdr_buf_subsegment(xdr, &subbuf, offset, length))
+ return -EMSGSIZE;
+ return actor(&subbuf, data);
+}
+
+/**
+ * pcl_process_nonpayloads - Process non-payload regions inside @xdr
+ * @pcl: Chunk list to process
+ * @xdr: xdr_buf to process
+ * @actor: Function to invoke on each non-payload region
+ * @data: Arguments for @actor
+ *
+ * This mechanism must ignore not only result payloads that were already
+ * sent via RDMA Write, but also XDR padding for those payloads that
+ * the upper layer has added.
+ *
+ * Assumptions:
+ * The xdr->len and ch_position fields are aligned to 4-byte multiples.
+ *
+ * Returns:
+ * On success, zero,
+ * %-EMSGSIZE on XDR buffer overflow, or
+ * The return value of @actor
+ */
+int pcl_process_nonpayloads(const struct svc_rdma_pcl *pcl,
+ const struct xdr_buf *xdr,
+ int (*actor)(const struct xdr_buf *, void *),
+ void *data)
+{
+ struct svc_rdma_chunk *chunk, *next;
+ unsigned int start;
+ int ret;
+
+ chunk = pcl_first_chunk(pcl);
+
+ /* No result payloads were generated */
+ if (!chunk || !chunk->ch_payload_length)
+ return actor(xdr, data);
+
+ /* Process the region before the first result payload */
+ ret = pcl_process_region(xdr, 0, chunk->ch_position, actor, data);
+ if (ret < 0)
+ return ret;
+
+ /* Process the regions between each middle result payload */
+ while ((next = pcl_next_chunk(pcl, chunk))) {
+ if (!next->ch_payload_length)
+ break;
+
+ start = pcl_chunk_end_offset(chunk);
+ ret = pcl_process_region(xdr, start, next->ch_position - start,
+ actor, data);
+ if (ret < 0)
+ return ret;
+
+ chunk = next;
+ }
+
+ /* Process the region after the last result payload */
+ start = pcl_chunk_end_offset(chunk);
+ ret = pcl_process_region(xdr, start, xdr->len - start, actor, data);
+ if (ret < 0)
+ return ret;
+
+ return 0;
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c
new file mode 100644
index 0000000000..3b05f90a3e
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c
@@ -0,0 +1,863 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2016-2018 Oracle. All rights reserved.
+ * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved.
+ * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: Tom Tucker <tom@opengridcomputing.com>
+ */
+
+/* Operation
+ *
+ * The main entry point is svc_rdma_recvfrom. This is called from
+ * svc_recv when the transport indicates there is incoming data to
+ * be read. "Data Ready" is signaled when an RDMA Receive completes,
+ * or when a set of RDMA Reads complete.
+ *
+ * An svc_rqst is passed in. This structure contains an array of
+ * free pages (rq_pages) that will contain the incoming RPC message.
+ *
+ * Short messages are moved directly into svc_rqst::rq_arg, and
+ * the RPC Call is ready to be processed by the Upper Layer.
+ * svc_rdma_recvfrom returns the length of the RPC Call message,
+ * completing the reception of the RPC Call.
+ *
+ * However, when an incoming message has Read chunks,
+ * svc_rdma_recvfrom must post RDMA Reads to pull the RPC Call's
+ * data payload from the client. svc_rdma_recvfrom sets up the
+ * RDMA Reads using pages in svc_rqst::rq_pages, which are
+ * transferred to an svc_rdma_recv_ctxt for the duration of the
+ * I/O. svc_rdma_recvfrom then returns zero, since the RPC message
+ * is still not yet ready.
+ *
+ * When the Read chunk payloads have become available on the
+ * server, "Data Ready" is raised again, and svc_recv calls
+ * svc_rdma_recvfrom again. This second call may use a different
+ * svc_rqst than the first one, thus any information that needs
+ * to be preserved across these two calls is kept in an
+ * svc_rdma_recv_ctxt.
+ *
+ * The second call to svc_rdma_recvfrom performs final assembly
+ * of the RPC Call message, using the RDMA Read sink pages kept in
+ * the svc_rdma_recv_ctxt. The xdr_buf is copied from the
+ * svc_rdma_recv_ctxt to the second svc_rqst. The second call returns
+ * the length of the completed RPC Call message.
+ *
+ * Page Management
+ *
+ * Pages under I/O must be transferred from the first svc_rqst to an
+ * svc_rdma_recv_ctxt before the first svc_rdma_recvfrom call returns.
+ *
+ * The first svc_rqst supplies pages for RDMA Reads. These are moved
+ * from rqstp::rq_pages into ctxt::pages. The consumed elements of
+ * the rq_pages array are set to NULL and refilled with the first
+ * svc_rdma_recvfrom call returns.
+ *
+ * During the second svc_rdma_recvfrom call, RDMA Read sink pages
+ * are transferred from the svc_rdma_recv_ctxt to the second svc_rqst.
+ */
+
+#include <linux/slab.h>
+#include <linux/spinlock.h>
+#include <asm/unaligned.h>
+#include <rdma/ib_verbs.h>
+#include <rdma/rdma_cm.h>
+
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/debug.h>
+#include <linux/sunrpc/rpc_rdma.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc);
+
+static inline struct svc_rdma_recv_ctxt *
+svc_rdma_next_recv_ctxt(struct list_head *list)
+{
+ return list_first_entry_or_null(list, struct svc_rdma_recv_ctxt,
+ rc_list);
+}
+
+static void svc_rdma_recv_cid_init(struct svcxprt_rdma *rdma,
+ struct rpc_rdma_cid *cid)
+{
+ cid->ci_queue_id = rdma->sc_rq_cq->res.id;
+ cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids);
+}
+
+static struct svc_rdma_recv_ctxt *
+svc_rdma_recv_ctxt_alloc(struct svcxprt_rdma *rdma)
+{
+ int node = ibdev_to_node(rdma->sc_cm_id->device);
+ struct svc_rdma_recv_ctxt *ctxt;
+ dma_addr_t addr;
+ void *buffer;
+
+ ctxt = kmalloc_node(sizeof(*ctxt), GFP_KERNEL, node);
+ if (!ctxt)
+ goto fail0;
+ buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node);
+ if (!buffer)
+ goto fail1;
+ addr = ib_dma_map_single(rdma->sc_pd->device, buffer,
+ rdma->sc_max_req_size, DMA_FROM_DEVICE);
+ if (ib_dma_mapping_error(rdma->sc_pd->device, addr))
+ goto fail2;
+
+ svc_rdma_recv_cid_init(rdma, &ctxt->rc_cid);
+ pcl_init(&ctxt->rc_call_pcl);
+ pcl_init(&ctxt->rc_read_pcl);
+ pcl_init(&ctxt->rc_write_pcl);
+ pcl_init(&ctxt->rc_reply_pcl);
+
+ ctxt->rc_recv_wr.next = NULL;
+ ctxt->rc_recv_wr.wr_cqe = &ctxt->rc_cqe;
+ ctxt->rc_recv_wr.sg_list = &ctxt->rc_recv_sge;
+ ctxt->rc_recv_wr.num_sge = 1;
+ ctxt->rc_cqe.done = svc_rdma_wc_receive;
+ ctxt->rc_recv_sge.addr = addr;
+ ctxt->rc_recv_sge.length = rdma->sc_max_req_size;
+ ctxt->rc_recv_sge.lkey = rdma->sc_pd->local_dma_lkey;
+ ctxt->rc_recv_buf = buffer;
+ return ctxt;
+
+fail2:
+ kfree(buffer);
+fail1:
+ kfree(ctxt);
+fail0:
+ return NULL;
+}
+
+static void svc_rdma_recv_ctxt_destroy(struct svcxprt_rdma *rdma,
+ struct svc_rdma_recv_ctxt *ctxt)
+{
+ ib_dma_unmap_single(rdma->sc_pd->device, ctxt->rc_recv_sge.addr,
+ ctxt->rc_recv_sge.length, DMA_FROM_DEVICE);
+ kfree(ctxt->rc_recv_buf);
+ kfree(ctxt);
+}
+
+/**
+ * svc_rdma_recv_ctxts_destroy - Release all recv_ctxt's for an xprt
+ * @rdma: svcxprt_rdma being torn down
+ *
+ */
+void svc_rdma_recv_ctxts_destroy(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_recv_ctxt *ctxt;
+ struct llist_node *node;
+
+ while ((node = llist_del_first(&rdma->sc_recv_ctxts))) {
+ ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node);
+ svc_rdma_recv_ctxt_destroy(rdma, ctxt);
+ }
+}
+
+/**
+ * svc_rdma_recv_ctxt_get - Allocate a recv_ctxt
+ * @rdma: controlling svcxprt_rdma
+ *
+ * Returns a recv_ctxt or (rarely) NULL if none are available.
+ */
+struct svc_rdma_recv_ctxt *svc_rdma_recv_ctxt_get(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_recv_ctxt *ctxt;
+ struct llist_node *node;
+
+ node = llist_del_first(&rdma->sc_recv_ctxts);
+ if (!node)
+ goto out_empty;
+ ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node);
+
+out:
+ ctxt->rc_page_count = 0;
+ return ctxt;
+
+out_empty:
+ ctxt = svc_rdma_recv_ctxt_alloc(rdma);
+ if (!ctxt)
+ return NULL;
+ goto out;
+}
+
+/**
+ * svc_rdma_recv_ctxt_put - Return recv_ctxt to free list
+ * @rdma: controlling svcxprt_rdma
+ * @ctxt: object to return to the free list
+ *
+ */
+void svc_rdma_recv_ctxt_put(struct svcxprt_rdma *rdma,
+ struct svc_rdma_recv_ctxt *ctxt)
+{
+ pcl_free(&ctxt->rc_call_pcl);
+ pcl_free(&ctxt->rc_read_pcl);
+ pcl_free(&ctxt->rc_write_pcl);
+ pcl_free(&ctxt->rc_reply_pcl);
+
+ llist_add(&ctxt->rc_node, &rdma->sc_recv_ctxts);
+}
+
+/**
+ * svc_rdma_release_ctxt - Release transport-specific per-rqst resources
+ * @xprt: the transport which owned the context
+ * @vctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt
+ *
+ * Ensure that the recv_ctxt is released whether or not a Reply
+ * was sent. For example, the client could close the connection,
+ * or svc_process could drop an RPC, before the Reply is sent.
+ */
+void svc_rdma_release_ctxt(struct svc_xprt *xprt, void *vctxt)
+{
+ struct svc_rdma_recv_ctxt *ctxt = vctxt;
+ struct svcxprt_rdma *rdma =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+
+ if (ctxt)
+ svc_rdma_recv_ctxt_put(rdma, ctxt);
+}
+
+static bool svc_rdma_refresh_recvs(struct svcxprt_rdma *rdma,
+ unsigned int wanted)
+{
+ const struct ib_recv_wr *bad_wr = NULL;
+ struct svc_rdma_recv_ctxt *ctxt;
+ struct ib_recv_wr *recv_chain;
+ int ret;
+
+ if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags))
+ return false;
+
+ recv_chain = NULL;
+ while (wanted--) {
+ ctxt = svc_rdma_recv_ctxt_get(rdma);
+ if (!ctxt)
+ break;
+
+ trace_svcrdma_post_recv(ctxt);
+ ctxt->rc_recv_wr.next = recv_chain;
+ recv_chain = &ctxt->rc_recv_wr;
+ rdma->sc_pending_recvs++;
+ }
+ if (!recv_chain)
+ return false;
+
+ ret = ib_post_recv(rdma->sc_qp, recv_chain, &bad_wr);
+ if (ret)
+ goto err_free;
+ return true;
+
+err_free:
+ trace_svcrdma_rq_post_err(rdma, ret);
+ while (bad_wr) {
+ ctxt = container_of(bad_wr, struct svc_rdma_recv_ctxt,
+ rc_recv_wr);
+ bad_wr = bad_wr->next;
+ svc_rdma_recv_ctxt_put(rdma, ctxt);
+ }
+ /* Since we're destroying the xprt, no need to reset
+ * sc_pending_recvs. */
+ return false;
+}
+
+/**
+ * svc_rdma_post_recvs - Post initial set of Recv WRs
+ * @rdma: fresh svcxprt_rdma
+ *
+ * Returns true if successful, otherwise false.
+ */
+bool svc_rdma_post_recvs(struct svcxprt_rdma *rdma)
+{
+ return svc_rdma_refresh_recvs(rdma, rdma->sc_max_requests);
+}
+
+/**
+ * svc_rdma_wc_receive - Invoked by RDMA provider for each polled Receive WC
+ * @cq: Completion Queue context
+ * @wc: Work Completion object
+ *
+ */
+static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct svcxprt_rdma *rdma = cq->cq_context;
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct svc_rdma_recv_ctxt *ctxt;
+
+ rdma->sc_pending_recvs--;
+
+ /* WARNING: Only wc->wr_cqe and wc->status are reliable */
+ ctxt = container_of(cqe, struct svc_rdma_recv_ctxt, rc_cqe);
+
+ if (wc->status != IB_WC_SUCCESS)
+ goto flushed;
+ trace_svcrdma_wc_recv(wc, &ctxt->rc_cid);
+
+ /* If receive posting fails, the connection is about to be
+ * lost anyway. The server will not be able to send a reply
+ * for this RPC, and the client will retransmit this RPC
+ * anyway when it reconnects.
+ *
+ * Therefore we drop the Receive, even if status was SUCCESS
+ * to reduce the likelihood of replayed requests once the
+ * client reconnects.
+ */
+ if (rdma->sc_pending_recvs < rdma->sc_max_requests)
+ if (!svc_rdma_refresh_recvs(rdma, rdma->sc_recv_batch))
+ goto dropped;
+
+ /* All wc fields are now known to be valid */
+ ctxt->rc_byte_len = wc->byte_len;
+
+ spin_lock(&rdma->sc_rq_dto_lock);
+ list_add_tail(&ctxt->rc_list, &rdma->sc_rq_dto_q);
+ /* Note the unlock pairs with the smp_rmb in svc_xprt_ready: */
+ set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags);
+ spin_unlock(&rdma->sc_rq_dto_lock);
+ if (!test_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags))
+ svc_xprt_enqueue(&rdma->sc_xprt);
+ return;
+
+flushed:
+ if (wc->status == IB_WC_WR_FLUSH_ERR)
+ trace_svcrdma_wc_recv_flush(wc, &ctxt->rc_cid);
+ else
+ trace_svcrdma_wc_recv_err(wc, &ctxt->rc_cid);
+dropped:
+ svc_rdma_recv_ctxt_put(rdma, ctxt);
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+}
+
+/**
+ * svc_rdma_flush_recv_queues - Drain pending Receive work
+ * @rdma: svcxprt_rdma being shut down
+ *
+ */
+void svc_rdma_flush_recv_queues(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_recv_ctxt *ctxt;
+
+ while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_rq_dto_q))) {
+ list_del(&ctxt->rc_list);
+ svc_rdma_recv_ctxt_put(rdma, ctxt);
+ }
+}
+
+static void svc_rdma_build_arg_xdr(struct svc_rqst *rqstp,
+ struct svc_rdma_recv_ctxt *ctxt)
+{
+ struct xdr_buf *arg = &rqstp->rq_arg;
+
+ arg->head[0].iov_base = ctxt->rc_recv_buf;
+ arg->head[0].iov_len = ctxt->rc_byte_len;
+ arg->tail[0].iov_base = NULL;
+ arg->tail[0].iov_len = 0;
+ arg->page_len = 0;
+ arg->page_base = 0;
+ arg->buflen = ctxt->rc_byte_len;
+ arg->len = ctxt->rc_byte_len;
+}
+
+/**
+ * xdr_count_read_segments - Count number of Read segments in Read list
+ * @rctxt: Ingress receive context
+ * @p: Start of an un-decoded Read list
+ *
+ * Before allocating anything, ensure the ingress Read list is safe
+ * to use.
+ *
+ * The segment count is limited to how many segments can fit in the
+ * transport header without overflowing the buffer. That's about 40
+ * Read segments for a 1KB inline threshold.
+ *
+ * Return values:
+ * %true: Read list is valid. @rctxt's xdr_stream is updated to point
+ * to the first byte past the Read list. rc_read_pcl and
+ * rc_call_pcl cl_count fields are set to the number of
+ * Read segments in the list.
+ * %false: Read list is corrupt. @rctxt's xdr_stream is left in an
+ * unknown state.
+ */
+static bool xdr_count_read_segments(struct svc_rdma_recv_ctxt *rctxt, __be32 *p)
+{
+ rctxt->rc_call_pcl.cl_count = 0;
+ rctxt->rc_read_pcl.cl_count = 0;
+ while (xdr_item_is_present(p)) {
+ u32 position, handle, length;
+ u64 offset;
+
+ p = xdr_inline_decode(&rctxt->rc_stream,
+ rpcrdma_readseg_maxsz * sizeof(*p));
+ if (!p)
+ return false;
+
+ xdr_decode_read_segment(p, &position, &handle,
+ &length, &offset);
+ if (position) {
+ if (position & 3)
+ return false;
+ ++rctxt->rc_read_pcl.cl_count;
+ } else {
+ ++rctxt->rc_call_pcl.cl_count;
+ }
+
+ p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p));
+ if (!p)
+ return false;
+ }
+ return true;
+}
+
+/* Sanity check the Read list.
+ *
+ * Sanity checks:
+ * - Read list does not overflow Receive buffer.
+ * - Chunk size limited by largest NFS data payload.
+ *
+ * Return values:
+ * %true: Read list is valid. @rctxt's xdr_stream is updated
+ * to point to the first byte past the Read list.
+ * %false: Read list is corrupt. @rctxt's xdr_stream is left
+ * in an unknown state.
+ */
+static bool xdr_check_read_list(struct svc_rdma_recv_ctxt *rctxt)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p));
+ if (!p)
+ return false;
+ if (!xdr_count_read_segments(rctxt, p))
+ return false;
+ if (!pcl_alloc_call(rctxt, p))
+ return false;
+ return pcl_alloc_read(rctxt, p);
+}
+
+static bool xdr_check_write_chunk(struct svc_rdma_recv_ctxt *rctxt)
+{
+ u32 segcount;
+ __be32 *p;
+
+ if (xdr_stream_decode_u32(&rctxt->rc_stream, &segcount))
+ return false;
+
+ /* A bogus segcount causes this buffer overflow check to fail. */
+ p = xdr_inline_decode(&rctxt->rc_stream,
+ segcount * rpcrdma_segment_maxsz * sizeof(*p));
+ return p != NULL;
+}
+
+/**
+ * xdr_count_write_chunks - Count number of Write chunks in Write list
+ * @rctxt: Received header and decoding state
+ * @p: start of an un-decoded Write list
+ *
+ * Before allocating anything, ensure the ingress Write list is
+ * safe to use.
+ *
+ * Return values:
+ * %true: Write list is valid. @rctxt's xdr_stream is updated
+ * to point to the first byte past the Write list, and
+ * the number of Write chunks is in rc_write_pcl.cl_count.
+ * %false: Write list is corrupt. @rctxt's xdr_stream is left
+ * in an indeterminate state.
+ */
+static bool xdr_count_write_chunks(struct svc_rdma_recv_ctxt *rctxt, __be32 *p)
+{
+ rctxt->rc_write_pcl.cl_count = 0;
+ while (xdr_item_is_present(p)) {
+ if (!xdr_check_write_chunk(rctxt))
+ return false;
+ ++rctxt->rc_write_pcl.cl_count;
+ p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p));
+ if (!p)
+ return false;
+ }
+ return true;
+}
+
+/* Sanity check the Write list.
+ *
+ * Implementation limits:
+ * - This implementation currently supports only one Write chunk.
+ *
+ * Sanity checks:
+ * - Write list does not overflow Receive buffer.
+ * - Chunk size limited by largest NFS data payload.
+ *
+ * Return values:
+ * %true: Write list is valid. @rctxt's xdr_stream is updated
+ * to point to the first byte past the Write list.
+ * %false: Write list is corrupt. @rctxt's xdr_stream is left
+ * in an unknown state.
+ */
+static bool xdr_check_write_list(struct svc_rdma_recv_ctxt *rctxt)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p));
+ if (!p)
+ return false;
+ if (!xdr_count_write_chunks(rctxt, p))
+ return false;
+ if (!pcl_alloc_write(rctxt, &rctxt->rc_write_pcl, p))
+ return false;
+
+ rctxt->rc_cur_result_payload = pcl_first_chunk(&rctxt->rc_write_pcl);
+ return true;
+}
+
+/* Sanity check the Reply chunk.
+ *
+ * Sanity checks:
+ * - Reply chunk does not overflow Receive buffer.
+ * - Chunk size limited by largest NFS data payload.
+ *
+ * Return values:
+ * %true: Reply chunk is valid. @rctxt's xdr_stream is updated
+ * to point to the first byte past the Reply chunk.
+ * %false: Reply chunk is corrupt. @rctxt's xdr_stream is left
+ * in an unknown state.
+ */
+static bool xdr_check_reply_chunk(struct svc_rdma_recv_ctxt *rctxt)
+{
+ __be32 *p;
+
+ p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p));
+ if (!p)
+ return false;
+
+ if (!xdr_item_is_present(p))
+ return true;
+ if (!xdr_check_write_chunk(rctxt))
+ return false;
+
+ rctxt->rc_reply_pcl.cl_count = 1;
+ return pcl_alloc_write(rctxt, &rctxt->rc_reply_pcl, p);
+}
+
+/* RPC-over-RDMA Version One private extension: Remote Invalidation.
+ * Responder's choice: requester signals it can handle Send With
+ * Invalidate, and responder chooses one R_key to invalidate.
+ *
+ * If there is exactly one distinct R_key in the received transport
+ * header, set rc_inv_rkey to that R_key. Otherwise, set it to zero.
+ */
+static void svc_rdma_get_inv_rkey(struct svcxprt_rdma *rdma,
+ struct svc_rdma_recv_ctxt *ctxt)
+{
+ struct svc_rdma_segment *segment;
+ struct svc_rdma_chunk *chunk;
+ u32 inv_rkey;
+
+ ctxt->rc_inv_rkey = 0;
+
+ if (!rdma->sc_snd_w_inv)
+ return;
+
+ inv_rkey = 0;
+ pcl_for_each_chunk(chunk, &ctxt->rc_call_pcl) {
+ pcl_for_each_segment(segment, chunk) {
+ if (inv_rkey == 0)
+ inv_rkey = segment->rs_handle;
+ else if (inv_rkey != segment->rs_handle)
+ return;
+ }
+ }
+ pcl_for_each_chunk(chunk, &ctxt->rc_read_pcl) {
+ pcl_for_each_segment(segment, chunk) {
+ if (inv_rkey == 0)
+ inv_rkey = segment->rs_handle;
+ else if (inv_rkey != segment->rs_handle)
+ return;
+ }
+ }
+ pcl_for_each_chunk(chunk, &ctxt->rc_write_pcl) {
+ pcl_for_each_segment(segment, chunk) {
+ if (inv_rkey == 0)
+ inv_rkey = segment->rs_handle;
+ else if (inv_rkey != segment->rs_handle)
+ return;
+ }
+ }
+ pcl_for_each_chunk(chunk, &ctxt->rc_reply_pcl) {
+ pcl_for_each_segment(segment, chunk) {
+ if (inv_rkey == 0)
+ inv_rkey = segment->rs_handle;
+ else if (inv_rkey != segment->rs_handle)
+ return;
+ }
+ }
+ ctxt->rc_inv_rkey = inv_rkey;
+}
+
+/**
+ * svc_rdma_xdr_decode_req - Decode the transport header
+ * @rq_arg: xdr_buf containing ingress RPC/RDMA message
+ * @rctxt: state of decoding
+ *
+ * On entry, xdr->head[0].iov_base points to first byte of the
+ * RPC-over-RDMA transport header.
+ *
+ * On successful exit, head[0] points to first byte past the
+ * RPC-over-RDMA header. For RDMA_MSG, this is the RPC message.
+ *
+ * The length of the RPC-over-RDMA header is returned.
+ *
+ * Assumptions:
+ * - The transport header is entirely contained in the head iovec.
+ */
+static int svc_rdma_xdr_decode_req(struct xdr_buf *rq_arg,
+ struct svc_rdma_recv_ctxt *rctxt)
+{
+ __be32 *p, *rdma_argp;
+ unsigned int hdr_len;
+
+ rdma_argp = rq_arg->head[0].iov_base;
+ xdr_init_decode(&rctxt->rc_stream, rq_arg, rdma_argp, NULL);
+
+ p = xdr_inline_decode(&rctxt->rc_stream,
+ rpcrdma_fixed_maxsz * sizeof(*p));
+ if (unlikely(!p))
+ goto out_short;
+ p++;
+ if (*p != rpcrdma_version)
+ goto out_version;
+ p += 2;
+ rctxt->rc_msgtype = *p;
+ switch (rctxt->rc_msgtype) {
+ case rdma_msg:
+ break;
+ case rdma_nomsg:
+ break;
+ case rdma_done:
+ goto out_drop;
+ case rdma_error:
+ goto out_drop;
+ default:
+ goto out_proc;
+ }
+
+ if (!xdr_check_read_list(rctxt))
+ goto out_inval;
+ if (!xdr_check_write_list(rctxt))
+ goto out_inval;
+ if (!xdr_check_reply_chunk(rctxt))
+ goto out_inval;
+
+ rq_arg->head[0].iov_base = rctxt->rc_stream.p;
+ hdr_len = xdr_stream_pos(&rctxt->rc_stream);
+ rq_arg->head[0].iov_len -= hdr_len;
+ rq_arg->len -= hdr_len;
+ trace_svcrdma_decode_rqst(rctxt, rdma_argp, hdr_len);
+ return hdr_len;
+
+out_short:
+ trace_svcrdma_decode_short_err(rctxt, rq_arg->len);
+ return -EINVAL;
+
+out_version:
+ trace_svcrdma_decode_badvers_err(rctxt, rdma_argp);
+ return -EPROTONOSUPPORT;
+
+out_drop:
+ trace_svcrdma_decode_drop_err(rctxt, rdma_argp);
+ return 0;
+
+out_proc:
+ trace_svcrdma_decode_badproc_err(rctxt, rdma_argp);
+ return -EINVAL;
+
+out_inval:
+ trace_svcrdma_decode_parse_err(rctxt, rdma_argp);
+ return -EINVAL;
+}
+
+static void svc_rdma_send_error(struct svcxprt_rdma *rdma,
+ struct svc_rdma_recv_ctxt *rctxt,
+ int status)
+{
+ struct svc_rdma_send_ctxt *sctxt;
+
+ sctxt = svc_rdma_send_ctxt_get(rdma);
+ if (!sctxt)
+ return;
+ svc_rdma_send_error_msg(rdma, sctxt, rctxt, status);
+}
+
+/* By convention, backchannel calls arrive via rdma_msg type
+ * messages, and never populate the chunk lists. This makes
+ * the RPC/RDMA header small and fixed in size, so it is
+ * straightforward to check the RPC header's direction field.
+ */
+static bool svc_rdma_is_reverse_direction_reply(struct svc_xprt *xprt,
+ struct svc_rdma_recv_ctxt *rctxt)
+{
+ __be32 *p = rctxt->rc_recv_buf;
+
+ if (!xprt->xpt_bc_xprt)
+ return false;
+
+ if (rctxt->rc_msgtype != rdma_msg)
+ return false;
+
+ if (!pcl_is_empty(&rctxt->rc_call_pcl))
+ return false;
+ if (!pcl_is_empty(&rctxt->rc_read_pcl))
+ return false;
+ if (!pcl_is_empty(&rctxt->rc_write_pcl))
+ return false;
+ if (!pcl_is_empty(&rctxt->rc_reply_pcl))
+ return false;
+
+ /* RPC call direction */
+ if (*(p + 8) == cpu_to_be32(RPC_CALL))
+ return false;
+
+ return true;
+}
+
+/**
+ * svc_rdma_recvfrom - Receive an RPC call
+ * @rqstp: request structure into which to receive an RPC Call
+ *
+ * Returns:
+ * The positive number of bytes in the RPC Call message,
+ * %0 if there were no Calls ready to return,
+ * %-EINVAL if the Read chunk data is too large,
+ * %-ENOMEM if rdma_rw context pool was exhausted,
+ * %-ENOTCONN if posting failed (connection is lost),
+ * %-EIO if rdma_rw initialization failed (DMA mapping, etc).
+ *
+ * Called in a loop when XPT_DATA is set. XPT_DATA is cleared only
+ * when there are no remaining ctxt's to process.
+ *
+ * The next ctxt is removed from the "receive" lists.
+ *
+ * - If the ctxt completes a Receive, then construct the Call
+ * message from the contents of the Receive buffer.
+ *
+ * - If there are no Read chunks in this message, then finish
+ * assembling the Call message and return the number of bytes
+ * in the message.
+ *
+ * - If there are Read chunks in this message, post Read WRs to
+ * pull that payload. When the Read WRs complete, build the
+ * full message and return the number of bytes in it.
+ */
+int svc_rdma_recvfrom(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ struct svcxprt_rdma *rdma_xprt =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+ struct svc_rdma_recv_ctxt *ctxt;
+ int ret;
+
+ /* Prevent svc_xprt_release() from releasing pages in rq_pages
+ * when returning 0 or an error.
+ */
+ rqstp->rq_respages = rqstp->rq_pages;
+ rqstp->rq_next_page = rqstp->rq_respages;
+
+ rqstp->rq_xprt_ctxt = NULL;
+
+ ctxt = NULL;
+ spin_lock(&rdma_xprt->sc_rq_dto_lock);
+ ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_rq_dto_q);
+ if (ctxt)
+ list_del(&ctxt->rc_list);
+ else
+ /* No new incoming requests, terminate the loop */
+ clear_bit(XPT_DATA, &xprt->xpt_flags);
+ spin_unlock(&rdma_xprt->sc_rq_dto_lock);
+
+ /* Unblock the transport for the next receive */
+ svc_xprt_received(xprt);
+ if (!ctxt)
+ return 0;
+
+ percpu_counter_inc(&svcrdma_stat_recv);
+ ib_dma_sync_single_for_cpu(rdma_xprt->sc_pd->device,
+ ctxt->rc_recv_sge.addr, ctxt->rc_byte_len,
+ DMA_FROM_DEVICE);
+ svc_rdma_build_arg_xdr(rqstp, ctxt);
+
+ ret = svc_rdma_xdr_decode_req(&rqstp->rq_arg, ctxt);
+ if (ret < 0)
+ goto out_err;
+ if (ret == 0)
+ goto out_drop;
+
+ if (svc_rdma_is_reverse_direction_reply(xprt, ctxt))
+ goto out_backchannel;
+
+ svc_rdma_get_inv_rkey(rdma_xprt, ctxt);
+
+ if (!pcl_is_empty(&ctxt->rc_read_pcl) ||
+ !pcl_is_empty(&ctxt->rc_call_pcl)) {
+ ret = svc_rdma_process_read_list(rdma_xprt, rqstp, ctxt);
+ if (ret < 0)
+ goto out_readfail;
+ }
+
+ rqstp->rq_xprt_ctxt = ctxt;
+ rqstp->rq_prot = IPPROTO_MAX;
+ svc_xprt_copy_addrs(rqstp, xprt);
+ set_bit(RQ_SECURE, &rqstp->rq_flags);
+ return rqstp->rq_arg.len;
+
+out_err:
+ svc_rdma_send_error(rdma_xprt, ctxt, ret);
+ svc_rdma_recv_ctxt_put(rdma_xprt, ctxt);
+ return 0;
+
+out_readfail:
+ if (ret == -EINVAL)
+ svc_rdma_send_error(rdma_xprt, ctxt, ret);
+ svc_rdma_recv_ctxt_put(rdma_xprt, ctxt);
+ svc_xprt_deferred_close(xprt);
+ return -ENOTCONN;
+
+out_backchannel:
+ svc_rdma_handle_bc_reply(rqstp, ctxt);
+out_drop:
+ svc_rdma_recv_ctxt_put(rdma_xprt, ctxt);
+ return 0;
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c
new file mode 100644
index 0000000000..e460e25a1d
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c
@@ -0,0 +1,1169 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2016-2018 Oracle. All rights reserved.
+ *
+ * Use the core R/W API to move RPC-over-RDMA Read and Write chunks.
+ */
+
+#include <rdma/rw.h>
+
+#include <linux/sunrpc/xdr.h>
+#include <linux/sunrpc/rpc_rdma.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc);
+static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc);
+
+/* Each R/W context contains state for one chain of RDMA Read or
+ * Write Work Requests.
+ *
+ * Each WR chain handles a single contiguous server-side buffer,
+ * because scatterlist entries after the first have to start on
+ * page alignment. xdr_buf iovecs cannot guarantee alignment.
+ *
+ * Each WR chain handles only one R_key. Each RPC-over-RDMA segment
+ * from a client may contain a unique R_key, so each WR chain moves
+ * up to one segment at a time.
+ *
+ * The scatterlist makes this data structure over 4KB in size. To
+ * make it less likely to fail, and to handle the allocation for
+ * smaller I/O requests without disabling bottom-halves, these
+ * contexts are created on demand, but cached and reused until the
+ * controlling svcxprt_rdma is destroyed.
+ */
+struct svc_rdma_rw_ctxt {
+ struct llist_node rw_node;
+ struct list_head rw_list;
+ struct rdma_rw_ctx rw_ctx;
+ unsigned int rw_nents;
+ struct sg_table rw_sg_table;
+ struct scatterlist rw_first_sgl[];
+};
+
+static inline struct svc_rdma_rw_ctxt *
+svc_rdma_next_ctxt(struct list_head *list)
+{
+ return list_first_entry_or_null(list, struct svc_rdma_rw_ctxt,
+ rw_list);
+}
+
+static struct svc_rdma_rw_ctxt *
+svc_rdma_get_rw_ctxt(struct svcxprt_rdma *rdma, unsigned int sges)
+{
+ struct svc_rdma_rw_ctxt *ctxt;
+ struct llist_node *node;
+
+ spin_lock(&rdma->sc_rw_ctxt_lock);
+ node = llist_del_first(&rdma->sc_rw_ctxts);
+ spin_unlock(&rdma->sc_rw_ctxt_lock);
+ if (node) {
+ ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node);
+ } else {
+ ctxt = kmalloc_node(struct_size(ctxt, rw_first_sgl, SG_CHUNK_SIZE),
+ GFP_KERNEL, ibdev_to_node(rdma->sc_cm_id->device));
+ if (!ctxt)
+ goto out_noctx;
+
+ INIT_LIST_HEAD(&ctxt->rw_list);
+ }
+
+ ctxt->rw_sg_table.sgl = ctxt->rw_first_sgl;
+ if (sg_alloc_table_chained(&ctxt->rw_sg_table, sges,
+ ctxt->rw_sg_table.sgl,
+ SG_CHUNK_SIZE))
+ goto out_free;
+ return ctxt;
+
+out_free:
+ kfree(ctxt);
+out_noctx:
+ trace_svcrdma_no_rwctx_err(rdma, sges);
+ return NULL;
+}
+
+static void __svc_rdma_put_rw_ctxt(struct svc_rdma_rw_ctxt *ctxt,
+ struct llist_head *list)
+{
+ sg_free_table_chained(&ctxt->rw_sg_table, SG_CHUNK_SIZE);
+ llist_add(&ctxt->rw_node, list);
+}
+
+static void svc_rdma_put_rw_ctxt(struct svcxprt_rdma *rdma,
+ struct svc_rdma_rw_ctxt *ctxt)
+{
+ __svc_rdma_put_rw_ctxt(ctxt, &rdma->sc_rw_ctxts);
+}
+
+/**
+ * svc_rdma_destroy_rw_ctxts - Free accumulated R/W contexts
+ * @rdma: transport about to be destroyed
+ *
+ */
+void svc_rdma_destroy_rw_ctxts(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_rw_ctxt *ctxt;
+ struct llist_node *node;
+
+ while ((node = llist_del_first(&rdma->sc_rw_ctxts)) != NULL) {
+ ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node);
+ kfree(ctxt);
+ }
+}
+
+/**
+ * svc_rdma_rw_ctx_init - Prepare a R/W context for I/O
+ * @rdma: controlling transport instance
+ * @ctxt: R/W context to prepare
+ * @offset: RDMA offset
+ * @handle: RDMA tag/handle
+ * @direction: I/O direction
+ *
+ * Returns on success, the number of WQEs that will be needed
+ * on the workqueue, or a negative errno.
+ */
+static int svc_rdma_rw_ctx_init(struct svcxprt_rdma *rdma,
+ struct svc_rdma_rw_ctxt *ctxt,
+ u64 offset, u32 handle,
+ enum dma_data_direction direction)
+{
+ int ret;
+
+ ret = rdma_rw_ctx_init(&ctxt->rw_ctx, rdma->sc_qp, rdma->sc_port_num,
+ ctxt->rw_sg_table.sgl, ctxt->rw_nents,
+ 0, offset, handle, direction);
+ if (unlikely(ret < 0)) {
+ svc_rdma_put_rw_ctxt(rdma, ctxt);
+ trace_svcrdma_dma_map_rw_err(rdma, ctxt->rw_nents, ret);
+ }
+ return ret;
+}
+
+/* A chunk context tracks all I/O for moving one Read or Write
+ * chunk. This is a set of rdma_rw's that handle data movement
+ * for all segments of one chunk.
+ *
+ * These are small, acquired with a single allocator call, and
+ * no more than one is needed per chunk. They are allocated on
+ * demand, and not cached.
+ */
+struct svc_rdma_chunk_ctxt {
+ struct rpc_rdma_cid cc_cid;
+ struct ib_cqe cc_cqe;
+ struct svcxprt_rdma *cc_rdma;
+ struct list_head cc_rwctxts;
+ ktime_t cc_posttime;
+ int cc_sqecount;
+ enum ib_wc_status cc_status;
+ struct completion cc_done;
+};
+
+static void svc_rdma_cc_cid_init(struct svcxprt_rdma *rdma,
+ struct rpc_rdma_cid *cid)
+{
+ cid->ci_queue_id = rdma->sc_sq_cq->res.id;
+ cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids);
+}
+
+static void svc_rdma_cc_init(struct svcxprt_rdma *rdma,
+ struct svc_rdma_chunk_ctxt *cc)
+{
+ svc_rdma_cc_cid_init(rdma, &cc->cc_cid);
+ cc->cc_rdma = rdma;
+
+ INIT_LIST_HEAD(&cc->cc_rwctxts);
+ cc->cc_sqecount = 0;
+}
+
+/*
+ * The consumed rw_ctx's are cleaned and placed on a local llist so
+ * that only one atomic llist operation is needed to put them all
+ * back on the free list.
+ */
+static void svc_rdma_cc_release(struct svc_rdma_chunk_ctxt *cc,
+ enum dma_data_direction dir)
+{
+ struct svcxprt_rdma *rdma = cc->cc_rdma;
+ struct llist_node *first, *last;
+ struct svc_rdma_rw_ctxt *ctxt;
+ LLIST_HEAD(free);
+
+ trace_svcrdma_cc_release(&cc->cc_cid, cc->cc_sqecount);
+
+ first = last = NULL;
+ while ((ctxt = svc_rdma_next_ctxt(&cc->cc_rwctxts)) != NULL) {
+ list_del(&ctxt->rw_list);
+
+ rdma_rw_ctx_destroy(&ctxt->rw_ctx, rdma->sc_qp,
+ rdma->sc_port_num, ctxt->rw_sg_table.sgl,
+ ctxt->rw_nents, dir);
+ __svc_rdma_put_rw_ctxt(ctxt, &free);
+
+ ctxt->rw_node.next = first;
+ first = &ctxt->rw_node;
+ if (!last)
+ last = first;
+ }
+ if (first)
+ llist_add_batch(first, last, &rdma->sc_rw_ctxts);
+}
+
+/* State for sending a Write or Reply chunk.
+ * - Tracks progress of writing one chunk over all its segments
+ * - Stores arguments for the SGL constructor functions
+ */
+struct svc_rdma_write_info {
+ const struct svc_rdma_chunk *wi_chunk;
+
+ /* write state of this chunk */
+ unsigned int wi_seg_off;
+ unsigned int wi_seg_no;
+
+ /* SGL constructor arguments */
+ const struct xdr_buf *wi_xdr;
+ unsigned char *wi_base;
+ unsigned int wi_next_off;
+
+ struct svc_rdma_chunk_ctxt wi_cc;
+};
+
+static struct svc_rdma_write_info *
+svc_rdma_write_info_alloc(struct svcxprt_rdma *rdma,
+ const struct svc_rdma_chunk *chunk)
+{
+ struct svc_rdma_write_info *info;
+
+ info = kmalloc_node(sizeof(*info), GFP_KERNEL,
+ ibdev_to_node(rdma->sc_cm_id->device));
+ if (!info)
+ return info;
+
+ info->wi_chunk = chunk;
+ info->wi_seg_off = 0;
+ info->wi_seg_no = 0;
+ svc_rdma_cc_init(rdma, &info->wi_cc);
+ info->wi_cc.cc_cqe.done = svc_rdma_write_done;
+ return info;
+}
+
+static void svc_rdma_write_info_free(struct svc_rdma_write_info *info)
+{
+ svc_rdma_cc_release(&info->wi_cc, DMA_TO_DEVICE);
+ kfree(info);
+}
+
+/**
+ * svc_rdma_write_done - Write chunk completion
+ * @cq: controlling Completion Queue
+ * @wc: Work Completion
+ *
+ * Pages under I/O are freed by a subsequent Send completion.
+ */
+static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct svc_rdma_chunk_ctxt *cc =
+ container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe);
+ struct svcxprt_rdma *rdma = cc->cc_rdma;
+ struct svc_rdma_write_info *info =
+ container_of(cc, struct svc_rdma_write_info, wi_cc);
+
+ switch (wc->status) {
+ case IB_WC_SUCCESS:
+ trace_svcrdma_wc_write(wc, &cc->cc_cid);
+ break;
+ case IB_WC_WR_FLUSH_ERR:
+ trace_svcrdma_wc_write_flush(wc, &cc->cc_cid);
+ break;
+ default:
+ trace_svcrdma_wc_write_err(wc, &cc->cc_cid);
+ }
+
+ svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount);
+
+ if (unlikely(wc->status != IB_WC_SUCCESS))
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+
+ svc_rdma_write_info_free(info);
+}
+
+/* State for pulling a Read chunk.
+ */
+struct svc_rdma_read_info {
+ struct svc_rqst *ri_rqst;
+ struct svc_rdma_recv_ctxt *ri_readctxt;
+ unsigned int ri_pageno;
+ unsigned int ri_pageoff;
+ unsigned int ri_totalbytes;
+
+ struct svc_rdma_chunk_ctxt ri_cc;
+};
+
+static struct svc_rdma_read_info *
+svc_rdma_read_info_alloc(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_read_info *info;
+
+ info = kmalloc_node(sizeof(*info), GFP_KERNEL,
+ ibdev_to_node(rdma->sc_cm_id->device));
+ if (!info)
+ return info;
+
+ svc_rdma_cc_init(rdma, &info->ri_cc);
+ info->ri_cc.cc_cqe.done = svc_rdma_wc_read_done;
+ return info;
+}
+
+static void svc_rdma_read_info_free(struct svc_rdma_read_info *info)
+{
+ svc_rdma_cc_release(&info->ri_cc, DMA_FROM_DEVICE);
+ kfree(info);
+}
+
+/**
+ * svc_rdma_wc_read_done - Handle completion of an RDMA Read ctx
+ * @cq: controlling Completion Queue
+ * @wc: Work Completion
+ *
+ */
+static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct svc_rdma_chunk_ctxt *cc =
+ container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe);
+ struct svc_rdma_read_info *info;
+
+ switch (wc->status) {
+ case IB_WC_SUCCESS:
+ info = container_of(cc, struct svc_rdma_read_info, ri_cc);
+ trace_svcrdma_wc_read(wc, &cc->cc_cid, info->ri_totalbytes,
+ cc->cc_posttime);
+ break;
+ case IB_WC_WR_FLUSH_ERR:
+ trace_svcrdma_wc_read_flush(wc, &cc->cc_cid);
+ break;
+ default:
+ trace_svcrdma_wc_read_err(wc, &cc->cc_cid);
+ }
+
+ svc_rdma_wake_send_waiters(cc->cc_rdma, cc->cc_sqecount);
+ cc->cc_status = wc->status;
+ complete(&cc->cc_done);
+ return;
+}
+
+/*
+ * Assumptions:
+ * - If ib_post_send() succeeds, only one completion is expected,
+ * even if one or more WRs are flushed. This is true when posting
+ * an rdma_rw_ctx or when posting a single signaled WR.
+ */
+static int svc_rdma_post_chunk_ctxt(struct svc_rdma_chunk_ctxt *cc)
+{
+ struct svcxprt_rdma *rdma = cc->cc_rdma;
+ struct ib_send_wr *first_wr;
+ const struct ib_send_wr *bad_wr;
+ struct list_head *tmp;
+ struct ib_cqe *cqe;
+ int ret;
+
+ might_sleep();
+
+ if (cc->cc_sqecount > rdma->sc_sq_depth)
+ return -EINVAL;
+
+ first_wr = NULL;
+ cqe = &cc->cc_cqe;
+ list_for_each(tmp, &cc->cc_rwctxts) {
+ struct svc_rdma_rw_ctxt *ctxt;
+
+ ctxt = list_entry(tmp, struct svc_rdma_rw_ctxt, rw_list);
+ first_wr = rdma_rw_ctx_wrs(&ctxt->rw_ctx, rdma->sc_qp,
+ rdma->sc_port_num, cqe, first_wr);
+ cqe = NULL;
+ }
+
+ do {
+ if (atomic_sub_return(cc->cc_sqecount,
+ &rdma->sc_sq_avail) > 0) {
+ cc->cc_posttime = ktime_get();
+ ret = ib_post_send(rdma->sc_qp, first_wr, &bad_wr);
+ if (ret)
+ break;
+ return 0;
+ }
+
+ percpu_counter_inc(&svcrdma_stat_sq_starve);
+ trace_svcrdma_sq_full(rdma);
+ atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail);
+ wait_event(rdma->sc_send_wait,
+ atomic_read(&rdma->sc_sq_avail) > cc->cc_sqecount);
+ trace_svcrdma_sq_retry(rdma);
+ } while (1);
+
+ trace_svcrdma_sq_post_err(rdma, ret);
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+
+ /* If even one was posted, there will be a completion. */
+ if (bad_wr != first_wr)
+ return 0;
+
+ atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail);
+ wake_up(&rdma->sc_send_wait);
+ return -ENOTCONN;
+}
+
+/* Build and DMA-map an SGL that covers one kvec in an xdr_buf
+ */
+static void svc_rdma_vec_to_sg(struct svc_rdma_write_info *info,
+ unsigned int len,
+ struct svc_rdma_rw_ctxt *ctxt)
+{
+ struct scatterlist *sg = ctxt->rw_sg_table.sgl;
+
+ sg_set_buf(&sg[0], info->wi_base, len);
+ info->wi_base += len;
+
+ ctxt->rw_nents = 1;
+}
+
+/* Build and DMA-map an SGL that covers part of an xdr_buf's pagelist.
+ */
+static void svc_rdma_pagelist_to_sg(struct svc_rdma_write_info *info,
+ unsigned int remaining,
+ struct svc_rdma_rw_ctxt *ctxt)
+{
+ unsigned int sge_no, sge_bytes, page_off, page_no;
+ const struct xdr_buf *xdr = info->wi_xdr;
+ struct scatterlist *sg;
+ struct page **page;
+
+ page_off = info->wi_next_off + xdr->page_base;
+ page_no = page_off >> PAGE_SHIFT;
+ page_off = offset_in_page(page_off);
+ page = xdr->pages + page_no;
+ info->wi_next_off += remaining;
+ sg = ctxt->rw_sg_table.sgl;
+ sge_no = 0;
+ do {
+ sge_bytes = min_t(unsigned int, remaining,
+ PAGE_SIZE - page_off);
+ sg_set_page(sg, *page, sge_bytes, page_off);
+
+ remaining -= sge_bytes;
+ sg = sg_next(sg);
+ page_off = 0;
+ sge_no++;
+ page++;
+ } while (remaining);
+
+ ctxt->rw_nents = sge_no;
+}
+
+/* Construct RDMA Write WRs to send a portion of an xdr_buf containing
+ * an RPC Reply.
+ */
+static int
+svc_rdma_build_writes(struct svc_rdma_write_info *info,
+ void (*constructor)(struct svc_rdma_write_info *info,
+ unsigned int len,
+ struct svc_rdma_rw_ctxt *ctxt),
+ unsigned int remaining)
+{
+ struct svc_rdma_chunk_ctxt *cc = &info->wi_cc;
+ struct svcxprt_rdma *rdma = cc->cc_rdma;
+ const struct svc_rdma_segment *seg;
+ struct svc_rdma_rw_ctxt *ctxt;
+ int ret;
+
+ do {
+ unsigned int write_len;
+ u64 offset;
+
+ if (info->wi_seg_no >= info->wi_chunk->ch_segcount)
+ goto out_overflow;
+
+ seg = &info->wi_chunk->ch_segments[info->wi_seg_no];
+ write_len = min(remaining, seg->rs_length - info->wi_seg_off);
+ if (!write_len)
+ goto out_overflow;
+ ctxt = svc_rdma_get_rw_ctxt(rdma,
+ (write_len >> PAGE_SHIFT) + 2);
+ if (!ctxt)
+ return -ENOMEM;
+
+ constructor(info, write_len, ctxt);
+ offset = seg->rs_offset + info->wi_seg_off;
+ ret = svc_rdma_rw_ctx_init(rdma, ctxt, offset, seg->rs_handle,
+ DMA_TO_DEVICE);
+ if (ret < 0)
+ return -EIO;
+ percpu_counter_inc(&svcrdma_stat_write);
+
+ list_add(&ctxt->rw_list, &cc->cc_rwctxts);
+ cc->cc_sqecount += ret;
+ if (write_len == seg->rs_length - info->wi_seg_off) {
+ info->wi_seg_no++;
+ info->wi_seg_off = 0;
+ } else {
+ info->wi_seg_off += write_len;
+ }
+ remaining -= write_len;
+ } while (remaining);
+
+ return 0;
+
+out_overflow:
+ trace_svcrdma_small_wrch_err(rdma, remaining, info->wi_seg_no,
+ info->wi_chunk->ch_segcount);
+ return -E2BIG;
+}
+
+/**
+ * svc_rdma_iov_write - Construct RDMA Writes from an iov
+ * @info: pointer to write arguments
+ * @iov: kvec to write
+ *
+ * Returns:
+ * On success, returns zero
+ * %-E2BIG if the client-provided Write chunk is too small
+ * %-ENOMEM if a resource has been exhausted
+ * %-EIO if an rdma-rw error occurred
+ */
+static int svc_rdma_iov_write(struct svc_rdma_write_info *info,
+ const struct kvec *iov)
+{
+ info->wi_base = iov->iov_base;
+ return svc_rdma_build_writes(info, svc_rdma_vec_to_sg,
+ iov->iov_len);
+}
+
+/**
+ * svc_rdma_pages_write - Construct RDMA Writes from pages
+ * @info: pointer to write arguments
+ * @xdr: xdr_buf with pages to write
+ * @offset: offset into the content of @xdr
+ * @length: number of bytes to write
+ *
+ * Returns:
+ * On success, returns zero
+ * %-E2BIG if the client-provided Write chunk is too small
+ * %-ENOMEM if a resource has been exhausted
+ * %-EIO if an rdma-rw error occurred
+ */
+static int svc_rdma_pages_write(struct svc_rdma_write_info *info,
+ const struct xdr_buf *xdr,
+ unsigned int offset,
+ unsigned long length)
+{
+ info->wi_xdr = xdr;
+ info->wi_next_off = offset - xdr->head[0].iov_len;
+ return svc_rdma_build_writes(info, svc_rdma_pagelist_to_sg,
+ length);
+}
+
+/**
+ * svc_rdma_xb_write - Construct RDMA Writes to write an xdr_buf
+ * @xdr: xdr_buf to write
+ * @data: pointer to write arguments
+ *
+ * Returns:
+ * On success, returns zero
+ * %-E2BIG if the client-provided Write chunk is too small
+ * %-ENOMEM if a resource has been exhausted
+ * %-EIO if an rdma-rw error occurred
+ */
+static int svc_rdma_xb_write(const struct xdr_buf *xdr, void *data)
+{
+ struct svc_rdma_write_info *info = data;
+ int ret;
+
+ if (xdr->head[0].iov_len) {
+ ret = svc_rdma_iov_write(info, &xdr->head[0]);
+ if (ret < 0)
+ return ret;
+ }
+
+ if (xdr->page_len) {
+ ret = svc_rdma_pages_write(info, xdr, xdr->head[0].iov_len,
+ xdr->page_len);
+ if (ret < 0)
+ return ret;
+ }
+
+ if (xdr->tail[0].iov_len) {
+ ret = svc_rdma_iov_write(info, &xdr->tail[0]);
+ if (ret < 0)
+ return ret;
+ }
+
+ return xdr->len;
+}
+
+/**
+ * svc_rdma_send_write_chunk - Write all segments in a Write chunk
+ * @rdma: controlling RDMA transport
+ * @chunk: Write chunk provided by the client
+ * @xdr: xdr_buf containing the data payload
+ *
+ * Returns a non-negative number of bytes the chunk consumed, or
+ * %-E2BIG if the payload was larger than the Write chunk,
+ * %-EINVAL if client provided too many segments,
+ * %-ENOMEM if rdma_rw context pool was exhausted,
+ * %-ENOTCONN if posting failed (connection is lost),
+ * %-EIO if rdma_rw initialization failed (DMA mapping, etc).
+ */
+int svc_rdma_send_write_chunk(struct svcxprt_rdma *rdma,
+ const struct svc_rdma_chunk *chunk,
+ const struct xdr_buf *xdr)
+{
+ struct svc_rdma_write_info *info;
+ struct svc_rdma_chunk_ctxt *cc;
+ int ret;
+
+ info = svc_rdma_write_info_alloc(rdma, chunk);
+ if (!info)
+ return -ENOMEM;
+ cc = &info->wi_cc;
+
+ ret = svc_rdma_xb_write(xdr, info);
+ if (ret != xdr->len)
+ goto out_err;
+
+ trace_svcrdma_post_write_chunk(&cc->cc_cid, cc->cc_sqecount);
+ ret = svc_rdma_post_chunk_ctxt(cc);
+ if (ret < 0)
+ goto out_err;
+ return xdr->len;
+
+out_err:
+ svc_rdma_write_info_free(info);
+ return ret;
+}
+
+/**
+ * svc_rdma_send_reply_chunk - Write all segments in the Reply chunk
+ * @rdma: controlling RDMA transport
+ * @rctxt: Write and Reply chunks from client
+ * @xdr: xdr_buf containing an RPC Reply
+ *
+ * Returns a non-negative number of bytes the chunk consumed, or
+ * %-E2BIG if the payload was larger than the Reply chunk,
+ * %-EINVAL if client provided too many segments,
+ * %-ENOMEM if rdma_rw context pool was exhausted,
+ * %-ENOTCONN if posting failed (connection is lost),
+ * %-EIO if rdma_rw initialization failed (DMA mapping, etc).
+ */
+int svc_rdma_send_reply_chunk(struct svcxprt_rdma *rdma,
+ const struct svc_rdma_recv_ctxt *rctxt,
+ const struct xdr_buf *xdr)
+{
+ struct svc_rdma_write_info *info;
+ struct svc_rdma_chunk_ctxt *cc;
+ struct svc_rdma_chunk *chunk;
+ int ret;
+
+ if (pcl_is_empty(&rctxt->rc_reply_pcl))
+ return 0;
+
+ chunk = pcl_first_chunk(&rctxt->rc_reply_pcl);
+ info = svc_rdma_write_info_alloc(rdma, chunk);
+ if (!info)
+ return -ENOMEM;
+ cc = &info->wi_cc;
+
+ ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr,
+ svc_rdma_xb_write, info);
+ if (ret < 0)
+ goto out_err;
+
+ trace_svcrdma_post_reply_chunk(&cc->cc_cid, cc->cc_sqecount);
+ ret = svc_rdma_post_chunk_ctxt(cc);
+ if (ret < 0)
+ goto out_err;
+
+ return xdr->len;
+
+out_err:
+ svc_rdma_write_info_free(info);
+ return ret;
+}
+
+/**
+ * svc_rdma_build_read_segment - Build RDMA Read WQEs to pull one RDMA segment
+ * @info: context for ongoing I/O
+ * @segment: co-ordinates of remote memory to be read
+ *
+ * Returns:
+ * %0: the Read WR chain was constructed successfully
+ * %-EINVAL: there were not enough rq_pages to finish
+ * %-ENOMEM: allocating a local resources failed
+ * %-EIO: a DMA mapping error occurred
+ */
+static int svc_rdma_build_read_segment(struct svc_rdma_read_info *info,
+ const struct svc_rdma_segment *segment)
+{
+ struct svc_rdma_recv_ctxt *head = info->ri_readctxt;
+ struct svc_rdma_chunk_ctxt *cc = &info->ri_cc;
+ struct svc_rqst *rqstp = info->ri_rqst;
+ unsigned int sge_no, seg_len, len;
+ struct svc_rdma_rw_ctxt *ctxt;
+ struct scatterlist *sg;
+ int ret;
+
+ len = segment->rs_length;
+ sge_no = PAGE_ALIGN(info->ri_pageoff + len) >> PAGE_SHIFT;
+ ctxt = svc_rdma_get_rw_ctxt(cc->cc_rdma, sge_no);
+ if (!ctxt)
+ return -ENOMEM;
+ ctxt->rw_nents = sge_no;
+
+ sg = ctxt->rw_sg_table.sgl;
+ for (sge_no = 0; sge_no < ctxt->rw_nents; sge_no++) {
+ seg_len = min_t(unsigned int, len,
+ PAGE_SIZE - info->ri_pageoff);
+
+ if (!info->ri_pageoff)
+ head->rc_page_count++;
+
+ sg_set_page(sg, rqstp->rq_pages[info->ri_pageno],
+ seg_len, info->ri_pageoff);
+ sg = sg_next(sg);
+
+ info->ri_pageoff += seg_len;
+ if (info->ri_pageoff == PAGE_SIZE) {
+ info->ri_pageno++;
+ info->ri_pageoff = 0;
+ }
+ len -= seg_len;
+
+ /* Safety check */
+ if (len &&
+ &rqstp->rq_pages[info->ri_pageno + 1] > rqstp->rq_page_end)
+ goto out_overrun;
+ }
+
+ ret = svc_rdma_rw_ctx_init(cc->cc_rdma, ctxt, segment->rs_offset,
+ segment->rs_handle, DMA_FROM_DEVICE);
+ if (ret < 0)
+ return -EIO;
+ percpu_counter_inc(&svcrdma_stat_read);
+
+ list_add(&ctxt->rw_list, &cc->cc_rwctxts);
+ cc->cc_sqecount += ret;
+ return 0;
+
+out_overrun:
+ trace_svcrdma_page_overrun_err(cc->cc_rdma, rqstp, info->ri_pageno);
+ return -EINVAL;
+}
+
+/**
+ * svc_rdma_build_read_chunk - Build RDMA Read WQEs to pull one RDMA chunk
+ * @info: context for ongoing I/O
+ * @chunk: Read chunk to pull
+ *
+ * Return values:
+ * %0: the Read WR chain was constructed successfully
+ * %-EINVAL: there were not enough resources to finish
+ * %-ENOMEM: allocating a local resources failed
+ * %-EIO: a DMA mapping error occurred
+ */
+static int svc_rdma_build_read_chunk(struct svc_rdma_read_info *info,
+ const struct svc_rdma_chunk *chunk)
+{
+ const struct svc_rdma_segment *segment;
+ int ret;
+
+ ret = -EINVAL;
+ pcl_for_each_segment(segment, chunk) {
+ ret = svc_rdma_build_read_segment(info, segment);
+ if (ret < 0)
+ break;
+ info->ri_totalbytes += segment->rs_length;
+ }
+ return ret;
+}
+
+/**
+ * svc_rdma_copy_inline_range - Copy part of the inline content into pages
+ * @info: context for RDMA Reads
+ * @offset: offset into the Receive buffer of region to copy
+ * @remaining: length of region to copy
+ *
+ * Take a page at a time from rqstp->rq_pages and copy the inline
+ * content from the Receive buffer into that page. Update
+ * info->ri_pageno and info->ri_pageoff so that the next RDMA Read
+ * result will land contiguously with the copied content.
+ *
+ * Return values:
+ * %0: Inline content was successfully copied
+ * %-EINVAL: offset or length was incorrect
+ */
+static int svc_rdma_copy_inline_range(struct svc_rdma_read_info *info,
+ unsigned int offset,
+ unsigned int remaining)
+{
+ struct svc_rdma_recv_ctxt *head = info->ri_readctxt;
+ unsigned char *dst, *src = head->rc_recv_buf;
+ struct svc_rqst *rqstp = info->ri_rqst;
+ unsigned int page_no, numpages;
+
+ numpages = PAGE_ALIGN(info->ri_pageoff + remaining) >> PAGE_SHIFT;
+ for (page_no = 0; page_no < numpages; page_no++) {
+ unsigned int page_len;
+
+ page_len = min_t(unsigned int, remaining,
+ PAGE_SIZE - info->ri_pageoff);
+
+ if (!info->ri_pageoff)
+ head->rc_page_count++;
+
+ dst = page_address(rqstp->rq_pages[info->ri_pageno]);
+ memcpy(dst + info->ri_pageno, src + offset, page_len);
+
+ info->ri_totalbytes += page_len;
+ info->ri_pageoff += page_len;
+ if (info->ri_pageoff == PAGE_SIZE) {
+ info->ri_pageno++;
+ info->ri_pageoff = 0;
+ }
+ remaining -= page_len;
+ offset += page_len;
+ }
+
+ return -EINVAL;
+}
+
+/**
+ * svc_rdma_read_multiple_chunks - Construct RDMA Reads to pull data item Read chunks
+ * @info: context for RDMA Reads
+ *
+ * The chunk data lands in rqstp->rq_arg as a series of contiguous pages,
+ * like an incoming TCP call.
+ *
+ * Return values:
+ * %0: RDMA Read WQEs were successfully built
+ * %-EINVAL: client provided too many chunks or segments,
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+static noinline int svc_rdma_read_multiple_chunks(struct svc_rdma_read_info *info)
+{
+ struct svc_rdma_recv_ctxt *head = info->ri_readctxt;
+ const struct svc_rdma_pcl *pcl = &head->rc_read_pcl;
+ struct xdr_buf *buf = &info->ri_rqst->rq_arg;
+ struct svc_rdma_chunk *chunk, *next;
+ unsigned int start, length;
+ int ret;
+
+ start = 0;
+ chunk = pcl_first_chunk(pcl);
+ length = chunk->ch_position;
+ ret = svc_rdma_copy_inline_range(info, start, length);
+ if (ret < 0)
+ return ret;
+
+ pcl_for_each_chunk(chunk, pcl) {
+ ret = svc_rdma_build_read_chunk(info, chunk);
+ if (ret < 0)
+ return ret;
+
+ next = pcl_next_chunk(pcl, chunk);
+ if (!next)
+ break;
+
+ start += length;
+ length = next->ch_position - info->ri_totalbytes;
+ ret = svc_rdma_copy_inline_range(info, start, length);
+ if (ret < 0)
+ return ret;
+ }
+
+ start += length;
+ length = head->rc_byte_len - start;
+ ret = svc_rdma_copy_inline_range(info, start, length);
+ if (ret < 0)
+ return ret;
+
+ buf->len += info->ri_totalbytes;
+ buf->buflen += info->ri_totalbytes;
+
+ buf->head[0].iov_base = page_address(info->ri_rqst->rq_pages[0]);
+ buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, info->ri_totalbytes);
+ buf->pages = &info->ri_rqst->rq_pages[1];
+ buf->page_len = info->ri_totalbytes - buf->head[0].iov_len;
+ return 0;
+}
+
+/**
+ * svc_rdma_read_data_item - Construct RDMA Reads to pull data item Read chunks
+ * @info: context for RDMA Reads
+ *
+ * The chunk data lands in the page list of rqstp->rq_arg.pages.
+ *
+ * Currently NFSD does not look at the rqstp->rq_arg.tail[0] kvec.
+ * Therefore, XDR round-up of the Read chunk and trailing
+ * inline content must both be added at the end of the pagelist.
+ *
+ * Return values:
+ * %0: RDMA Read WQEs were successfully built
+ * %-EINVAL: client provided too many chunks or segments,
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+static int svc_rdma_read_data_item(struct svc_rdma_read_info *info)
+{
+ struct svc_rdma_recv_ctxt *head = info->ri_readctxt;
+ struct xdr_buf *buf = &info->ri_rqst->rq_arg;
+ struct svc_rdma_chunk *chunk;
+ unsigned int length;
+ int ret;
+
+ chunk = pcl_first_chunk(&head->rc_read_pcl);
+ ret = svc_rdma_build_read_chunk(info, chunk);
+ if (ret < 0)
+ goto out;
+
+ /* Split the Receive buffer between the head and tail
+ * buffers at Read chunk's position. XDR roundup of the
+ * chunk is not included in either the pagelist or in
+ * the tail.
+ */
+ buf->tail[0].iov_base = buf->head[0].iov_base + chunk->ch_position;
+ buf->tail[0].iov_len = buf->head[0].iov_len - chunk->ch_position;
+ buf->head[0].iov_len = chunk->ch_position;
+
+ /* Read chunk may need XDR roundup (see RFC 8166, s. 3.4.5.2).
+ *
+ * If the client already rounded up the chunk length, the
+ * length does not change. Otherwise, the length of the page
+ * list is increased to include XDR round-up.
+ *
+ * Currently these chunks always start at page offset 0,
+ * thus the rounded-up length never crosses a page boundary.
+ */
+ buf->pages = &info->ri_rqst->rq_pages[0];
+ length = xdr_align_size(chunk->ch_length);
+ buf->page_len = length;
+ buf->len += length;
+ buf->buflen += length;
+
+out:
+ return ret;
+}
+
+/**
+ * svc_rdma_read_chunk_range - Build RDMA Read WQEs for portion of a chunk
+ * @info: context for RDMA Reads
+ * @chunk: parsed Call chunk to pull
+ * @offset: offset of region to pull
+ * @length: length of region to pull
+ *
+ * Return values:
+ * %0: RDMA Read WQEs were successfully built
+ * %-EINVAL: there were not enough resources to finish
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+static int svc_rdma_read_chunk_range(struct svc_rdma_read_info *info,
+ const struct svc_rdma_chunk *chunk,
+ unsigned int offset, unsigned int length)
+{
+ const struct svc_rdma_segment *segment;
+ int ret;
+
+ ret = -EINVAL;
+ pcl_for_each_segment(segment, chunk) {
+ struct svc_rdma_segment dummy;
+
+ if (offset > segment->rs_length) {
+ offset -= segment->rs_length;
+ continue;
+ }
+
+ dummy.rs_handle = segment->rs_handle;
+ dummy.rs_length = min_t(u32, length, segment->rs_length) - offset;
+ dummy.rs_offset = segment->rs_offset + offset;
+
+ ret = svc_rdma_build_read_segment(info, &dummy);
+ if (ret < 0)
+ break;
+
+ info->ri_totalbytes += dummy.rs_length;
+ length -= dummy.rs_length;
+ offset = 0;
+ }
+ return ret;
+}
+
+/**
+ * svc_rdma_read_call_chunk - Build RDMA Read WQEs to pull a Long Message
+ * @info: context for RDMA Reads
+ *
+ * Return values:
+ * %0: RDMA Read WQEs were successfully built
+ * %-EINVAL: there were not enough resources to finish
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+static int svc_rdma_read_call_chunk(struct svc_rdma_read_info *info)
+{
+ struct svc_rdma_recv_ctxt *head = info->ri_readctxt;
+ const struct svc_rdma_chunk *call_chunk =
+ pcl_first_chunk(&head->rc_call_pcl);
+ const struct svc_rdma_pcl *pcl = &head->rc_read_pcl;
+ struct svc_rdma_chunk *chunk, *next;
+ unsigned int start, length;
+ int ret;
+
+ if (pcl_is_empty(pcl))
+ return svc_rdma_build_read_chunk(info, call_chunk);
+
+ start = 0;
+ chunk = pcl_first_chunk(pcl);
+ length = chunk->ch_position;
+ ret = svc_rdma_read_chunk_range(info, call_chunk, start, length);
+ if (ret < 0)
+ return ret;
+
+ pcl_for_each_chunk(chunk, pcl) {
+ ret = svc_rdma_build_read_chunk(info, chunk);
+ if (ret < 0)
+ return ret;
+
+ next = pcl_next_chunk(pcl, chunk);
+ if (!next)
+ break;
+
+ start += length;
+ length = next->ch_position - info->ri_totalbytes;
+ ret = svc_rdma_read_chunk_range(info, call_chunk,
+ start, length);
+ if (ret < 0)
+ return ret;
+ }
+
+ start += length;
+ length = call_chunk->ch_length - start;
+ return svc_rdma_read_chunk_range(info, call_chunk, start, length);
+}
+
+/**
+ * svc_rdma_read_special - Build RDMA Read WQEs to pull a Long Message
+ * @info: context for RDMA Reads
+ *
+ * The start of the data lands in the first page just after the
+ * Transport header, and the rest lands in rqstp->rq_arg.pages.
+ *
+ * Assumptions:
+ * - A PZRC is never sent in an RDMA_MSG message, though it's
+ * allowed by spec.
+ *
+ * Return values:
+ * %0: RDMA Read WQEs were successfully built
+ * %-EINVAL: client provided too many chunks or segments,
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+static noinline int svc_rdma_read_special(struct svc_rdma_read_info *info)
+{
+ struct xdr_buf *buf = &info->ri_rqst->rq_arg;
+ int ret;
+
+ ret = svc_rdma_read_call_chunk(info);
+ if (ret < 0)
+ goto out;
+
+ buf->len += info->ri_totalbytes;
+ buf->buflen += info->ri_totalbytes;
+
+ buf->head[0].iov_base = page_address(info->ri_rqst->rq_pages[0]);
+ buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, info->ri_totalbytes);
+ buf->pages = &info->ri_rqst->rq_pages[1];
+ buf->page_len = info->ri_totalbytes - buf->head[0].iov_len;
+
+out:
+ return ret;
+}
+
+/**
+ * svc_rdma_process_read_list - Pull list of Read chunks from the client
+ * @rdma: controlling RDMA transport
+ * @rqstp: set of pages to use as Read sink buffers
+ * @head: pages under I/O collect here
+ *
+ * The RPC/RDMA protocol assumes that the upper layer's XDR decoders
+ * pull each Read chunk as they decode an incoming RPC message.
+ *
+ * On Linux, however, the server needs to have a fully-constructed RPC
+ * message in rqstp->rq_arg when there is a positive return code from
+ * ->xpo_recvfrom. So the Read list is safety-checked immediately when
+ * it is received, then here the whole Read list is pulled all at once.
+ * The ingress RPC message is fully reconstructed once all associated
+ * RDMA Reads have completed.
+ *
+ * Return values:
+ * %1: all needed RDMA Reads were posted successfully,
+ * %-EINVAL: client provided too many chunks or segments,
+ * %-ENOMEM: rdma_rw context pool was exhausted,
+ * %-ENOTCONN: posting failed (connection is lost),
+ * %-EIO: rdma_rw initialization failed (DMA mapping, etc).
+ */
+int svc_rdma_process_read_list(struct svcxprt_rdma *rdma,
+ struct svc_rqst *rqstp,
+ struct svc_rdma_recv_ctxt *head)
+{
+ struct svc_rdma_read_info *info;
+ struct svc_rdma_chunk_ctxt *cc;
+ int ret;
+
+ info = svc_rdma_read_info_alloc(rdma);
+ if (!info)
+ return -ENOMEM;
+ cc = &info->ri_cc;
+ info->ri_rqst = rqstp;
+ info->ri_readctxt = head;
+ info->ri_pageno = 0;
+ info->ri_pageoff = 0;
+ info->ri_totalbytes = 0;
+
+ if (pcl_is_empty(&head->rc_call_pcl)) {
+ if (head->rc_read_pcl.cl_count == 1)
+ ret = svc_rdma_read_data_item(info);
+ else
+ ret = svc_rdma_read_multiple_chunks(info);
+ } else
+ ret = svc_rdma_read_special(info);
+ if (ret < 0)
+ goto out_err;
+
+ trace_svcrdma_post_read_chunk(&cc->cc_cid, cc->cc_sqecount);
+ init_completion(&cc->cc_done);
+ ret = svc_rdma_post_chunk_ctxt(cc);
+ if (ret < 0)
+ goto out_err;
+
+ ret = 1;
+ wait_for_completion(&cc->cc_done);
+ if (cc->cc_status != IB_WC_SUCCESS)
+ ret = -EIO;
+
+ /* rq_respages starts after the last arg page */
+ rqstp->rq_respages = &rqstp->rq_pages[head->rc_page_count];
+ rqstp->rq_next_page = rqstp->rq_respages + 1;
+
+ /* Ensure svc_rdma_recv_ctxt_put() does not try to release pages */
+ head->rc_page_count = 0;
+
+out_err:
+ svc_rdma_read_info_free(info);
+ return ret;
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c
new file mode 100644
index 0000000000..c6644cca52
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c
@@ -0,0 +1,1062 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2016-2018 Oracle. All rights reserved.
+ * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved.
+ * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: Tom Tucker <tom@opengridcomputing.com>
+ */
+
+/* Operation
+ *
+ * The main entry point is svc_rdma_sendto. This is called by the
+ * RPC server when an RPC Reply is ready to be transmitted to a client.
+ *
+ * The passed-in svc_rqst contains a struct xdr_buf which holds an
+ * XDR-encoded RPC Reply message. sendto must construct the RPC-over-RDMA
+ * transport header, post all Write WRs needed for this Reply, then post
+ * a Send WR conveying the transport header and the RPC message itself to
+ * the client.
+ *
+ * svc_rdma_sendto must fully transmit the Reply before returning, as
+ * the svc_rqst will be recycled as soon as sendto returns. Remaining
+ * resources referred to by the svc_rqst are also recycled at that time.
+ * Therefore any resources that must remain longer must be detached
+ * from the svc_rqst and released later.
+ *
+ * Page Management
+ *
+ * The I/O that performs Reply transmission is asynchronous, and may
+ * complete well after sendto returns. Thus pages under I/O must be
+ * removed from the svc_rqst before sendto returns.
+ *
+ * The logic here depends on Send Queue and completion ordering. Since
+ * the Send WR is always posted last, it will always complete last. Thus
+ * when it completes, it is guaranteed that all previous Write WRs have
+ * also completed.
+ *
+ * Write WRs are constructed and posted. Each Write segment gets its own
+ * svc_rdma_rw_ctxt, allowing the Write completion handler to find and
+ * DMA-unmap the pages under I/O for that Write segment. The Write
+ * completion handler does not release any pages.
+ *
+ * When the Send WR is constructed, it also gets its own svc_rdma_send_ctxt.
+ * The ownership of all of the Reply's pages are transferred into that
+ * ctxt, the Send WR is posted, and sendto returns.
+ *
+ * The svc_rdma_send_ctxt is presented when the Send WR completes. The
+ * Send completion handler finally releases the Reply's pages.
+ *
+ * This mechanism also assumes that completions on the transport's Send
+ * Completion Queue do not run in parallel. Otherwise a Write completion
+ * and Send completion running at the same time could release pages that
+ * are still DMA-mapped.
+ *
+ * Error Handling
+ *
+ * - If the Send WR is posted successfully, it will either complete
+ * successfully, or get flushed. Either way, the Send completion
+ * handler releases the Reply's pages.
+ * - If the Send WR cannot be not posted, the forward path releases
+ * the Reply's pages.
+ *
+ * This handles the case, without the use of page reference counting,
+ * where two different Write segments send portions of the same page.
+ */
+
+#include <linux/spinlock.h>
+#include <asm/unaligned.h>
+
+#include <rdma/ib_verbs.h>
+#include <rdma/rdma_cm.h>
+
+#include <linux/sunrpc/debug.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc);
+
+static void svc_rdma_send_cid_init(struct svcxprt_rdma *rdma,
+ struct rpc_rdma_cid *cid)
+{
+ cid->ci_queue_id = rdma->sc_sq_cq->res.id;
+ cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids);
+}
+
+static struct svc_rdma_send_ctxt *
+svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma)
+{
+ int node = ibdev_to_node(rdma->sc_cm_id->device);
+ struct svc_rdma_send_ctxt *ctxt;
+ dma_addr_t addr;
+ void *buffer;
+ int i;
+
+ ctxt = kmalloc_node(struct_size(ctxt, sc_sges, rdma->sc_max_send_sges),
+ GFP_KERNEL, node);
+ if (!ctxt)
+ goto fail0;
+ buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node);
+ if (!buffer)
+ goto fail1;
+ addr = ib_dma_map_single(rdma->sc_pd->device, buffer,
+ rdma->sc_max_req_size, DMA_TO_DEVICE);
+ if (ib_dma_mapping_error(rdma->sc_pd->device, addr))
+ goto fail2;
+
+ svc_rdma_send_cid_init(rdma, &ctxt->sc_cid);
+
+ ctxt->sc_send_wr.next = NULL;
+ ctxt->sc_send_wr.wr_cqe = &ctxt->sc_cqe;
+ ctxt->sc_send_wr.sg_list = ctxt->sc_sges;
+ ctxt->sc_send_wr.send_flags = IB_SEND_SIGNALED;
+ ctxt->sc_cqe.done = svc_rdma_wc_send;
+ ctxt->sc_xprt_buf = buffer;
+ xdr_buf_init(&ctxt->sc_hdrbuf, ctxt->sc_xprt_buf,
+ rdma->sc_max_req_size);
+ ctxt->sc_sges[0].addr = addr;
+
+ for (i = 0; i < rdma->sc_max_send_sges; i++)
+ ctxt->sc_sges[i].lkey = rdma->sc_pd->local_dma_lkey;
+ return ctxt;
+
+fail2:
+ kfree(buffer);
+fail1:
+ kfree(ctxt);
+fail0:
+ return NULL;
+}
+
+/**
+ * svc_rdma_send_ctxts_destroy - Release all send_ctxt's for an xprt
+ * @rdma: svcxprt_rdma being torn down
+ *
+ */
+void svc_rdma_send_ctxts_destroy(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_send_ctxt *ctxt;
+ struct llist_node *node;
+
+ while ((node = llist_del_first(&rdma->sc_send_ctxts)) != NULL) {
+ ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node);
+ ib_dma_unmap_single(rdma->sc_pd->device,
+ ctxt->sc_sges[0].addr,
+ rdma->sc_max_req_size,
+ DMA_TO_DEVICE);
+ kfree(ctxt->sc_xprt_buf);
+ kfree(ctxt);
+ }
+}
+
+/**
+ * svc_rdma_send_ctxt_get - Get a free send_ctxt
+ * @rdma: controlling svcxprt_rdma
+ *
+ * Returns a ready-to-use send_ctxt, or NULL if none are
+ * available and a fresh one cannot be allocated.
+ */
+struct svc_rdma_send_ctxt *svc_rdma_send_ctxt_get(struct svcxprt_rdma *rdma)
+{
+ struct svc_rdma_send_ctxt *ctxt;
+ struct llist_node *node;
+
+ spin_lock(&rdma->sc_send_lock);
+ node = llist_del_first(&rdma->sc_send_ctxts);
+ if (!node)
+ goto out_empty;
+ ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node);
+ spin_unlock(&rdma->sc_send_lock);
+
+out:
+ rpcrdma_set_xdrlen(&ctxt->sc_hdrbuf, 0);
+ xdr_init_encode(&ctxt->sc_stream, &ctxt->sc_hdrbuf,
+ ctxt->sc_xprt_buf, NULL);
+
+ ctxt->sc_send_wr.num_sge = 0;
+ ctxt->sc_cur_sge_no = 0;
+ ctxt->sc_page_count = 0;
+ return ctxt;
+
+out_empty:
+ spin_unlock(&rdma->sc_send_lock);
+ ctxt = svc_rdma_send_ctxt_alloc(rdma);
+ if (!ctxt)
+ return NULL;
+ goto out;
+}
+
+/**
+ * svc_rdma_send_ctxt_put - Return send_ctxt to free list
+ * @rdma: controlling svcxprt_rdma
+ * @ctxt: object to return to the free list
+ *
+ * Pages left in sc_pages are DMA unmapped and released.
+ */
+void svc_rdma_send_ctxt_put(struct svcxprt_rdma *rdma,
+ struct svc_rdma_send_ctxt *ctxt)
+{
+ struct ib_device *device = rdma->sc_cm_id->device;
+ unsigned int i;
+
+ if (ctxt->sc_page_count)
+ release_pages(ctxt->sc_pages, ctxt->sc_page_count);
+
+ /* The first SGE contains the transport header, which
+ * remains mapped until @ctxt is destroyed.
+ */
+ for (i = 1; i < ctxt->sc_send_wr.num_sge; i++) {
+ ib_dma_unmap_page(device,
+ ctxt->sc_sges[i].addr,
+ ctxt->sc_sges[i].length,
+ DMA_TO_DEVICE);
+ trace_svcrdma_dma_unmap_page(rdma,
+ ctxt->sc_sges[i].addr,
+ ctxt->sc_sges[i].length);
+ }
+
+ llist_add(&ctxt->sc_node, &rdma->sc_send_ctxts);
+}
+
+/**
+ * svc_rdma_wake_send_waiters - manage Send Queue accounting
+ * @rdma: controlling transport
+ * @avail: Number of additional SQEs that are now available
+ *
+ */
+void svc_rdma_wake_send_waiters(struct svcxprt_rdma *rdma, int avail)
+{
+ atomic_add(avail, &rdma->sc_sq_avail);
+ smp_mb__after_atomic();
+ if (unlikely(waitqueue_active(&rdma->sc_send_wait)))
+ wake_up(&rdma->sc_send_wait);
+}
+
+/**
+ * svc_rdma_wc_send - Invoked by RDMA provider for each polled Send WC
+ * @cq: Completion Queue context
+ * @wc: Work Completion object
+ *
+ * NB: The svc_xprt/svcxprt_rdma is pinned whenever it's possible that
+ * the Send completion handler could be running.
+ */
+static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct svcxprt_rdma *rdma = cq->cq_context;
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct svc_rdma_send_ctxt *ctxt =
+ container_of(cqe, struct svc_rdma_send_ctxt, sc_cqe);
+
+ svc_rdma_wake_send_waiters(rdma, 1);
+
+ if (unlikely(wc->status != IB_WC_SUCCESS))
+ goto flushed;
+
+ trace_svcrdma_wc_send(wc, &ctxt->sc_cid);
+ svc_rdma_send_ctxt_put(rdma, ctxt);
+ return;
+
+flushed:
+ if (wc->status != IB_WC_WR_FLUSH_ERR)
+ trace_svcrdma_wc_send_err(wc, &ctxt->sc_cid);
+ else
+ trace_svcrdma_wc_send_flush(wc, &ctxt->sc_cid);
+ svc_rdma_send_ctxt_put(rdma, ctxt);
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+}
+
+/**
+ * svc_rdma_send - Post a single Send WR
+ * @rdma: transport on which to post the WR
+ * @ctxt: send ctxt with a Send WR ready to post
+ *
+ * Returns zero if the Send WR was posted successfully. Otherwise, a
+ * negative errno is returned.
+ */
+int svc_rdma_send(struct svcxprt_rdma *rdma, struct svc_rdma_send_ctxt *ctxt)
+{
+ struct ib_send_wr *wr = &ctxt->sc_send_wr;
+ int ret;
+
+ might_sleep();
+
+ /* Sync the transport header buffer */
+ ib_dma_sync_single_for_device(rdma->sc_pd->device,
+ wr->sg_list[0].addr,
+ wr->sg_list[0].length,
+ DMA_TO_DEVICE);
+
+ /* If the SQ is full, wait until an SQ entry is available */
+ while (1) {
+ if ((atomic_dec_return(&rdma->sc_sq_avail) < 0)) {
+ percpu_counter_inc(&svcrdma_stat_sq_starve);
+ trace_svcrdma_sq_full(rdma);
+ atomic_inc(&rdma->sc_sq_avail);
+ wait_event(rdma->sc_send_wait,
+ atomic_read(&rdma->sc_sq_avail) > 1);
+ if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags))
+ return -ENOTCONN;
+ trace_svcrdma_sq_retry(rdma);
+ continue;
+ }
+
+ trace_svcrdma_post_send(ctxt);
+ ret = ib_post_send(rdma->sc_qp, wr, NULL);
+ if (ret)
+ break;
+ return 0;
+ }
+
+ trace_svcrdma_sq_post_err(rdma, ret);
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+ wake_up(&rdma->sc_send_wait);
+ return ret;
+}
+
+/**
+ * svc_rdma_encode_read_list - Encode RPC Reply's Read chunk list
+ * @sctxt: Send context for the RPC Reply
+ *
+ * Return values:
+ * On success, returns length in bytes of the Reply XDR buffer
+ * that was consumed by the Reply Read list
+ * %-EMSGSIZE on XDR buffer overflow
+ */
+static ssize_t svc_rdma_encode_read_list(struct svc_rdma_send_ctxt *sctxt)
+{
+ /* RPC-over-RDMA version 1 replies never have a Read list. */
+ return xdr_stream_encode_item_absent(&sctxt->sc_stream);
+}
+
+/**
+ * svc_rdma_encode_write_segment - Encode one Write segment
+ * @sctxt: Send context for the RPC Reply
+ * @chunk: Write chunk to push
+ * @remaining: remaining bytes of the payload left in the Write chunk
+ * @segno: which segment in the chunk
+ *
+ * Return values:
+ * On success, returns length in bytes of the Reply XDR buffer
+ * that was consumed by the Write segment, and updates @remaining
+ * %-EMSGSIZE on XDR buffer overflow
+ */
+static ssize_t svc_rdma_encode_write_segment(struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_chunk *chunk,
+ u32 *remaining, unsigned int segno)
+{
+ const struct svc_rdma_segment *segment = &chunk->ch_segments[segno];
+ const size_t len = rpcrdma_segment_maxsz * sizeof(__be32);
+ u32 length;
+ __be32 *p;
+
+ p = xdr_reserve_space(&sctxt->sc_stream, len);
+ if (!p)
+ return -EMSGSIZE;
+
+ length = min_t(u32, *remaining, segment->rs_length);
+ *remaining -= length;
+ xdr_encode_rdma_segment(p, segment->rs_handle, length,
+ segment->rs_offset);
+ trace_svcrdma_encode_wseg(sctxt, segno, segment->rs_handle, length,
+ segment->rs_offset);
+ return len;
+}
+
+/**
+ * svc_rdma_encode_write_chunk - Encode one Write chunk
+ * @sctxt: Send context for the RPC Reply
+ * @chunk: Write chunk to push
+ *
+ * Copy a Write chunk from the Call transport header to the
+ * Reply transport header. Update each segment's length field
+ * to reflect the number of bytes written in that segment.
+ *
+ * Return values:
+ * On success, returns length in bytes of the Reply XDR buffer
+ * that was consumed by the Write chunk
+ * %-EMSGSIZE on XDR buffer overflow
+ */
+static ssize_t svc_rdma_encode_write_chunk(struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_chunk *chunk)
+{
+ u32 remaining = chunk->ch_payload_length;
+ unsigned int segno;
+ ssize_t len, ret;
+
+ len = 0;
+ ret = xdr_stream_encode_item_present(&sctxt->sc_stream);
+ if (ret < 0)
+ return ret;
+ len += ret;
+
+ ret = xdr_stream_encode_u32(&sctxt->sc_stream, chunk->ch_segcount);
+ if (ret < 0)
+ return ret;
+ len += ret;
+
+ for (segno = 0; segno < chunk->ch_segcount; segno++) {
+ ret = svc_rdma_encode_write_segment(sctxt, chunk, &remaining, segno);
+ if (ret < 0)
+ return ret;
+ len += ret;
+ }
+
+ return len;
+}
+
+/**
+ * svc_rdma_encode_write_list - Encode RPC Reply's Write chunk list
+ * @rctxt: Reply context with information about the RPC Call
+ * @sctxt: Send context for the RPC Reply
+ *
+ * Return values:
+ * On success, returns length in bytes of the Reply XDR buffer
+ * that was consumed by the Reply's Write list
+ * %-EMSGSIZE on XDR buffer overflow
+ */
+static ssize_t svc_rdma_encode_write_list(struct svc_rdma_recv_ctxt *rctxt,
+ struct svc_rdma_send_ctxt *sctxt)
+{
+ struct svc_rdma_chunk *chunk;
+ ssize_t len, ret;
+
+ len = 0;
+ pcl_for_each_chunk(chunk, &rctxt->rc_write_pcl) {
+ ret = svc_rdma_encode_write_chunk(sctxt, chunk);
+ if (ret < 0)
+ return ret;
+ len += ret;
+ }
+
+ /* Terminate the Write list */
+ ret = xdr_stream_encode_item_absent(&sctxt->sc_stream);
+ if (ret < 0)
+ return ret;
+
+ return len + ret;
+}
+
+/**
+ * svc_rdma_encode_reply_chunk - Encode RPC Reply's Reply chunk
+ * @rctxt: Reply context with information about the RPC Call
+ * @sctxt: Send context for the RPC Reply
+ * @length: size in bytes of the payload in the Reply chunk
+ *
+ * Return values:
+ * On success, returns length in bytes of the Reply XDR buffer
+ * that was consumed by the Reply's Reply chunk
+ * %-EMSGSIZE on XDR buffer overflow
+ * %-E2BIG if the RPC message is larger than the Reply chunk
+ */
+static ssize_t
+svc_rdma_encode_reply_chunk(struct svc_rdma_recv_ctxt *rctxt,
+ struct svc_rdma_send_ctxt *sctxt,
+ unsigned int length)
+{
+ struct svc_rdma_chunk *chunk;
+
+ if (pcl_is_empty(&rctxt->rc_reply_pcl))
+ return xdr_stream_encode_item_absent(&sctxt->sc_stream);
+
+ chunk = pcl_first_chunk(&rctxt->rc_reply_pcl);
+ if (length > chunk->ch_length)
+ return -E2BIG;
+
+ chunk->ch_payload_length = length;
+ return svc_rdma_encode_write_chunk(sctxt, chunk);
+}
+
+struct svc_rdma_map_data {
+ struct svcxprt_rdma *md_rdma;
+ struct svc_rdma_send_ctxt *md_ctxt;
+};
+
+/**
+ * svc_rdma_page_dma_map - DMA map one page
+ * @data: pointer to arguments
+ * @page: struct page to DMA map
+ * @offset: offset into the page
+ * @len: number of bytes to map
+ *
+ * Returns:
+ * %0 if DMA mapping was successful
+ * %-EIO if the page cannot be DMA mapped
+ */
+static int svc_rdma_page_dma_map(void *data, struct page *page,
+ unsigned long offset, unsigned int len)
+{
+ struct svc_rdma_map_data *args = data;
+ struct svcxprt_rdma *rdma = args->md_rdma;
+ struct svc_rdma_send_ctxt *ctxt = args->md_ctxt;
+ struct ib_device *dev = rdma->sc_cm_id->device;
+ dma_addr_t dma_addr;
+
+ ++ctxt->sc_cur_sge_no;
+
+ dma_addr = ib_dma_map_page(dev, page, offset, len, DMA_TO_DEVICE);
+ if (ib_dma_mapping_error(dev, dma_addr))
+ goto out_maperr;
+
+ trace_svcrdma_dma_map_page(rdma, dma_addr, len);
+ ctxt->sc_sges[ctxt->sc_cur_sge_no].addr = dma_addr;
+ ctxt->sc_sges[ctxt->sc_cur_sge_no].length = len;
+ ctxt->sc_send_wr.num_sge++;
+ return 0;
+
+out_maperr:
+ trace_svcrdma_dma_map_err(rdma, dma_addr, len);
+ return -EIO;
+}
+
+/**
+ * svc_rdma_iov_dma_map - DMA map an iovec
+ * @data: pointer to arguments
+ * @iov: kvec to DMA map
+ *
+ * ib_dma_map_page() is used here because svc_rdma_dma_unmap()
+ * handles DMA-unmap and it uses ib_dma_unmap_page() exclusively.
+ *
+ * Returns:
+ * %0 if DMA mapping was successful
+ * %-EIO if the iovec cannot be DMA mapped
+ */
+static int svc_rdma_iov_dma_map(void *data, const struct kvec *iov)
+{
+ if (!iov->iov_len)
+ return 0;
+ return svc_rdma_page_dma_map(data, virt_to_page(iov->iov_base),
+ offset_in_page(iov->iov_base),
+ iov->iov_len);
+}
+
+/**
+ * svc_rdma_xb_dma_map - DMA map all segments of an xdr_buf
+ * @xdr: xdr_buf containing portion of an RPC message to transmit
+ * @data: pointer to arguments
+ *
+ * Returns:
+ * %0 if DMA mapping was successful
+ * %-EIO if DMA mapping failed
+ *
+ * On failure, any DMA mappings that have been already done must be
+ * unmapped by the caller.
+ */
+static int svc_rdma_xb_dma_map(const struct xdr_buf *xdr, void *data)
+{
+ unsigned int len, remaining;
+ unsigned long pageoff;
+ struct page **ppages;
+ int ret;
+
+ ret = svc_rdma_iov_dma_map(data, &xdr->head[0]);
+ if (ret < 0)
+ return ret;
+
+ ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT);
+ pageoff = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining) {
+ len = min_t(u32, PAGE_SIZE - pageoff, remaining);
+
+ ret = svc_rdma_page_dma_map(data, *ppages++, pageoff, len);
+ if (ret < 0)
+ return ret;
+
+ remaining -= len;
+ pageoff = 0;
+ }
+
+ ret = svc_rdma_iov_dma_map(data, &xdr->tail[0]);
+ if (ret < 0)
+ return ret;
+
+ return xdr->len;
+}
+
+struct svc_rdma_pullup_data {
+ u8 *pd_dest;
+ unsigned int pd_length;
+ unsigned int pd_num_sges;
+};
+
+/**
+ * svc_rdma_xb_count_sges - Count how many SGEs will be needed
+ * @xdr: xdr_buf containing portion of an RPC message to transmit
+ * @data: pointer to arguments
+ *
+ * Returns:
+ * Number of SGEs needed to Send the contents of @xdr inline
+ */
+static int svc_rdma_xb_count_sges(const struct xdr_buf *xdr,
+ void *data)
+{
+ struct svc_rdma_pullup_data *args = data;
+ unsigned int remaining;
+ unsigned long offset;
+
+ if (xdr->head[0].iov_len)
+ ++args->pd_num_sges;
+
+ offset = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining) {
+ ++args->pd_num_sges;
+ remaining -= min_t(u32, PAGE_SIZE - offset, remaining);
+ offset = 0;
+ }
+
+ if (xdr->tail[0].iov_len)
+ ++args->pd_num_sges;
+
+ args->pd_length += xdr->len;
+ return 0;
+}
+
+/**
+ * svc_rdma_pull_up_needed - Determine whether to use pull-up
+ * @rdma: controlling transport
+ * @sctxt: send_ctxt for the Send WR
+ * @rctxt: Write and Reply chunks provided by client
+ * @xdr: xdr_buf containing RPC message to transmit
+ *
+ * Returns:
+ * %true if pull-up must be used
+ * %false otherwise
+ */
+static bool svc_rdma_pull_up_needed(const struct svcxprt_rdma *rdma,
+ const struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_recv_ctxt *rctxt,
+ const struct xdr_buf *xdr)
+{
+ /* Resources needed for the transport header */
+ struct svc_rdma_pullup_data args = {
+ .pd_length = sctxt->sc_hdrbuf.len,
+ .pd_num_sges = 1,
+ };
+ int ret;
+
+ ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr,
+ svc_rdma_xb_count_sges, &args);
+ if (ret < 0)
+ return false;
+
+ if (args.pd_length < RPCRDMA_PULLUP_THRESH)
+ return true;
+ return args.pd_num_sges >= rdma->sc_max_send_sges;
+}
+
+/**
+ * svc_rdma_xb_linearize - Copy region of xdr_buf to flat buffer
+ * @xdr: xdr_buf containing portion of an RPC message to copy
+ * @data: pointer to arguments
+ *
+ * Returns:
+ * Always zero.
+ */
+static int svc_rdma_xb_linearize(const struct xdr_buf *xdr,
+ void *data)
+{
+ struct svc_rdma_pullup_data *args = data;
+ unsigned int len, remaining;
+ unsigned long pageoff;
+ struct page **ppages;
+
+ if (xdr->head[0].iov_len) {
+ memcpy(args->pd_dest, xdr->head[0].iov_base, xdr->head[0].iov_len);
+ args->pd_dest += xdr->head[0].iov_len;
+ }
+
+ ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT);
+ pageoff = offset_in_page(xdr->page_base);
+ remaining = xdr->page_len;
+ while (remaining) {
+ len = min_t(u32, PAGE_SIZE - pageoff, remaining);
+ memcpy(args->pd_dest, page_address(*ppages) + pageoff, len);
+ remaining -= len;
+ args->pd_dest += len;
+ pageoff = 0;
+ ppages++;
+ }
+
+ if (xdr->tail[0].iov_len) {
+ memcpy(args->pd_dest, xdr->tail[0].iov_base, xdr->tail[0].iov_len);
+ args->pd_dest += xdr->tail[0].iov_len;
+ }
+
+ args->pd_length += xdr->len;
+ return 0;
+}
+
+/**
+ * svc_rdma_pull_up_reply_msg - Copy Reply into a single buffer
+ * @rdma: controlling transport
+ * @sctxt: send_ctxt for the Send WR; xprt hdr is already prepared
+ * @rctxt: Write and Reply chunks provided by client
+ * @xdr: prepared xdr_buf containing RPC message
+ *
+ * The device is not capable of sending the reply directly.
+ * Assemble the elements of @xdr into the transport header buffer.
+ *
+ * Assumptions:
+ * pull_up_needed has determined that @xdr will fit in the buffer.
+ *
+ * Returns:
+ * %0 if pull-up was successful
+ * %-EMSGSIZE if a buffer manipulation problem occurred
+ */
+static int svc_rdma_pull_up_reply_msg(const struct svcxprt_rdma *rdma,
+ struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_recv_ctxt *rctxt,
+ const struct xdr_buf *xdr)
+{
+ struct svc_rdma_pullup_data args = {
+ .pd_dest = sctxt->sc_xprt_buf + sctxt->sc_hdrbuf.len,
+ };
+ int ret;
+
+ ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr,
+ svc_rdma_xb_linearize, &args);
+ if (ret < 0)
+ return ret;
+
+ sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len + args.pd_length;
+ trace_svcrdma_send_pullup(sctxt, args.pd_length);
+ return 0;
+}
+
+/* svc_rdma_map_reply_msg - DMA map the buffer holding RPC message
+ * @rdma: controlling transport
+ * @sctxt: send_ctxt for the Send WR
+ * @rctxt: Write and Reply chunks provided by client
+ * @xdr: prepared xdr_buf containing RPC message
+ *
+ * Returns:
+ * %0 if DMA mapping was successful.
+ * %-EMSGSIZE if a buffer manipulation problem occurred
+ * %-EIO if DMA mapping failed
+ *
+ * The Send WR's num_sge field is set in all cases.
+ */
+int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma,
+ struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_recv_ctxt *rctxt,
+ const struct xdr_buf *xdr)
+{
+ struct svc_rdma_map_data args = {
+ .md_rdma = rdma,
+ .md_ctxt = sctxt,
+ };
+
+ /* Set up the (persistently-mapped) transport header SGE. */
+ sctxt->sc_send_wr.num_sge = 1;
+ sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len;
+
+ /* If there is a Reply chunk, nothing follows the transport
+ * header, and we're done here.
+ */
+ if (!pcl_is_empty(&rctxt->rc_reply_pcl))
+ return 0;
+
+ /* For pull-up, svc_rdma_send() will sync the transport header.
+ * No additional DMA mapping is necessary.
+ */
+ if (svc_rdma_pull_up_needed(rdma, sctxt, rctxt, xdr))
+ return svc_rdma_pull_up_reply_msg(rdma, sctxt, rctxt, xdr);
+
+ return pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr,
+ svc_rdma_xb_dma_map, &args);
+}
+
+/* The svc_rqst and all resources it owns are released as soon as
+ * svc_rdma_sendto returns. Transfer pages under I/O to the ctxt
+ * so they are released by the Send completion handler.
+ */
+static void svc_rdma_save_io_pages(struct svc_rqst *rqstp,
+ struct svc_rdma_send_ctxt *ctxt)
+{
+ int i, pages = rqstp->rq_next_page - rqstp->rq_respages;
+
+ ctxt->sc_page_count += pages;
+ for (i = 0; i < pages; i++) {
+ ctxt->sc_pages[i] = rqstp->rq_respages[i];
+ rqstp->rq_respages[i] = NULL;
+ }
+
+ /* Prevent svc_xprt_release from releasing pages in rq_pages */
+ rqstp->rq_next_page = rqstp->rq_respages;
+}
+
+/* Prepare the portion of the RPC Reply that will be transmitted
+ * via RDMA Send. The RPC-over-RDMA transport header is prepared
+ * in sc_sges[0], and the RPC xdr_buf is prepared in following sges.
+ *
+ * Depending on whether a Write list or Reply chunk is present,
+ * the server may send all, a portion of, or none of the xdr_buf.
+ * In the latter case, only the transport header (sc_sges[0]) is
+ * transmitted.
+ *
+ * RDMA Send is the last step of transmitting an RPC reply. Pages
+ * involved in the earlier RDMA Writes are here transferred out
+ * of the rqstp and into the sctxt's page array. These pages are
+ * DMA unmapped by each Write completion, but the subsequent Send
+ * completion finally releases these pages.
+ *
+ * Assumptions:
+ * - The Reply's transport header will never be larger than a page.
+ */
+static int svc_rdma_send_reply_msg(struct svcxprt_rdma *rdma,
+ struct svc_rdma_send_ctxt *sctxt,
+ const struct svc_rdma_recv_ctxt *rctxt,
+ struct svc_rqst *rqstp)
+{
+ int ret;
+
+ ret = svc_rdma_map_reply_msg(rdma, sctxt, rctxt, &rqstp->rq_res);
+ if (ret < 0)
+ return ret;
+
+ svc_rdma_save_io_pages(rqstp, sctxt);
+
+ if (rctxt->rc_inv_rkey) {
+ sctxt->sc_send_wr.opcode = IB_WR_SEND_WITH_INV;
+ sctxt->sc_send_wr.ex.invalidate_rkey = rctxt->rc_inv_rkey;
+ } else {
+ sctxt->sc_send_wr.opcode = IB_WR_SEND;
+ }
+
+ return svc_rdma_send(rdma, sctxt);
+}
+
+/**
+ * svc_rdma_send_error_msg - Send an RPC/RDMA v1 error response
+ * @rdma: controlling transport context
+ * @sctxt: Send context for the response
+ * @rctxt: Receive context for incoming bad message
+ * @status: negative errno indicating error that occurred
+ *
+ * Given the client-provided Read, Write, and Reply chunks, the
+ * server was not able to parse the Call or form a complete Reply.
+ * Return an RDMA_ERROR message so the client can retire the RPC
+ * transaction.
+ *
+ * The caller does not have to release @sctxt. It is released by
+ * Send completion, or by this function on error.
+ */
+void svc_rdma_send_error_msg(struct svcxprt_rdma *rdma,
+ struct svc_rdma_send_ctxt *sctxt,
+ struct svc_rdma_recv_ctxt *rctxt,
+ int status)
+{
+ __be32 *rdma_argp = rctxt->rc_recv_buf;
+ __be32 *p;
+
+ rpcrdma_set_xdrlen(&sctxt->sc_hdrbuf, 0);
+ xdr_init_encode(&sctxt->sc_stream, &sctxt->sc_hdrbuf,
+ sctxt->sc_xprt_buf, NULL);
+
+ p = xdr_reserve_space(&sctxt->sc_stream,
+ rpcrdma_fixed_maxsz * sizeof(*p));
+ if (!p)
+ goto put_ctxt;
+
+ *p++ = *rdma_argp;
+ *p++ = *(rdma_argp + 1);
+ *p++ = rdma->sc_fc_credits;
+ *p = rdma_error;
+
+ switch (status) {
+ case -EPROTONOSUPPORT:
+ p = xdr_reserve_space(&sctxt->sc_stream, 3 * sizeof(*p));
+ if (!p)
+ goto put_ctxt;
+
+ *p++ = err_vers;
+ *p++ = rpcrdma_version;
+ *p = rpcrdma_version;
+ trace_svcrdma_err_vers(*rdma_argp);
+ break;
+ default:
+ p = xdr_reserve_space(&sctxt->sc_stream, sizeof(*p));
+ if (!p)
+ goto put_ctxt;
+
+ *p = err_chunk;
+ trace_svcrdma_err_chunk(*rdma_argp);
+ }
+
+ /* Remote Invalidation is skipped for simplicity. */
+ sctxt->sc_send_wr.num_sge = 1;
+ sctxt->sc_send_wr.opcode = IB_WR_SEND;
+ sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len;
+ if (svc_rdma_send(rdma, sctxt))
+ goto put_ctxt;
+ return;
+
+put_ctxt:
+ svc_rdma_send_ctxt_put(rdma, sctxt);
+}
+
+/**
+ * svc_rdma_sendto - Transmit an RPC reply
+ * @rqstp: processed RPC request, reply XDR already in ::rq_res
+ *
+ * Any resources still associated with @rqstp are released upon return.
+ * If no reply message was possible, the connection is closed.
+ *
+ * Returns:
+ * %0 if an RPC reply has been successfully posted,
+ * %-ENOMEM if a resource shortage occurred (connection is lost),
+ * %-ENOTCONN if posting failed (connection is lost).
+ */
+int svc_rdma_sendto(struct svc_rqst *rqstp)
+{
+ struct svc_xprt *xprt = rqstp->rq_xprt;
+ struct svcxprt_rdma *rdma =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+ struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt;
+ __be32 *rdma_argp = rctxt->rc_recv_buf;
+ struct svc_rdma_send_ctxt *sctxt;
+ unsigned int rc_size;
+ __be32 *p;
+ int ret;
+
+ ret = -ENOTCONN;
+ if (svc_xprt_is_dead(xprt))
+ goto drop_connection;
+
+ ret = -ENOMEM;
+ sctxt = svc_rdma_send_ctxt_get(rdma);
+ if (!sctxt)
+ goto drop_connection;
+
+ ret = -EMSGSIZE;
+ p = xdr_reserve_space(&sctxt->sc_stream,
+ rpcrdma_fixed_maxsz * sizeof(*p));
+ if (!p)
+ goto put_ctxt;
+
+ ret = svc_rdma_send_reply_chunk(rdma, rctxt, &rqstp->rq_res);
+ if (ret < 0)
+ goto reply_chunk;
+ rc_size = ret;
+
+ *p++ = *rdma_argp;
+ *p++ = *(rdma_argp + 1);
+ *p++ = rdma->sc_fc_credits;
+ *p = pcl_is_empty(&rctxt->rc_reply_pcl) ? rdma_msg : rdma_nomsg;
+
+ ret = svc_rdma_encode_read_list(sctxt);
+ if (ret < 0)
+ goto put_ctxt;
+ ret = svc_rdma_encode_write_list(rctxt, sctxt);
+ if (ret < 0)
+ goto put_ctxt;
+ ret = svc_rdma_encode_reply_chunk(rctxt, sctxt, rc_size);
+ if (ret < 0)
+ goto put_ctxt;
+
+ ret = svc_rdma_send_reply_msg(rdma, sctxt, rctxt, rqstp);
+ if (ret < 0)
+ goto put_ctxt;
+ return 0;
+
+reply_chunk:
+ if (ret != -E2BIG && ret != -EINVAL)
+ goto put_ctxt;
+
+ /* Send completion releases payload pages that were part
+ * of previously posted RDMA Writes.
+ */
+ svc_rdma_save_io_pages(rqstp, sctxt);
+ svc_rdma_send_error_msg(rdma, sctxt, rctxt, ret);
+ return 0;
+
+put_ctxt:
+ svc_rdma_send_ctxt_put(rdma, sctxt);
+drop_connection:
+ trace_svcrdma_send_err(rqstp, ret);
+ svc_xprt_deferred_close(&rdma->sc_xprt);
+ return -ENOTCONN;
+}
+
+/**
+ * svc_rdma_result_payload - special processing for a result payload
+ * @rqstp: svc_rqst to operate on
+ * @offset: payload's byte offset in @xdr
+ * @length: size of payload, in bytes
+ *
+ * Return values:
+ * %0 if successful or nothing needed to be done
+ * %-EMSGSIZE on XDR buffer overflow
+ * %-E2BIG if the payload was larger than the Write chunk
+ * %-EINVAL if client provided too many segments
+ * %-ENOMEM if rdma_rw context pool was exhausted
+ * %-ENOTCONN if posting failed (connection is lost)
+ * %-EIO if rdma_rw initialization failed (DMA mapping, etc)
+ */
+int svc_rdma_result_payload(struct svc_rqst *rqstp, unsigned int offset,
+ unsigned int length)
+{
+ struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt;
+ struct svc_rdma_chunk *chunk;
+ struct svcxprt_rdma *rdma;
+ struct xdr_buf subbuf;
+ int ret;
+
+ chunk = rctxt->rc_cur_result_payload;
+ if (!length || !chunk)
+ return 0;
+ rctxt->rc_cur_result_payload =
+ pcl_next_chunk(&rctxt->rc_write_pcl, chunk);
+ if (length > chunk->ch_length)
+ return -E2BIG;
+
+ chunk->ch_position = offset;
+ chunk->ch_payload_length = length;
+
+ if (xdr_buf_subsegment(&rqstp->rq_res, &subbuf, offset, length))
+ return -EMSGSIZE;
+
+ rdma = container_of(rqstp->rq_xprt, struct svcxprt_rdma, sc_xprt);
+ ret = svc_rdma_send_write_chunk(rdma, chunk, &subbuf);
+ if (ret < 0)
+ return ret;
+ return 0;
+}
diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c
new file mode 100644
index 0000000000..2abd895046
--- /dev/null
+++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c
@@ -0,0 +1,603 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2015-2018 Oracle. All rights reserved.
+ * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved.
+ * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: Tom Tucker <tom@opengridcomputing.com>
+ */
+
+#include <linux/interrupt.h>
+#include <linux/sched.h>
+#include <linux/slab.h>
+#include <linux/spinlock.h>
+#include <linux/workqueue.h>
+#include <linux/export.h>
+
+#include <rdma/ib_verbs.h>
+#include <rdma/rdma_cm.h>
+#include <rdma/rw.h>
+
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/debug.h>
+#include <linux/sunrpc/svc_xprt.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+#define RPCDBG_FACILITY RPCDBG_SVCXPRT
+
+static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv,
+ struct net *net, int node);
+static struct svc_xprt *svc_rdma_create(struct svc_serv *serv,
+ struct net *net,
+ struct sockaddr *sa, int salen,
+ int flags);
+static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt);
+static void svc_rdma_detach(struct svc_xprt *xprt);
+static void svc_rdma_free(struct svc_xprt *xprt);
+static int svc_rdma_has_wspace(struct svc_xprt *xprt);
+static void svc_rdma_kill_temp_xprt(struct svc_xprt *);
+
+static const struct svc_xprt_ops svc_rdma_ops = {
+ .xpo_create = svc_rdma_create,
+ .xpo_recvfrom = svc_rdma_recvfrom,
+ .xpo_sendto = svc_rdma_sendto,
+ .xpo_result_payload = svc_rdma_result_payload,
+ .xpo_release_ctxt = svc_rdma_release_ctxt,
+ .xpo_detach = svc_rdma_detach,
+ .xpo_free = svc_rdma_free,
+ .xpo_has_wspace = svc_rdma_has_wspace,
+ .xpo_accept = svc_rdma_accept,
+ .xpo_kill_temp_xprt = svc_rdma_kill_temp_xprt,
+};
+
+struct svc_xprt_class svc_rdma_class = {
+ .xcl_name = "rdma",
+ .xcl_owner = THIS_MODULE,
+ .xcl_ops = &svc_rdma_ops,
+ .xcl_max_payload = RPCSVC_MAXPAYLOAD_RDMA,
+ .xcl_ident = XPRT_TRANSPORT_RDMA,
+};
+
+/* QP event handler */
+static void qp_event_handler(struct ib_event *event, void *context)
+{
+ struct svc_xprt *xprt = context;
+
+ trace_svcrdma_qp_error(event, (struct sockaddr *)&xprt->xpt_remote);
+ switch (event->event) {
+ /* These are considered benign events */
+ case IB_EVENT_PATH_MIG:
+ case IB_EVENT_COMM_EST:
+ case IB_EVENT_SQ_DRAINED:
+ case IB_EVENT_QP_LAST_WQE_REACHED:
+ break;
+
+ /* These are considered fatal events */
+ case IB_EVENT_PATH_MIG_ERR:
+ case IB_EVENT_QP_FATAL:
+ case IB_EVENT_QP_REQ_ERR:
+ case IB_EVENT_QP_ACCESS_ERR:
+ case IB_EVENT_DEVICE_FATAL:
+ default:
+ svc_xprt_deferred_close(xprt);
+ break;
+ }
+}
+
+static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv,
+ struct net *net, int node)
+{
+ struct svcxprt_rdma *cma_xprt;
+
+ cma_xprt = kzalloc_node(sizeof(*cma_xprt), GFP_KERNEL, node);
+ if (!cma_xprt)
+ return NULL;
+
+ svc_xprt_init(net, &svc_rdma_class, &cma_xprt->sc_xprt, serv);
+ INIT_LIST_HEAD(&cma_xprt->sc_accept_q);
+ INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q);
+ init_llist_head(&cma_xprt->sc_send_ctxts);
+ init_llist_head(&cma_xprt->sc_recv_ctxts);
+ init_llist_head(&cma_xprt->sc_rw_ctxts);
+ init_waitqueue_head(&cma_xprt->sc_send_wait);
+
+ spin_lock_init(&cma_xprt->sc_lock);
+ spin_lock_init(&cma_xprt->sc_rq_dto_lock);
+ spin_lock_init(&cma_xprt->sc_send_lock);
+ spin_lock_init(&cma_xprt->sc_rw_ctxt_lock);
+
+ /*
+ * Note that this implies that the underlying transport support
+ * has some form of congestion control (see RFC 7530 section 3.1
+ * paragraph 2). For now, we assume that all supported RDMA
+ * transports are suitable here.
+ */
+ set_bit(XPT_CONG_CTRL, &cma_xprt->sc_xprt.xpt_flags);
+
+ return cma_xprt;
+}
+
+static void
+svc_rdma_parse_connect_private(struct svcxprt_rdma *newxprt,
+ struct rdma_conn_param *param)
+{
+ const struct rpcrdma_connect_private *pmsg = param->private_data;
+
+ if (pmsg &&
+ pmsg->cp_magic == rpcrdma_cmp_magic &&
+ pmsg->cp_version == RPCRDMA_CMP_VERSION) {
+ newxprt->sc_snd_w_inv = pmsg->cp_flags &
+ RPCRDMA_CMP_F_SND_W_INV_OK;
+
+ dprintk("svcrdma: client send_size %u, recv_size %u "
+ "remote inv %ssupported\n",
+ rpcrdma_decode_buffer_size(pmsg->cp_send_size),
+ rpcrdma_decode_buffer_size(pmsg->cp_recv_size),
+ newxprt->sc_snd_w_inv ? "" : "un");
+ }
+}
+
+/*
+ * This function handles the CONNECT_REQUEST event on a listening
+ * endpoint. It is passed the cma_id for the _new_ connection. The context in
+ * this cma_id is inherited from the listening cma_id and is the svc_xprt
+ * structure for the listening endpoint.
+ *
+ * This function creates a new xprt for the new connection and enqueues it on
+ * the accept queue for the listent xprt. When the listen thread is kicked, it
+ * will call the recvfrom method on the listen xprt which will accept the new
+ * connection.
+ */
+static void handle_connect_req(struct rdma_cm_id *new_cma_id,
+ struct rdma_conn_param *param)
+{
+ struct svcxprt_rdma *listen_xprt = new_cma_id->context;
+ struct svcxprt_rdma *newxprt;
+ struct sockaddr *sa;
+
+ newxprt = svc_rdma_create_xprt(listen_xprt->sc_xprt.xpt_server,
+ listen_xprt->sc_xprt.xpt_net,
+ ibdev_to_node(new_cma_id->device));
+ if (!newxprt)
+ return;
+ newxprt->sc_cm_id = new_cma_id;
+ new_cma_id->context = newxprt;
+ svc_rdma_parse_connect_private(newxprt, param);
+
+ /* Save client advertised inbound read limit for use later in accept. */
+ newxprt->sc_ord = param->initiator_depth;
+
+ sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr;
+ newxprt->sc_xprt.xpt_remotelen = svc_addr_len(sa);
+ memcpy(&newxprt->sc_xprt.xpt_remote, sa,
+ newxprt->sc_xprt.xpt_remotelen);
+ snprintf(newxprt->sc_xprt.xpt_remotebuf,
+ sizeof(newxprt->sc_xprt.xpt_remotebuf) - 1, "%pISc", sa);
+
+ /* The remote port is arbitrary and not under the control of the
+ * client ULP. Set it to a fixed value so that the DRC continues
+ * to be effective after a reconnect.
+ */
+ rpc_set_port((struct sockaddr *)&newxprt->sc_xprt.xpt_remote, 0);
+
+ sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr;
+ svc_xprt_set_local(&newxprt->sc_xprt, sa, svc_addr_len(sa));
+
+ /*
+ * Enqueue the new transport on the accept queue of the listening
+ * transport
+ */
+ spin_lock(&listen_xprt->sc_lock);
+ list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q);
+ spin_unlock(&listen_xprt->sc_lock);
+
+ set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags);
+ svc_xprt_enqueue(&listen_xprt->sc_xprt);
+}
+
+/**
+ * svc_rdma_listen_handler - Handle CM events generated on a listening endpoint
+ * @cma_id: the server's listener rdma_cm_id
+ * @event: details of the event
+ *
+ * Return values:
+ * %0: Do not destroy @cma_id
+ * %1: Destroy @cma_id (never returned here)
+ *
+ * NB: There is never a DEVICE_REMOVAL event for INADDR_ANY listeners.
+ */
+static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id,
+ struct rdma_cm_event *event)
+{
+ switch (event->event) {
+ case RDMA_CM_EVENT_CONNECT_REQUEST:
+ handle_connect_req(cma_id, &event->param.conn);
+ break;
+ default:
+ break;
+ }
+ return 0;
+}
+
+/**
+ * svc_rdma_cma_handler - Handle CM events on client connections
+ * @cma_id: the server's listener rdma_cm_id
+ * @event: details of the event
+ *
+ * Return values:
+ * %0: Do not destroy @cma_id
+ * %1: Destroy @cma_id (never returned here)
+ */
+static int svc_rdma_cma_handler(struct rdma_cm_id *cma_id,
+ struct rdma_cm_event *event)
+{
+ struct svcxprt_rdma *rdma = cma_id->context;
+ struct svc_xprt *xprt = &rdma->sc_xprt;
+
+ switch (event->event) {
+ case RDMA_CM_EVENT_ESTABLISHED:
+ clear_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags);
+
+ /* Handle any requests that were received while
+ * CONN_PENDING was set. */
+ svc_xprt_enqueue(xprt);
+ break;
+ case RDMA_CM_EVENT_DISCONNECTED:
+ case RDMA_CM_EVENT_DEVICE_REMOVAL:
+ svc_xprt_deferred_close(xprt);
+ break;
+ default:
+ break;
+ }
+ return 0;
+}
+
+/*
+ * Create a listening RDMA service endpoint.
+ */
+static struct svc_xprt *svc_rdma_create(struct svc_serv *serv,
+ struct net *net,
+ struct sockaddr *sa, int salen,
+ int flags)
+{
+ struct rdma_cm_id *listen_id;
+ struct svcxprt_rdma *cma_xprt;
+ int ret;
+
+ if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6)
+ return ERR_PTR(-EAFNOSUPPORT);
+ cma_xprt = svc_rdma_create_xprt(serv, net, NUMA_NO_NODE);
+ if (!cma_xprt)
+ return ERR_PTR(-ENOMEM);
+ set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags);
+ strcpy(cma_xprt->sc_xprt.xpt_remotebuf, "listener");
+
+ listen_id = rdma_create_id(net, svc_rdma_listen_handler, cma_xprt,
+ RDMA_PS_TCP, IB_QPT_RC);
+ if (IS_ERR(listen_id)) {
+ ret = PTR_ERR(listen_id);
+ goto err0;
+ }
+
+ /* Allow both IPv4 and IPv6 sockets to bind a single port
+ * at the same time.
+ */
+#if IS_ENABLED(CONFIG_IPV6)
+ ret = rdma_set_afonly(listen_id, 1);
+ if (ret)
+ goto err1;
+#endif
+ ret = rdma_bind_addr(listen_id, sa);
+ if (ret)
+ goto err1;
+ cma_xprt->sc_cm_id = listen_id;
+
+ ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG);
+ if (ret)
+ goto err1;
+
+ /*
+ * We need to use the address from the cm_id in case the
+ * caller specified 0 for the port number.
+ */
+ sa = (struct sockaddr *)&cma_xprt->sc_cm_id->route.addr.src_addr;
+ svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen);
+
+ return &cma_xprt->sc_xprt;
+
+ err1:
+ rdma_destroy_id(listen_id);
+ err0:
+ kfree(cma_xprt);
+ return ERR_PTR(ret);
+}
+
+/*
+ * This is the xpo_recvfrom function for listening endpoints. Its
+ * purpose is to accept incoming connections. The CMA callback handler
+ * has already created a new transport and attached it to the new CMA
+ * ID.
+ *
+ * There is a queue of pending connections hung on the listening
+ * transport. This queue contains the new svc_xprt structure. This
+ * function takes svc_xprt structures off the accept_q and completes
+ * the connection.
+ */
+static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt)
+{
+ struct svcxprt_rdma *listen_rdma;
+ struct svcxprt_rdma *newxprt = NULL;
+ struct rdma_conn_param conn_param;
+ struct rpcrdma_connect_private pmsg;
+ struct ib_qp_init_attr qp_attr;
+ unsigned int ctxts, rq_depth;
+ struct ib_device *dev;
+ int ret = 0;
+ RPC_IFDEBUG(struct sockaddr *sap);
+
+ listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt);
+ clear_bit(XPT_CONN, &xprt->xpt_flags);
+ /* Get the next entry off the accept list */
+ spin_lock(&listen_rdma->sc_lock);
+ if (!list_empty(&listen_rdma->sc_accept_q)) {
+ newxprt = list_entry(listen_rdma->sc_accept_q.next,
+ struct svcxprt_rdma, sc_accept_q);
+ list_del_init(&newxprt->sc_accept_q);
+ }
+ if (!list_empty(&listen_rdma->sc_accept_q))
+ set_bit(XPT_CONN, &listen_rdma->sc_xprt.xpt_flags);
+ spin_unlock(&listen_rdma->sc_lock);
+ if (!newxprt)
+ return NULL;
+
+ dev = newxprt->sc_cm_id->device;
+ newxprt->sc_port_num = newxprt->sc_cm_id->port_num;
+
+ /* Qualify the transport resource defaults with the
+ * capabilities of this particular device */
+ /* Transport header, head iovec, tail iovec */
+ newxprt->sc_max_send_sges = 3;
+ /* Add one SGE per page list entry */
+ newxprt->sc_max_send_sges += (svcrdma_max_req_size / PAGE_SIZE) + 1;
+ if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge)
+ newxprt->sc_max_send_sges = dev->attrs.max_send_sge;
+ newxprt->sc_max_req_size = svcrdma_max_req_size;
+ newxprt->sc_max_requests = svcrdma_max_requests;
+ newxprt->sc_max_bc_requests = svcrdma_max_bc_requests;
+ newxprt->sc_recv_batch = RPCRDMA_MAX_RECV_BATCH;
+ rq_depth = newxprt->sc_max_requests + newxprt->sc_max_bc_requests +
+ newxprt->sc_recv_batch;
+ if (rq_depth > dev->attrs.max_qp_wr) {
+ pr_warn("svcrdma: reducing receive depth to %d\n",
+ dev->attrs.max_qp_wr);
+ rq_depth = dev->attrs.max_qp_wr;
+ newxprt->sc_recv_batch = 1;
+ newxprt->sc_max_requests = rq_depth - 2;
+ newxprt->sc_max_bc_requests = 2;
+ }
+ newxprt->sc_fc_credits = cpu_to_be32(newxprt->sc_max_requests);
+ ctxts = rdma_rw_mr_factor(dev, newxprt->sc_port_num, RPCSVC_MAXPAGES);
+ ctxts *= newxprt->sc_max_requests;
+ newxprt->sc_sq_depth = rq_depth + ctxts;
+ if (newxprt->sc_sq_depth > dev->attrs.max_qp_wr) {
+ pr_warn("svcrdma: reducing send depth to %d\n",
+ dev->attrs.max_qp_wr);
+ newxprt->sc_sq_depth = dev->attrs.max_qp_wr;
+ }
+ atomic_set(&newxprt->sc_sq_avail, newxprt->sc_sq_depth);
+
+ newxprt->sc_pd = ib_alloc_pd(dev, 0);
+ if (IS_ERR(newxprt->sc_pd)) {
+ trace_svcrdma_pd_err(newxprt, PTR_ERR(newxprt->sc_pd));
+ goto errout;
+ }
+ newxprt->sc_sq_cq = ib_alloc_cq_any(dev, newxprt, newxprt->sc_sq_depth,
+ IB_POLL_WORKQUEUE);
+ if (IS_ERR(newxprt->sc_sq_cq))
+ goto errout;
+ newxprt->sc_rq_cq =
+ ib_alloc_cq_any(dev, newxprt, rq_depth, IB_POLL_WORKQUEUE);
+ if (IS_ERR(newxprt->sc_rq_cq))
+ goto errout;
+
+ memset(&qp_attr, 0, sizeof qp_attr);
+ qp_attr.event_handler = qp_event_handler;
+ qp_attr.qp_context = &newxprt->sc_xprt;
+ qp_attr.port_num = newxprt->sc_port_num;
+ qp_attr.cap.max_rdma_ctxs = ctxts;
+ qp_attr.cap.max_send_wr = newxprt->sc_sq_depth - ctxts;
+ qp_attr.cap.max_recv_wr = rq_depth;
+ qp_attr.cap.max_send_sge = newxprt->sc_max_send_sges;
+ qp_attr.cap.max_recv_sge = 1;
+ qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
+ qp_attr.qp_type = IB_QPT_RC;
+ qp_attr.send_cq = newxprt->sc_sq_cq;
+ qp_attr.recv_cq = newxprt->sc_rq_cq;
+ dprintk("svcrdma: newxprt->sc_cm_id=%p, newxprt->sc_pd=%p\n",
+ newxprt->sc_cm_id, newxprt->sc_pd);
+ dprintk(" cap.max_send_wr = %d, cap.max_recv_wr = %d\n",
+ qp_attr.cap.max_send_wr, qp_attr.cap.max_recv_wr);
+ dprintk(" cap.max_send_sge = %d, cap.max_recv_sge = %d\n",
+ qp_attr.cap.max_send_sge, qp_attr.cap.max_recv_sge);
+
+ ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, &qp_attr);
+ if (ret) {
+ trace_svcrdma_qp_err(newxprt, ret);
+ goto errout;
+ }
+ newxprt->sc_qp = newxprt->sc_cm_id->qp;
+
+ if (!(dev->attrs.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
+ newxprt->sc_snd_w_inv = false;
+ if (!rdma_protocol_iwarp(dev, newxprt->sc_port_num) &&
+ !rdma_ib_or_roce(dev, newxprt->sc_port_num)) {
+ trace_svcrdma_fabric_err(newxprt, -EINVAL);
+ goto errout;
+ }
+
+ if (!svc_rdma_post_recvs(newxprt))
+ goto errout;
+
+ /* Construct RDMA-CM private message */
+ pmsg.cp_magic = rpcrdma_cmp_magic;
+ pmsg.cp_version = RPCRDMA_CMP_VERSION;
+ pmsg.cp_flags = 0;
+ pmsg.cp_send_size = pmsg.cp_recv_size =
+ rpcrdma_encode_buffer_size(newxprt->sc_max_req_size);
+
+ /* Accept Connection */
+ set_bit(RDMAXPRT_CONN_PENDING, &newxprt->sc_flags);
+ memset(&conn_param, 0, sizeof conn_param);
+ conn_param.responder_resources = 0;
+ conn_param.initiator_depth = min_t(int, newxprt->sc_ord,
+ dev->attrs.max_qp_init_rd_atom);
+ if (!conn_param.initiator_depth) {
+ ret = -EINVAL;
+ trace_svcrdma_initdepth_err(newxprt, ret);
+ goto errout;
+ }
+ conn_param.private_data = &pmsg;
+ conn_param.private_data_len = sizeof(pmsg);
+ rdma_lock_handler(newxprt->sc_cm_id);
+ newxprt->sc_cm_id->event_handler = svc_rdma_cma_handler;
+ ret = rdma_accept(newxprt->sc_cm_id, &conn_param);
+ rdma_unlock_handler(newxprt->sc_cm_id);
+ if (ret) {
+ trace_svcrdma_accept_err(newxprt, ret);
+ goto errout;
+ }
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+ dprintk("svcrdma: new connection %p accepted:\n", newxprt);
+ sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr;
+ dprintk(" local address : %pIS:%u\n", sap, rpc_get_port(sap));
+ sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr;
+ dprintk(" remote address : %pIS:%u\n", sap, rpc_get_port(sap));
+ dprintk(" max_sge : %d\n", newxprt->sc_max_send_sges);
+ dprintk(" sq_depth : %d\n", newxprt->sc_sq_depth);
+ dprintk(" rdma_rw_ctxs : %d\n", ctxts);
+ dprintk(" max_requests : %d\n", newxprt->sc_max_requests);
+ dprintk(" ord : %d\n", conn_param.initiator_depth);
+#endif
+
+ return &newxprt->sc_xprt;
+
+ errout:
+ /* Take a reference in case the DTO handler runs */
+ svc_xprt_get(&newxprt->sc_xprt);
+ if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp))
+ ib_destroy_qp(newxprt->sc_qp);
+ rdma_destroy_id(newxprt->sc_cm_id);
+ /* This call to put will destroy the transport */
+ svc_xprt_put(&newxprt->sc_xprt);
+ return NULL;
+}
+
+static void svc_rdma_detach(struct svc_xprt *xprt)
+{
+ struct svcxprt_rdma *rdma =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+
+ rdma_disconnect(rdma->sc_cm_id);
+}
+
+static void __svc_rdma_free(struct work_struct *work)
+{
+ struct svcxprt_rdma *rdma =
+ container_of(work, struct svcxprt_rdma, sc_work);
+
+ /* This blocks until the Completion Queues are empty */
+ if (rdma->sc_qp && !IS_ERR(rdma->sc_qp))
+ ib_drain_qp(rdma->sc_qp);
+
+ svc_rdma_flush_recv_queues(rdma);
+
+ svc_rdma_destroy_rw_ctxts(rdma);
+ svc_rdma_send_ctxts_destroy(rdma);
+ svc_rdma_recv_ctxts_destroy(rdma);
+
+ /* Destroy the QP if present (not a listener) */
+ if (rdma->sc_qp && !IS_ERR(rdma->sc_qp))
+ ib_destroy_qp(rdma->sc_qp);
+
+ if (rdma->sc_sq_cq && !IS_ERR(rdma->sc_sq_cq))
+ ib_free_cq(rdma->sc_sq_cq);
+
+ if (rdma->sc_rq_cq && !IS_ERR(rdma->sc_rq_cq))
+ ib_free_cq(rdma->sc_rq_cq);
+
+ if (rdma->sc_pd && !IS_ERR(rdma->sc_pd))
+ ib_dealloc_pd(rdma->sc_pd);
+
+ /* Destroy the CM ID */
+ rdma_destroy_id(rdma->sc_cm_id);
+
+ kfree(rdma);
+}
+
+static void svc_rdma_free(struct svc_xprt *xprt)
+{
+ struct svcxprt_rdma *rdma =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+
+ INIT_WORK(&rdma->sc_work, __svc_rdma_free);
+ schedule_work(&rdma->sc_work);
+}
+
+static int svc_rdma_has_wspace(struct svc_xprt *xprt)
+{
+ struct svcxprt_rdma *rdma =
+ container_of(xprt, struct svcxprt_rdma, sc_xprt);
+
+ /*
+ * If there are already waiters on the SQ,
+ * return false.
+ */
+ if (waitqueue_active(&rdma->sc_send_wait))
+ return 0;
+
+ /* Otherwise return true. */
+ return 1;
+}
+
+static void svc_rdma_kill_temp_xprt(struct svc_xprt *xprt)
+{
+}
diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c
new file mode 100644
index 0000000000..29b0562d62
--- /dev/null
+++ b/net/sunrpc/xprtrdma/transport.c
@@ -0,0 +1,796 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2014-2017 Oracle. All rights reserved.
+ * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * transport.c
+ *
+ * This file contains the top-level implementation of an RPC RDMA
+ * transport.
+ *
+ * Naming convention: functions beginning with xprt_ are part of the
+ * transport switch. All others are RPC RDMA internal.
+ */
+
+#include <linux/module.h>
+#include <linux/slab.h>
+#include <linux/seq_file.h>
+#include <linux/smp.h>
+
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/svc_rdma.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+/*
+ * tunables
+ */
+
+static unsigned int xprt_rdma_slot_table_entries = RPCRDMA_DEF_SLOT_TABLE;
+unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE;
+unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE;
+unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRWR;
+int xprt_rdma_pad_optimize;
+static struct xprt_class xprt_rdma;
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+
+static unsigned int min_slot_table_size = RPCRDMA_MIN_SLOT_TABLE;
+static unsigned int max_slot_table_size = RPCRDMA_MAX_SLOT_TABLE;
+static unsigned int min_inline_size = RPCRDMA_MIN_INLINE;
+static unsigned int max_inline_size = RPCRDMA_MAX_INLINE;
+static unsigned int max_padding = PAGE_SIZE;
+static unsigned int min_memreg = RPCRDMA_BOUNCEBUFFERS;
+static unsigned int max_memreg = RPCRDMA_LAST - 1;
+static unsigned int dummy;
+
+static struct ctl_table_header *sunrpc_table_header;
+
+static struct ctl_table xr_tunables_table[] = {
+ {
+ .procname = "rdma_slot_table_entries",
+ .data = &xprt_rdma_slot_table_entries,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_slot_table_size,
+ .extra2 = &max_slot_table_size
+ },
+ {
+ .procname = "rdma_max_inline_read",
+ .data = &xprt_rdma_max_inline_read,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_inline_size,
+ .extra2 = &max_inline_size,
+ },
+ {
+ .procname = "rdma_max_inline_write",
+ .data = &xprt_rdma_max_inline_write,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_inline_size,
+ .extra2 = &max_inline_size,
+ },
+ {
+ .procname = "rdma_inline_write_padding",
+ .data = &dummy,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = SYSCTL_ZERO,
+ .extra2 = &max_padding,
+ },
+ {
+ .procname = "rdma_memreg_strategy",
+ .data = &xprt_rdma_memreg_strategy,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_memreg,
+ .extra2 = &max_memreg,
+ },
+ {
+ .procname = "rdma_pad_optimize",
+ .data = &xprt_rdma_pad_optimize,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec,
+ },
+ { },
+};
+
+#endif
+
+static const struct rpc_xprt_ops xprt_rdma_procs;
+
+static void
+xprt_rdma_format_addresses4(struct rpc_xprt *xprt, struct sockaddr *sap)
+{
+ struct sockaddr_in *sin = (struct sockaddr_in *)sap;
+ char buf[20];
+
+ snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr));
+ xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL);
+
+ xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA;
+}
+
+static void
+xprt_rdma_format_addresses6(struct rpc_xprt *xprt, struct sockaddr *sap)
+{
+ struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap;
+ char buf[40];
+
+ snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr);
+ xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL);
+
+ xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA6;
+}
+
+void
+xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap)
+{
+ char buf[128];
+
+ switch (sap->sa_family) {
+ case AF_INET:
+ xprt_rdma_format_addresses4(xprt, sap);
+ break;
+ case AF_INET6:
+ xprt_rdma_format_addresses6(xprt, sap);
+ break;
+ default:
+ pr_err("rpcrdma: Unrecognized address family\n");
+ return;
+ }
+
+ (void)rpc_ntop(sap, buf, sizeof(buf));
+ xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL);
+
+ snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap));
+ xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL);
+
+ snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap));
+ xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL);
+
+ xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma";
+}
+
+void
+xprt_rdma_free_addresses(struct rpc_xprt *xprt)
+{
+ unsigned int i;
+
+ for (i = 0; i < RPC_DISPLAY_MAX; i++)
+ switch (i) {
+ case RPC_DISPLAY_PROTO:
+ case RPC_DISPLAY_NETID:
+ continue;
+ default:
+ kfree(xprt->address_strings[i]);
+ }
+}
+
+/**
+ * xprt_rdma_connect_worker - establish connection in the background
+ * @work: worker thread context
+ *
+ * Requester holds the xprt's send lock to prevent activity on this
+ * transport while a fresh connection is being established. RPC tasks
+ * sleep on the xprt's pending queue waiting for connect to complete.
+ */
+static void
+xprt_rdma_connect_worker(struct work_struct *work)
+{
+ struct rpcrdma_xprt *r_xprt = container_of(work, struct rpcrdma_xprt,
+ rx_connect_worker.work);
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ unsigned int pflags = current->flags;
+ int rc;
+
+ if (atomic_read(&xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+ rc = rpcrdma_xprt_connect(r_xprt);
+ xprt_clear_connecting(xprt);
+ if (!rc) {
+ xprt->connect_cookie++;
+ xprt->stat.connect_count++;
+ xprt->stat.connect_time += (long)jiffies -
+ xprt->stat.connect_start;
+ xprt_set_connected(xprt);
+ rc = -EAGAIN;
+ } else
+ rpcrdma_xprt_disconnect(r_xprt);
+ xprt_unlock_connect(xprt, r_xprt);
+ xprt_wake_pending_tasks(xprt, rc);
+ current_restore_flags(pflags, PF_MEMALLOC);
+}
+
+/**
+ * xprt_rdma_inject_disconnect - inject a connection fault
+ * @xprt: transport context
+ *
+ * If @xprt is connected, disconnect it to simulate spurious
+ * connection loss. Caller must hold @xprt's send lock to
+ * ensure that data structures and hardware resources are
+ * stable during the rdma_disconnect() call.
+ */
+static void
+xprt_rdma_inject_disconnect(struct rpc_xprt *xprt)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ trace_xprtrdma_op_inject_dsc(r_xprt);
+ rdma_disconnect(r_xprt->rx_ep->re_id);
+}
+
+/**
+ * xprt_rdma_destroy - Full tear down of transport
+ * @xprt: doomed transport context
+ *
+ * Caller guarantees there will be no more calls to us with
+ * this @xprt.
+ */
+static void
+xprt_rdma_destroy(struct rpc_xprt *xprt)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ cancel_delayed_work_sync(&r_xprt->rx_connect_worker);
+
+ rpcrdma_xprt_disconnect(r_xprt);
+ rpcrdma_buffer_destroy(&r_xprt->rx_buf);
+
+ xprt_rdma_free_addresses(xprt);
+ xprt_free(xprt);
+
+ module_put(THIS_MODULE);
+}
+
+/* 60 second timeout, no retries */
+static const struct rpc_timeout xprt_rdma_default_timeout = {
+ .to_initval = 60 * HZ,
+ .to_maxval = 60 * HZ,
+};
+
+/**
+ * xprt_setup_rdma - Set up transport to use RDMA
+ *
+ * @args: rpc transport arguments
+ */
+static struct rpc_xprt *
+xprt_setup_rdma(struct xprt_create *args)
+{
+ struct rpc_xprt *xprt;
+ struct rpcrdma_xprt *new_xprt;
+ struct sockaddr *sap;
+ int rc;
+
+ if (args->addrlen > sizeof(xprt->addr))
+ return ERR_PTR(-EBADF);
+
+ if (!try_module_get(THIS_MODULE))
+ return ERR_PTR(-EIO);
+
+ xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), 0,
+ xprt_rdma_slot_table_entries);
+ if (!xprt) {
+ module_put(THIS_MODULE);
+ return ERR_PTR(-ENOMEM);
+ }
+
+ xprt->timeout = &xprt_rdma_default_timeout;
+ xprt->connect_timeout = xprt->timeout->to_initval;
+ xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ xprt->bind_timeout = RPCRDMA_BIND_TO;
+ xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO;
+ xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO;
+
+ xprt->resvport = 0; /* privileged port not needed */
+ xprt->ops = &xprt_rdma_procs;
+
+ /*
+ * Set up RDMA-specific connect data.
+ */
+ sap = args->dstaddr;
+
+ /* Ensure xprt->addr holds valid server TCP (not RDMA)
+ * address, for any side protocols which peek at it */
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xprt_rdma;
+ xprt->addrlen = args->addrlen;
+ memcpy(&xprt->addr, sap, xprt->addrlen);
+
+ if (rpc_get_port(sap))
+ xprt_set_bound(xprt);
+ xprt_rdma_format_addresses(xprt, sap);
+
+ new_xprt = rpcx_to_rdmax(xprt);
+ rc = rpcrdma_buffer_create(new_xprt);
+ if (rc) {
+ xprt_rdma_free_addresses(xprt);
+ xprt_free(xprt);
+ module_put(THIS_MODULE);
+ return ERR_PTR(rc);
+ }
+
+ INIT_DELAYED_WORK(&new_xprt->rx_connect_worker,
+ xprt_rdma_connect_worker);
+
+ xprt->max_payload = RPCRDMA_MAX_DATA_SEGS << PAGE_SHIFT;
+
+ return xprt;
+}
+
+/**
+ * xprt_rdma_close - close a transport connection
+ * @xprt: transport context
+ *
+ * Called during autoclose or device removal.
+ *
+ * Caller holds @xprt's send lock to prevent activity on this
+ * transport while the connection is torn down.
+ */
+void xprt_rdma_close(struct rpc_xprt *xprt)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ rpcrdma_xprt_disconnect(r_xprt);
+
+ xprt->reestablish_timeout = 0;
+ ++xprt->connect_cookie;
+ xprt_disconnect_done(xprt);
+}
+
+/**
+ * xprt_rdma_set_port - update server port with rpcbind result
+ * @xprt: controlling RPC transport
+ * @port: new port value
+ *
+ * Transport connect status is unchanged.
+ */
+static void
+xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port)
+{
+ struct sockaddr *sap = (struct sockaddr *)&xprt->addr;
+ char buf[8];
+
+ rpc_set_port(sap, port);
+
+ kfree(xprt->address_strings[RPC_DISPLAY_PORT]);
+ snprintf(buf, sizeof(buf), "%u", port);
+ xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL);
+
+ kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]);
+ snprintf(buf, sizeof(buf), "%4hx", port);
+ xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL);
+}
+
+/**
+ * xprt_rdma_timer - invoked when an RPC times out
+ * @xprt: controlling RPC transport
+ * @task: RPC task that timed out
+ *
+ * Invoked when the transport is still connected, but an RPC
+ * retransmit timeout occurs.
+ *
+ * Since RDMA connections don't have a keep-alive, forcibly
+ * disconnect and retry to connect. This drives full
+ * detection of the network path, and retransmissions of
+ * all pending RPCs.
+ */
+static void
+xprt_rdma_timer(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ xprt_force_disconnect(xprt);
+}
+
+/**
+ * xprt_rdma_set_connect_timeout - set timeouts for establishing a connection
+ * @xprt: controlling transport instance
+ * @connect_timeout: reconnect timeout after client disconnects
+ * @reconnect_timeout: reconnect timeout after server disconnects
+ *
+ */
+static void xprt_rdma_set_connect_timeout(struct rpc_xprt *xprt,
+ unsigned long connect_timeout,
+ unsigned long reconnect_timeout)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+
+ trace_xprtrdma_op_set_cto(r_xprt, connect_timeout, reconnect_timeout);
+
+ spin_lock(&xprt->transport_lock);
+
+ if (connect_timeout < xprt->connect_timeout) {
+ struct rpc_timeout to;
+ unsigned long initval;
+
+ to = *xprt->timeout;
+ initval = connect_timeout;
+ if (initval < RPCRDMA_INIT_REEST_TO << 1)
+ initval = RPCRDMA_INIT_REEST_TO << 1;
+ to.to_initval = initval;
+ to.to_maxval = initval;
+ r_xprt->rx_timeout = to;
+ xprt->timeout = &r_xprt->rx_timeout;
+ xprt->connect_timeout = connect_timeout;
+ }
+
+ if (reconnect_timeout < xprt->max_reconnect_timeout)
+ xprt->max_reconnect_timeout = reconnect_timeout;
+
+ spin_unlock(&xprt->transport_lock);
+}
+
+/**
+ * xprt_rdma_connect - schedule an attempt to reconnect
+ * @xprt: transport state
+ * @task: RPC scheduler context (unused)
+ *
+ */
+static void
+xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ unsigned long delay;
+
+ WARN_ON_ONCE(!xprt_lock_connect(xprt, task, r_xprt));
+
+ delay = 0;
+ if (ep && ep->re_connect_status != 0) {
+ delay = xprt_reconnect_delay(xprt);
+ xprt_reconnect_backoff(xprt, RPCRDMA_INIT_REEST_TO);
+ }
+ trace_xprtrdma_op_connect(r_xprt, delay);
+ queue_delayed_work(system_long_wq, &r_xprt->rx_connect_worker, delay);
+}
+
+/**
+ * xprt_rdma_alloc_slot - allocate an rpc_rqst
+ * @xprt: controlling RPC transport
+ * @task: RPC task requesting a fresh rpc_rqst
+ *
+ * tk_status values:
+ * %0 if task->tk_rqstp points to a fresh rpc_rqst
+ * %-EAGAIN if no rpc_rqst is available; queued on backlog
+ */
+static void
+xprt_rdma_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ struct rpcrdma_req *req;
+
+ req = rpcrdma_buffer_get(&r_xprt->rx_buf);
+ if (!req)
+ goto out_sleep;
+ task->tk_rqstp = &req->rl_slot;
+ task->tk_status = 0;
+ return;
+
+out_sleep:
+ task->tk_status = -ENOMEM;
+ xprt_add_backlog(xprt, task);
+}
+
+/**
+ * xprt_rdma_free_slot - release an rpc_rqst
+ * @xprt: controlling RPC transport
+ * @rqst: rpc_rqst to release
+ *
+ */
+static void
+xprt_rdma_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *rqst)
+{
+ struct rpcrdma_xprt *r_xprt =
+ container_of(xprt, struct rpcrdma_xprt, rx_xprt);
+
+ rpcrdma_reply_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst));
+ if (!xprt_wake_up_backlog(xprt, rqst)) {
+ memset(rqst, 0, sizeof(*rqst));
+ rpcrdma_buffer_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst));
+ }
+}
+
+static bool rpcrdma_check_regbuf(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_regbuf *rb, size_t size,
+ gfp_t flags)
+{
+ if (unlikely(rdmab_length(rb) < size)) {
+ if (!rpcrdma_regbuf_realloc(rb, size, flags))
+ return false;
+ r_xprt->rx_stats.hardway_register_count += size;
+ }
+ return true;
+}
+
+/**
+ * xprt_rdma_allocate - allocate transport resources for an RPC
+ * @task: RPC task
+ *
+ * Return values:
+ * 0: Success; rq_buffer points to RPC buffer to use
+ * ENOMEM: Out of memory, call again later
+ * EIO: A permanent error occurred, do not retry
+ */
+static int
+xprt_rdma_allocate(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt);
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ gfp_t flags = rpc_task_gfp_mask();
+
+ if (!rpcrdma_check_regbuf(r_xprt, req->rl_sendbuf, rqst->rq_callsize,
+ flags))
+ goto out_fail;
+ if (!rpcrdma_check_regbuf(r_xprt, req->rl_recvbuf, rqst->rq_rcvsize,
+ flags))
+ goto out_fail;
+
+ rqst->rq_buffer = rdmab_data(req->rl_sendbuf);
+ rqst->rq_rbuffer = rdmab_data(req->rl_recvbuf);
+ return 0;
+
+out_fail:
+ return -ENOMEM;
+}
+
+/**
+ * xprt_rdma_free - release resources allocated by xprt_rdma_allocate
+ * @task: RPC task
+ *
+ * Caller guarantees rqst->rq_buffer is non-NULL.
+ */
+static void
+xprt_rdma_free(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+
+ if (unlikely(!list_empty(&req->rl_registered))) {
+ trace_xprtrdma_mrs_zap(task);
+ frwr_unmap_sync(rpcx_to_rdmax(rqst->rq_xprt), req);
+ }
+
+ /* XXX: If the RPC is completing because of a signal and
+ * not because a reply was received, we ought to ensure
+ * that the Send completion has fired, so that memory
+ * involved with the Send is not still visible to the NIC.
+ */
+}
+
+/**
+ * xprt_rdma_send_request - marshal and send an RPC request
+ * @rqst: RPC message in rq_snd_buf
+ *
+ * Caller holds the transport's write lock.
+ *
+ * Returns:
+ * %0 if the RPC message has been sent
+ * %-ENOTCONN if the caller should reconnect and call again
+ * %-EAGAIN if the caller should call again
+ * %-ENOBUFS if the caller should call again after a delay
+ * %-EMSGSIZE if encoding ran out of buffer space. The request
+ * was not sent. Do not try to send this message again.
+ * %-EIO if an I/O error occurred. The request was not sent.
+ * Do not try to send this message again.
+ */
+static int
+xprt_rdma_send_request(struct rpc_rqst *rqst)
+{
+ struct rpc_xprt *xprt = rqst->rq_xprt;
+ struct rpcrdma_req *req = rpcr_to_rdmar(rqst);
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ int rc = 0;
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+ if (unlikely(!rqst->rq_buffer))
+ return xprt_rdma_bc_send_reply(rqst);
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+ if (!xprt_connected(xprt))
+ return -ENOTCONN;
+
+ if (!xprt_request_get_cong(xprt, rqst))
+ return -EBADSLT;
+
+ rc = rpcrdma_marshal_req(r_xprt, rqst);
+ if (rc < 0)
+ goto failed_marshal;
+
+ /* Must suppress retransmit to maintain credits */
+ if (rqst->rq_connect_cookie == xprt->connect_cookie)
+ goto drop_connection;
+ rqst->rq_xtime = ktime_get();
+
+ if (frwr_send(r_xprt, req))
+ goto drop_connection;
+
+ rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len;
+
+ /* An RPC with no reply will throw off credit accounting,
+ * so drop the connection to reset the credit grant.
+ */
+ if (!rpc_reply_expected(rqst->rq_task))
+ goto drop_connection;
+ return 0;
+
+failed_marshal:
+ if (rc != -ENOTCONN)
+ return rc;
+drop_connection:
+ xprt_rdma_close(xprt);
+ return -ENOTCONN;
+}
+
+void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq)
+{
+ struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt);
+ long idle_time = 0;
+
+ if (xprt_connected(xprt))
+ idle_time = (long)(jiffies - xprt->last_used) / HZ;
+
+ seq_puts(seq, "\txprt:\trdma ");
+ seq_printf(seq, "%u %lu %lu %lu %ld %lu %lu %lu %llu %llu ",
+ 0, /* need a local port? */
+ xprt->stat.bind_count,
+ xprt->stat.connect_count,
+ xprt->stat.connect_time / HZ,
+ idle_time,
+ xprt->stat.sends,
+ xprt->stat.recvs,
+ xprt->stat.bad_xids,
+ xprt->stat.req_u,
+ xprt->stat.bklog_u);
+ seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %lu %lu %lu %lu ",
+ r_xprt->rx_stats.read_chunk_count,
+ r_xprt->rx_stats.write_chunk_count,
+ r_xprt->rx_stats.reply_chunk_count,
+ r_xprt->rx_stats.total_rdma_request,
+ r_xprt->rx_stats.total_rdma_reply,
+ r_xprt->rx_stats.pullup_copy_count,
+ r_xprt->rx_stats.fixup_copy_count,
+ r_xprt->rx_stats.hardway_register_count,
+ r_xprt->rx_stats.failed_marshal_count,
+ r_xprt->rx_stats.bad_reply_count,
+ r_xprt->rx_stats.nomsg_call_count);
+ seq_printf(seq, "%lu %lu %lu %lu %lu %lu\n",
+ r_xprt->rx_stats.mrs_recycled,
+ r_xprt->rx_stats.mrs_orphaned,
+ r_xprt->rx_stats.mrs_allocated,
+ r_xprt->rx_stats.local_inv_needed,
+ r_xprt->rx_stats.empty_sendctx_q,
+ r_xprt->rx_stats.reply_waits_for_send);
+}
+
+static int
+xprt_rdma_enable_swap(struct rpc_xprt *xprt)
+{
+ return 0;
+}
+
+static void
+xprt_rdma_disable_swap(struct rpc_xprt *xprt)
+{
+}
+
+/*
+ * Plumbing for rpc transport switch and kernel module
+ */
+
+static const struct rpc_xprt_ops xprt_rdma_procs = {
+ .reserve_xprt = xprt_reserve_xprt_cong,
+ .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */
+ .alloc_slot = xprt_rdma_alloc_slot,
+ .free_slot = xprt_rdma_free_slot,
+ .release_request = xprt_release_rqst_cong, /* ditto */
+ .wait_for_reply_request = xprt_wait_for_reply_request_def, /* ditto */
+ .timer = xprt_rdma_timer,
+ .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */
+ .set_port = xprt_rdma_set_port,
+ .connect = xprt_rdma_connect,
+ .buf_alloc = xprt_rdma_allocate,
+ .buf_free = xprt_rdma_free,
+ .send_request = xprt_rdma_send_request,
+ .close = xprt_rdma_close,
+ .destroy = xprt_rdma_destroy,
+ .set_connect_timeout = xprt_rdma_set_connect_timeout,
+ .print_stats = xprt_rdma_print_stats,
+ .enable_swap = xprt_rdma_enable_swap,
+ .disable_swap = xprt_rdma_disable_swap,
+ .inject_disconnect = xprt_rdma_inject_disconnect,
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+ .bc_setup = xprt_rdma_bc_setup,
+ .bc_maxpayload = xprt_rdma_bc_maxpayload,
+ .bc_num_slots = xprt_rdma_bc_max_slots,
+ .bc_free_rqst = xprt_rdma_bc_free_rqst,
+ .bc_destroy = xprt_rdma_bc_destroy,
+#endif
+};
+
+static struct xprt_class xprt_rdma = {
+ .list = LIST_HEAD_INIT(xprt_rdma.list),
+ .name = "rdma",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_RDMA,
+ .setup = xprt_setup_rdma,
+ .netid = { "rdma", "rdma6", "" },
+};
+
+void xprt_rdma_cleanup(void)
+{
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+ if (sunrpc_table_header) {
+ unregister_sysctl_table(sunrpc_table_header);
+ sunrpc_table_header = NULL;
+ }
+#endif
+
+ xprt_unregister_transport(&xprt_rdma);
+ xprt_unregister_transport(&xprt_rdma_bc);
+}
+
+int xprt_rdma_init(void)
+{
+ int rc;
+
+ rc = xprt_register_transport(&xprt_rdma);
+ if (rc)
+ return rc;
+
+ rc = xprt_register_transport(&xprt_rdma_bc);
+ if (rc) {
+ xprt_unregister_transport(&xprt_rdma);
+ return rc;
+ }
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+ if (!sunrpc_table_header)
+ sunrpc_table_header = register_sysctl("sunrpc", xr_tunables_table);
+#endif
+ return 0;
+}
diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c
new file mode 100644
index 0000000000..28c0771c4e
--- /dev/null
+++ b/net/sunrpc/xprtrdma/verbs.c
@@ -0,0 +1,1396 @@
+// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
+/*
+ * Copyright (c) 2014-2017 Oracle. All rights reserved.
+ * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * verbs.c
+ *
+ * Encapsulates the major functions managing:
+ * o adapters
+ * o endpoints
+ * o connections
+ * o buffer memory
+ */
+
+#include <linux/interrupt.h>
+#include <linux/slab.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/svc_rdma.h>
+#include <linux/log2.h>
+
+#include <asm-generic/barrier.h>
+#include <asm/bitops.h>
+
+#include <rdma/ib_cm.h>
+
+#include "xprt_rdma.h"
+#include <trace/events/rpcrdma.h>
+
+static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_sendctx *sc);
+static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep);
+static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt);
+static void rpcrdma_ep_get(struct rpcrdma_ep *ep);
+static int rpcrdma_ep_put(struct rpcrdma_ep *ep);
+static struct rpcrdma_regbuf *
+rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction);
+static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb);
+static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb);
+
+/* Wait for outstanding transport work to finish. ib_drain_qp
+ * handles the drains in the wrong order for us, so open code
+ * them here.
+ */
+static void rpcrdma_xprt_drain(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rdma_cm_id *id = ep->re_id;
+
+ /* Wait for rpcrdma_post_recvs() to leave its critical
+ * section.
+ */
+ if (atomic_inc_return(&ep->re_receiving) > 1)
+ wait_for_completion(&ep->re_done);
+
+ /* Flush Receives, then wait for deferred Reply work
+ * to complete.
+ */
+ ib_drain_rq(id->qp);
+
+ /* Deferred Reply processing might have scheduled
+ * local invalidations.
+ */
+ ib_drain_sq(id->qp);
+
+ rpcrdma_ep_put(ep);
+}
+
+/* Ensure xprt_force_disconnect() is invoked exactly once when a
+ * connection is closed or lost. (The important thing is it needs
+ * to be invoked "at least" once).
+ */
+void rpcrdma_force_disconnect(struct rpcrdma_ep *ep)
+{
+ if (atomic_add_unless(&ep->re_force_disconnect, 1, 1))
+ xprt_force_disconnect(ep->re_xprt);
+}
+
+/**
+ * rpcrdma_flush_disconnect - Disconnect on flushed completion
+ * @r_xprt: transport to disconnect
+ * @wc: work completion entry
+ *
+ * Must be called in process context.
+ */
+void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc)
+{
+ if (wc->status != IB_WC_SUCCESS)
+ rpcrdma_force_disconnect(r_xprt->rx_ep);
+}
+
+/**
+ * rpcrdma_wc_send - Invoked by RDMA provider for each polled Send WC
+ * @cq: completion queue
+ * @wc: WCE for a completed Send WR
+ *
+ */
+static void rpcrdma_wc_send(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_sendctx *sc =
+ container_of(cqe, struct rpcrdma_sendctx, sc_cqe);
+ struct rpcrdma_xprt *r_xprt = cq->cq_context;
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_send(wc, &sc->sc_cid);
+ rpcrdma_sendctx_put_locked(r_xprt, sc);
+ rpcrdma_flush_disconnect(r_xprt, wc);
+}
+
+/**
+ * rpcrdma_wc_receive - Invoked by RDMA provider for each polled Receive WC
+ * @cq: completion queue
+ * @wc: WCE for a completed Receive WR
+ *
+ */
+static void rpcrdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc)
+{
+ struct ib_cqe *cqe = wc->wr_cqe;
+ struct rpcrdma_rep *rep = container_of(cqe, struct rpcrdma_rep,
+ rr_cqe);
+ struct rpcrdma_xprt *r_xprt = cq->cq_context;
+
+ /* WARNING: Only wr_cqe and status are reliable at this point */
+ trace_xprtrdma_wc_receive(wc, &rep->rr_cid);
+ --r_xprt->rx_ep->re_receive_count;
+ if (wc->status != IB_WC_SUCCESS)
+ goto out_flushed;
+
+ /* status == SUCCESS means all fields in wc are trustworthy */
+ rpcrdma_set_xdrlen(&rep->rr_hdrbuf, wc->byte_len);
+ rep->rr_wc_flags = wc->wc_flags;
+ rep->rr_inv_rkey = wc->ex.invalidate_rkey;
+
+ ib_dma_sync_single_for_cpu(rdmab_device(rep->rr_rdmabuf),
+ rdmab_addr(rep->rr_rdmabuf),
+ wc->byte_len, DMA_FROM_DEVICE);
+
+ rpcrdma_reply_handler(rep);
+ return;
+
+out_flushed:
+ rpcrdma_flush_disconnect(r_xprt, wc);
+ rpcrdma_rep_put(&r_xprt->rx_buf, rep);
+}
+
+static void rpcrdma_update_cm_private(struct rpcrdma_ep *ep,
+ struct rdma_conn_param *param)
+{
+ const struct rpcrdma_connect_private *pmsg = param->private_data;
+ unsigned int rsize, wsize;
+
+ /* Default settings for RPC-over-RDMA Version One */
+ rsize = RPCRDMA_V1_DEF_INLINE_SIZE;
+ wsize = RPCRDMA_V1_DEF_INLINE_SIZE;
+
+ if (pmsg &&
+ pmsg->cp_magic == rpcrdma_cmp_magic &&
+ pmsg->cp_version == RPCRDMA_CMP_VERSION) {
+ rsize = rpcrdma_decode_buffer_size(pmsg->cp_send_size);
+ wsize = rpcrdma_decode_buffer_size(pmsg->cp_recv_size);
+ }
+
+ if (rsize < ep->re_inline_recv)
+ ep->re_inline_recv = rsize;
+ if (wsize < ep->re_inline_send)
+ ep->re_inline_send = wsize;
+
+ rpcrdma_set_max_header_sizes(ep);
+}
+
+/**
+ * rpcrdma_cm_event_handler - Handle RDMA CM events
+ * @id: rdma_cm_id on which an event has occurred
+ * @event: details of the event
+ *
+ * Called with @id's mutex held. Returns 1 if caller should
+ * destroy @id, otherwise 0.
+ */
+static int
+rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event)
+{
+ struct sockaddr *sap = (struct sockaddr *)&id->route.addr.dst_addr;
+ struct rpcrdma_ep *ep = id->context;
+
+ might_sleep();
+
+ switch (event->event) {
+ case RDMA_CM_EVENT_ADDR_RESOLVED:
+ case RDMA_CM_EVENT_ROUTE_RESOLVED:
+ ep->re_async_rc = 0;
+ complete(&ep->re_done);
+ return 0;
+ case RDMA_CM_EVENT_ADDR_ERROR:
+ ep->re_async_rc = -EPROTO;
+ complete(&ep->re_done);
+ return 0;
+ case RDMA_CM_EVENT_ROUTE_ERROR:
+ ep->re_async_rc = -ENETUNREACH;
+ complete(&ep->re_done);
+ return 0;
+ case RDMA_CM_EVENT_DEVICE_REMOVAL:
+ pr_info("rpcrdma: removing device %s for %pISpc\n",
+ ep->re_id->device->name, sap);
+ fallthrough;
+ case RDMA_CM_EVENT_ADDR_CHANGE:
+ ep->re_connect_status = -ENODEV;
+ goto disconnected;
+ case RDMA_CM_EVENT_ESTABLISHED:
+ rpcrdma_ep_get(ep);
+ ep->re_connect_status = 1;
+ rpcrdma_update_cm_private(ep, &event->param.conn);
+ trace_xprtrdma_inline_thresh(ep);
+ wake_up_all(&ep->re_connect_wait);
+ break;
+ case RDMA_CM_EVENT_CONNECT_ERROR:
+ ep->re_connect_status = -ENOTCONN;
+ goto wake_connect_worker;
+ case RDMA_CM_EVENT_UNREACHABLE:
+ ep->re_connect_status = -ENETUNREACH;
+ goto wake_connect_worker;
+ case RDMA_CM_EVENT_REJECTED:
+ ep->re_connect_status = -ECONNREFUSED;
+ if (event->status == IB_CM_REJ_STALE_CONN)
+ ep->re_connect_status = -ENOTCONN;
+wake_connect_worker:
+ wake_up_all(&ep->re_connect_wait);
+ return 0;
+ case RDMA_CM_EVENT_DISCONNECTED:
+ ep->re_connect_status = -ECONNABORTED;
+disconnected:
+ rpcrdma_force_disconnect(ep);
+ return rpcrdma_ep_put(ep);
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+static struct rdma_cm_id *rpcrdma_create_id(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_ep *ep)
+{
+ unsigned long wtimeout = msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1;
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct rdma_cm_id *id;
+ int rc;
+
+ init_completion(&ep->re_done);
+
+ id = rdma_create_id(xprt->xprt_net, rpcrdma_cm_event_handler, ep,
+ RDMA_PS_TCP, IB_QPT_RC);
+ if (IS_ERR(id))
+ return id;
+
+ ep->re_async_rc = -ETIMEDOUT;
+ rc = rdma_resolve_addr(id, NULL, (struct sockaddr *)&xprt->addr,
+ RDMA_RESOLVE_TIMEOUT);
+ if (rc)
+ goto out;
+ rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout);
+ if (rc < 0)
+ goto out;
+
+ rc = ep->re_async_rc;
+ if (rc)
+ goto out;
+
+ ep->re_async_rc = -ETIMEDOUT;
+ rc = rdma_resolve_route(id, RDMA_RESOLVE_TIMEOUT);
+ if (rc)
+ goto out;
+ rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout);
+ if (rc < 0)
+ goto out;
+ rc = ep->re_async_rc;
+ if (rc)
+ goto out;
+
+ return id;
+
+out:
+ rdma_destroy_id(id);
+ return ERR_PTR(rc);
+}
+
+static void rpcrdma_ep_destroy(struct kref *kref)
+{
+ struct rpcrdma_ep *ep = container_of(kref, struct rpcrdma_ep, re_kref);
+
+ if (ep->re_id->qp) {
+ rdma_destroy_qp(ep->re_id);
+ ep->re_id->qp = NULL;
+ }
+
+ if (ep->re_attr.recv_cq)
+ ib_free_cq(ep->re_attr.recv_cq);
+ ep->re_attr.recv_cq = NULL;
+ if (ep->re_attr.send_cq)
+ ib_free_cq(ep->re_attr.send_cq);
+ ep->re_attr.send_cq = NULL;
+
+ if (ep->re_pd)
+ ib_dealloc_pd(ep->re_pd);
+ ep->re_pd = NULL;
+
+ kfree(ep);
+ module_put(THIS_MODULE);
+}
+
+static noinline void rpcrdma_ep_get(struct rpcrdma_ep *ep)
+{
+ kref_get(&ep->re_kref);
+}
+
+/* Returns:
+ * %0 if @ep still has a positive kref count, or
+ * %1 if @ep was destroyed successfully.
+ */
+static noinline int rpcrdma_ep_put(struct rpcrdma_ep *ep)
+{
+ return kref_put(&ep->re_kref, rpcrdma_ep_destroy);
+}
+
+static int rpcrdma_ep_create(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_connect_private *pmsg;
+ struct ib_device *device;
+ struct rdma_cm_id *id;
+ struct rpcrdma_ep *ep;
+ int rc;
+
+ ep = kzalloc(sizeof(*ep), XPRTRDMA_GFP_FLAGS);
+ if (!ep)
+ return -ENOTCONN;
+ ep->re_xprt = &r_xprt->rx_xprt;
+ kref_init(&ep->re_kref);
+
+ id = rpcrdma_create_id(r_xprt, ep);
+ if (IS_ERR(id)) {
+ kfree(ep);
+ return PTR_ERR(id);
+ }
+ __module_get(THIS_MODULE);
+ device = id->device;
+ ep->re_id = id;
+ reinit_completion(&ep->re_done);
+
+ ep->re_max_requests = r_xprt->rx_xprt.max_reqs;
+ ep->re_inline_send = xprt_rdma_max_inline_write;
+ ep->re_inline_recv = xprt_rdma_max_inline_read;
+ rc = frwr_query_device(ep, device);
+ if (rc)
+ goto out_destroy;
+
+ r_xprt->rx_buf.rb_max_requests = cpu_to_be32(ep->re_max_requests);
+
+ ep->re_attr.srq = NULL;
+ ep->re_attr.cap.max_inline_data = 0;
+ ep->re_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
+ ep->re_attr.qp_type = IB_QPT_RC;
+ ep->re_attr.port_num = ~0;
+
+ ep->re_send_batch = ep->re_max_requests >> 3;
+ ep->re_send_count = ep->re_send_batch;
+ init_waitqueue_head(&ep->re_connect_wait);
+
+ ep->re_attr.send_cq = ib_alloc_cq_any(device, r_xprt,
+ ep->re_attr.cap.max_send_wr,
+ IB_POLL_WORKQUEUE);
+ if (IS_ERR(ep->re_attr.send_cq)) {
+ rc = PTR_ERR(ep->re_attr.send_cq);
+ ep->re_attr.send_cq = NULL;
+ goto out_destroy;
+ }
+
+ ep->re_attr.recv_cq = ib_alloc_cq_any(device, r_xprt,
+ ep->re_attr.cap.max_recv_wr,
+ IB_POLL_WORKQUEUE);
+ if (IS_ERR(ep->re_attr.recv_cq)) {
+ rc = PTR_ERR(ep->re_attr.recv_cq);
+ ep->re_attr.recv_cq = NULL;
+ goto out_destroy;
+ }
+ ep->re_receive_count = 0;
+
+ /* Initialize cma parameters */
+ memset(&ep->re_remote_cma, 0, sizeof(ep->re_remote_cma));
+
+ /* Prepare RDMA-CM private message */
+ pmsg = &ep->re_cm_private;
+ pmsg->cp_magic = rpcrdma_cmp_magic;
+ pmsg->cp_version = RPCRDMA_CMP_VERSION;
+ pmsg->cp_flags |= RPCRDMA_CMP_F_SND_W_INV_OK;
+ pmsg->cp_send_size = rpcrdma_encode_buffer_size(ep->re_inline_send);
+ pmsg->cp_recv_size = rpcrdma_encode_buffer_size(ep->re_inline_recv);
+ ep->re_remote_cma.private_data = pmsg;
+ ep->re_remote_cma.private_data_len = sizeof(*pmsg);
+
+ /* Client offers RDMA Read but does not initiate */
+ ep->re_remote_cma.initiator_depth = 0;
+ ep->re_remote_cma.responder_resources =
+ min_t(int, U8_MAX, device->attrs.max_qp_rd_atom);
+
+ /* Limit transport retries so client can detect server
+ * GID changes quickly. RPC layer handles re-establishing
+ * transport connection and retransmission.
+ */
+ ep->re_remote_cma.retry_count = 6;
+
+ /* RPC-over-RDMA handles its own flow control. In addition,
+ * make all RNR NAKs visible so we know that RPC-over-RDMA
+ * flow control is working correctly (no NAKs should be seen).
+ */
+ ep->re_remote_cma.flow_control = 0;
+ ep->re_remote_cma.rnr_retry_count = 0;
+
+ ep->re_pd = ib_alloc_pd(device, 0);
+ if (IS_ERR(ep->re_pd)) {
+ rc = PTR_ERR(ep->re_pd);
+ ep->re_pd = NULL;
+ goto out_destroy;
+ }
+
+ rc = rdma_create_qp(id, ep->re_pd, &ep->re_attr);
+ if (rc)
+ goto out_destroy;
+
+ r_xprt->rx_ep = ep;
+ return 0;
+
+out_destroy:
+ rpcrdma_ep_put(ep);
+ rdma_destroy_id(id);
+ return rc;
+}
+
+/**
+ * rpcrdma_xprt_connect - Connect an unconnected transport
+ * @r_xprt: controlling transport instance
+ *
+ * Returns 0 on success or a negative errno.
+ */
+int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpc_xprt *xprt = &r_xprt->rx_xprt;
+ struct rpcrdma_ep *ep;
+ int rc;
+
+ rc = rpcrdma_ep_create(r_xprt);
+ if (rc)
+ return rc;
+ ep = r_xprt->rx_ep;
+
+ xprt_clear_connected(xprt);
+ rpcrdma_reset_cwnd(r_xprt);
+
+ /* Bump the ep's reference count while there are
+ * outstanding Receives.
+ */
+ rpcrdma_ep_get(ep);
+ rpcrdma_post_recvs(r_xprt, 1, true);
+
+ rc = rdma_connect(ep->re_id, &ep->re_remote_cma);
+ if (rc)
+ goto out;
+
+ if (xprt->reestablish_timeout < RPCRDMA_INIT_REEST_TO)
+ xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO;
+ wait_event_interruptible(ep->re_connect_wait,
+ ep->re_connect_status != 0);
+ if (ep->re_connect_status <= 0) {
+ rc = ep->re_connect_status;
+ goto out;
+ }
+
+ rc = rpcrdma_sendctxs_create(r_xprt);
+ if (rc) {
+ rc = -ENOTCONN;
+ goto out;
+ }
+
+ rc = rpcrdma_reqs_setup(r_xprt);
+ if (rc) {
+ rc = -ENOTCONN;
+ goto out;
+ }
+ rpcrdma_mrs_create(r_xprt);
+ frwr_wp_create(r_xprt);
+
+out:
+ trace_xprtrdma_connect(r_xprt, rc);
+ return rc;
+}
+
+/**
+ * rpcrdma_xprt_disconnect - Disconnect underlying transport
+ * @r_xprt: controlling transport instance
+ *
+ * Caller serializes. Either the transport send lock is held,
+ * or we're being called to destroy the transport.
+ *
+ * On return, @r_xprt is completely divested of all hardware
+ * resources and prepared for the next ->connect operation.
+ */
+void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct rdma_cm_id *id;
+ int rc;
+
+ if (!ep)
+ return;
+
+ id = ep->re_id;
+ rc = rdma_disconnect(id);
+ trace_xprtrdma_disconnect(r_xprt, rc);
+
+ rpcrdma_xprt_drain(r_xprt);
+ rpcrdma_reps_unmap(r_xprt);
+ rpcrdma_reqs_reset(r_xprt);
+ rpcrdma_mrs_destroy(r_xprt);
+ rpcrdma_sendctxs_destroy(r_xprt);
+
+ if (rpcrdma_ep_put(ep))
+ rdma_destroy_id(id);
+
+ r_xprt->rx_ep = NULL;
+}
+
+/* Fixed-size circular FIFO queue. This implementation is wait-free and
+ * lock-free.
+ *
+ * Consumer is the code path that posts Sends. This path dequeues a
+ * sendctx for use by a Send operation. Multiple consumer threads
+ * are serialized by the RPC transport lock, which allows only one
+ * ->send_request call at a time.
+ *
+ * Producer is the code path that handles Send completions. This path
+ * enqueues a sendctx that has been completed. Multiple producer
+ * threads are serialized by the ib_poll_cq() function.
+ */
+
+/* rpcrdma_sendctxs_destroy() assumes caller has already quiesced
+ * queue activity, and rpcrdma_xprt_drain has flushed all remaining
+ * Send requests.
+ */
+static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ unsigned long i;
+
+ if (!buf->rb_sc_ctxs)
+ return;
+ for (i = 0; i <= buf->rb_sc_last; i++)
+ kfree(buf->rb_sc_ctxs[i]);
+ kfree(buf->rb_sc_ctxs);
+ buf->rb_sc_ctxs = NULL;
+}
+
+static struct rpcrdma_sendctx *rpcrdma_sendctx_create(struct rpcrdma_ep *ep)
+{
+ struct rpcrdma_sendctx *sc;
+
+ sc = kzalloc(struct_size(sc, sc_sges, ep->re_attr.cap.max_send_sge),
+ XPRTRDMA_GFP_FLAGS);
+ if (!sc)
+ return NULL;
+
+ sc->sc_cqe.done = rpcrdma_wc_send;
+ sc->sc_cid.ci_queue_id = ep->re_attr.send_cq->res.id;
+ sc->sc_cid.ci_completion_id =
+ atomic_inc_return(&ep->re_completion_ids);
+ return sc;
+}
+
+static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_sendctx *sc;
+ unsigned long i;
+
+ /* Maximum number of concurrent outstanding Send WRs. Capping
+ * the circular queue size stops Send Queue overflow by causing
+ * the ->send_request call to fail temporarily before too many
+ * Sends are posted.
+ */
+ i = r_xprt->rx_ep->re_max_requests + RPCRDMA_MAX_BC_REQUESTS;
+ buf->rb_sc_ctxs = kcalloc(i, sizeof(sc), XPRTRDMA_GFP_FLAGS);
+ if (!buf->rb_sc_ctxs)
+ return -ENOMEM;
+
+ buf->rb_sc_last = i - 1;
+ for (i = 0; i <= buf->rb_sc_last; i++) {
+ sc = rpcrdma_sendctx_create(r_xprt->rx_ep);
+ if (!sc)
+ return -ENOMEM;
+
+ buf->rb_sc_ctxs[i] = sc;
+ }
+
+ buf->rb_sc_head = 0;
+ buf->rb_sc_tail = 0;
+ return 0;
+}
+
+/* The sendctx queue is not guaranteed to have a size that is a
+ * power of two, thus the helpers in circ_buf.h cannot be used.
+ * The other option is to use modulus (%), which can be expensive.
+ */
+static unsigned long rpcrdma_sendctx_next(struct rpcrdma_buffer *buf,
+ unsigned long item)
+{
+ return likely(item < buf->rb_sc_last) ? item + 1 : 0;
+}
+
+/**
+ * rpcrdma_sendctx_get_locked - Acquire a send context
+ * @r_xprt: controlling transport instance
+ *
+ * Returns pointer to a free send completion context; or NULL if
+ * the queue is empty.
+ *
+ * Usage: Called to acquire an SGE array before preparing a Send WR.
+ *
+ * The caller serializes calls to this function (per transport), and
+ * provides an effective memory barrier that flushes the new value
+ * of rb_sc_head.
+ */
+struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_sendctx *sc;
+ unsigned long next_head;
+
+ next_head = rpcrdma_sendctx_next(buf, buf->rb_sc_head);
+
+ if (next_head == READ_ONCE(buf->rb_sc_tail))
+ goto out_emptyq;
+
+ /* ORDER: item must be accessed _before_ head is updated */
+ sc = buf->rb_sc_ctxs[next_head];
+
+ /* Releasing the lock in the caller acts as a memory
+ * barrier that flushes rb_sc_head.
+ */
+ buf->rb_sc_head = next_head;
+
+ return sc;
+
+out_emptyq:
+ /* The queue is "empty" if there have not been enough Send
+ * completions recently. This is a sign the Send Queue is
+ * backing up. Cause the caller to pause and try again.
+ */
+ xprt_wait_for_buffer_space(&r_xprt->rx_xprt);
+ r_xprt->rx_stats.empty_sendctx_q++;
+ return NULL;
+}
+
+/**
+ * rpcrdma_sendctx_put_locked - Release a send context
+ * @r_xprt: controlling transport instance
+ * @sc: send context to release
+ *
+ * Usage: Called from Send completion to return a sendctxt
+ * to the queue.
+ *
+ * The caller serializes calls to this function (per transport).
+ */
+static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_sendctx *sc)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ unsigned long next_tail;
+
+ /* Unmap SGEs of previously completed but unsignaled
+ * Sends by walking up the queue until @sc is found.
+ */
+ next_tail = buf->rb_sc_tail;
+ do {
+ next_tail = rpcrdma_sendctx_next(buf, next_tail);
+
+ /* ORDER: item must be accessed _before_ tail is updated */
+ rpcrdma_sendctx_unmap(buf->rb_sc_ctxs[next_tail]);
+
+ } while (buf->rb_sc_ctxs[next_tail] != sc);
+
+ /* Paired with READ_ONCE */
+ smp_store_release(&buf->rb_sc_tail, next_tail);
+
+ xprt_write_space(&r_xprt->rx_xprt);
+}
+
+static void
+rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct ib_device *device = ep->re_id->device;
+ unsigned int count;
+
+ /* Try to allocate enough to perform one full-sized I/O */
+ for (count = 0; count < ep->re_max_rdma_segs; count++) {
+ struct rpcrdma_mr *mr;
+ int rc;
+
+ mr = kzalloc_node(sizeof(*mr), XPRTRDMA_GFP_FLAGS,
+ ibdev_to_node(device));
+ if (!mr)
+ break;
+
+ rc = frwr_mr_init(r_xprt, mr);
+ if (rc) {
+ kfree(mr);
+ break;
+ }
+
+ spin_lock(&buf->rb_lock);
+ rpcrdma_mr_push(mr, &buf->rb_mrs);
+ list_add(&mr->mr_all, &buf->rb_all_mrs);
+ spin_unlock(&buf->rb_lock);
+ }
+
+ r_xprt->rx_stats.mrs_allocated += count;
+ trace_xprtrdma_createmrs(r_xprt, count);
+}
+
+static void
+rpcrdma_mr_refresh_worker(struct work_struct *work)
+{
+ struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer,
+ rb_refresh_worker);
+ struct rpcrdma_xprt *r_xprt = container_of(buf, struct rpcrdma_xprt,
+ rx_buf);
+
+ rpcrdma_mrs_create(r_xprt);
+ xprt_write_space(&r_xprt->rx_xprt);
+}
+
+/**
+ * rpcrdma_mrs_refresh - Wake the MR refresh worker
+ * @r_xprt: controlling transport instance
+ *
+ */
+void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+
+ /* If there is no underlying connection, it's no use
+ * to wake the refresh worker.
+ */
+ if (ep->re_connect_status != 1)
+ return;
+ queue_work(system_highpri_wq, &buf->rb_refresh_worker);
+}
+
+/**
+ * rpcrdma_req_create - Allocate an rpcrdma_req object
+ * @r_xprt: controlling r_xprt
+ * @size: initial size, in bytes, of send and receive buffers
+ *
+ * Returns an allocated and fully initialized rpcrdma_req or NULL.
+ */
+struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt,
+ size_t size)
+{
+ struct rpcrdma_buffer *buffer = &r_xprt->rx_buf;
+ struct rpcrdma_req *req;
+
+ req = kzalloc(sizeof(*req), XPRTRDMA_GFP_FLAGS);
+ if (req == NULL)
+ goto out1;
+
+ req->rl_sendbuf = rpcrdma_regbuf_alloc(size, DMA_TO_DEVICE);
+ if (!req->rl_sendbuf)
+ goto out2;
+
+ req->rl_recvbuf = rpcrdma_regbuf_alloc(size, DMA_NONE);
+ if (!req->rl_recvbuf)
+ goto out3;
+
+ INIT_LIST_HEAD(&req->rl_free_mrs);
+ INIT_LIST_HEAD(&req->rl_registered);
+ spin_lock(&buffer->rb_lock);
+ list_add(&req->rl_all, &buffer->rb_allreqs);
+ spin_unlock(&buffer->rb_lock);
+ return req;
+
+out3:
+ rpcrdma_regbuf_free(req->rl_sendbuf);
+out2:
+ kfree(req);
+out1:
+ return NULL;
+}
+
+/**
+ * rpcrdma_req_setup - Per-connection instance setup of an rpcrdma_req object
+ * @r_xprt: controlling transport instance
+ * @req: rpcrdma_req object to set up
+ *
+ * Returns zero on success, and a negative errno on failure.
+ */
+int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req)
+{
+ struct rpcrdma_regbuf *rb;
+ size_t maxhdrsize;
+
+ /* Compute maximum header buffer size in bytes */
+ maxhdrsize = rpcrdma_fixed_maxsz + 3 +
+ r_xprt->rx_ep->re_max_rdma_segs * rpcrdma_readchunk_maxsz;
+ maxhdrsize *= sizeof(__be32);
+ rb = rpcrdma_regbuf_alloc(__roundup_pow_of_two(maxhdrsize),
+ DMA_TO_DEVICE);
+ if (!rb)
+ goto out;
+
+ if (!__rpcrdma_regbuf_dma_map(r_xprt, rb))
+ goto out_free;
+
+ req->rl_rdmabuf = rb;
+ xdr_buf_init(&req->rl_hdrbuf, rdmab_data(rb), rdmab_length(rb));
+ return 0;
+
+out_free:
+ rpcrdma_regbuf_free(rb);
+out:
+ return -ENOMEM;
+}
+
+/* ASSUMPTION: the rb_allreqs list is stable for the duration,
+ * and thus can be walked without holding rb_lock. Eg. the
+ * caller is holding the transport send lock to exclude
+ * device removal or disconnection.
+ */
+static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_req *req;
+ int rc;
+
+ list_for_each_entry(req, &buf->rb_allreqs, rl_all) {
+ rc = rpcrdma_req_setup(r_xprt, req);
+ if (rc)
+ return rc;
+ }
+ return 0;
+}
+
+static void rpcrdma_req_reset(struct rpcrdma_req *req)
+{
+ /* Credits are valid for only one connection */
+ req->rl_slot.rq_cong = 0;
+
+ rpcrdma_regbuf_free(req->rl_rdmabuf);
+ req->rl_rdmabuf = NULL;
+
+ rpcrdma_regbuf_dma_unmap(req->rl_sendbuf);
+ rpcrdma_regbuf_dma_unmap(req->rl_recvbuf);
+
+ frwr_reset(req);
+}
+
+/* ASSUMPTION: the rb_allreqs list is stable for the duration,
+ * and thus can be walked without holding rb_lock. Eg. the
+ * caller is holding the transport send lock to exclude
+ * device removal or disconnection.
+ */
+static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_req *req;
+
+ list_for_each_entry(req, &buf->rb_allreqs, rl_all)
+ rpcrdma_req_reset(req);
+}
+
+static noinline
+struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt,
+ bool temp)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_rep *rep;
+
+ rep = kzalloc(sizeof(*rep), XPRTRDMA_GFP_FLAGS);
+ if (rep == NULL)
+ goto out;
+
+ rep->rr_rdmabuf = rpcrdma_regbuf_alloc(r_xprt->rx_ep->re_inline_recv,
+ DMA_FROM_DEVICE);
+ if (!rep->rr_rdmabuf)
+ goto out_free;
+
+ rep->rr_cid.ci_completion_id =
+ atomic_inc_return(&r_xprt->rx_ep->re_completion_ids);
+
+ xdr_buf_init(&rep->rr_hdrbuf, rdmab_data(rep->rr_rdmabuf),
+ rdmab_length(rep->rr_rdmabuf));
+ rep->rr_cqe.done = rpcrdma_wc_receive;
+ rep->rr_rxprt = r_xprt;
+ rep->rr_recv_wr.next = NULL;
+ rep->rr_recv_wr.wr_cqe = &rep->rr_cqe;
+ rep->rr_recv_wr.sg_list = &rep->rr_rdmabuf->rg_iov;
+ rep->rr_recv_wr.num_sge = 1;
+ rep->rr_temp = temp;
+
+ spin_lock(&buf->rb_lock);
+ list_add(&rep->rr_all, &buf->rb_all_reps);
+ spin_unlock(&buf->rb_lock);
+ return rep;
+
+out_free:
+ kfree(rep);
+out:
+ return NULL;
+}
+
+static void rpcrdma_rep_free(struct rpcrdma_rep *rep)
+{
+ rpcrdma_regbuf_free(rep->rr_rdmabuf);
+ kfree(rep);
+}
+
+static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep)
+{
+ struct rpcrdma_buffer *buf = &rep->rr_rxprt->rx_buf;
+
+ spin_lock(&buf->rb_lock);
+ list_del(&rep->rr_all);
+ spin_unlock(&buf->rb_lock);
+
+ rpcrdma_rep_free(rep);
+}
+
+static struct rpcrdma_rep *rpcrdma_rep_get_locked(struct rpcrdma_buffer *buf)
+{
+ struct llist_node *node;
+
+ /* Calls to llist_del_first are required to be serialized */
+ node = llist_del_first(&buf->rb_free_reps);
+ if (!node)
+ return NULL;
+ return llist_entry(node, struct rpcrdma_rep, rr_node);
+}
+
+/**
+ * rpcrdma_rep_put - Release rpcrdma_rep back to free list
+ * @buf: buffer pool
+ * @rep: rep to release
+ *
+ */
+void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep)
+{
+ llist_add(&rep->rr_node, &buf->rb_free_reps);
+}
+
+/* Caller must ensure the QP is quiescent (RQ is drained) before
+ * invoking this function, to guarantee rb_all_reps is not
+ * changing.
+ */
+static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_rep *rep;
+
+ list_for_each_entry(rep, &buf->rb_all_reps, rr_all) {
+ rpcrdma_regbuf_dma_unmap(rep->rr_rdmabuf);
+ rep->rr_temp = true; /* Mark this rep for destruction */
+ }
+}
+
+static void rpcrdma_reps_destroy(struct rpcrdma_buffer *buf)
+{
+ struct rpcrdma_rep *rep;
+
+ spin_lock(&buf->rb_lock);
+ while ((rep = list_first_entry_or_null(&buf->rb_all_reps,
+ struct rpcrdma_rep,
+ rr_all)) != NULL) {
+ list_del(&rep->rr_all);
+ spin_unlock(&buf->rb_lock);
+
+ rpcrdma_rep_free(rep);
+
+ spin_lock(&buf->rb_lock);
+ }
+ spin_unlock(&buf->rb_lock);
+}
+
+/**
+ * rpcrdma_buffer_create - Create initial set of req/rep objects
+ * @r_xprt: transport instance to (re)initialize
+ *
+ * Returns zero on success, otherwise a negative errno.
+ */
+int rpcrdma_buffer_create(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ int i, rc;
+
+ buf->rb_bc_srv_max_requests = 0;
+ spin_lock_init(&buf->rb_lock);
+ INIT_LIST_HEAD(&buf->rb_mrs);
+ INIT_LIST_HEAD(&buf->rb_all_mrs);
+ INIT_WORK(&buf->rb_refresh_worker, rpcrdma_mr_refresh_worker);
+
+ INIT_LIST_HEAD(&buf->rb_send_bufs);
+ INIT_LIST_HEAD(&buf->rb_allreqs);
+ INIT_LIST_HEAD(&buf->rb_all_reps);
+
+ rc = -ENOMEM;
+ for (i = 0; i < r_xprt->rx_xprt.max_reqs; i++) {
+ struct rpcrdma_req *req;
+
+ req = rpcrdma_req_create(r_xprt,
+ RPCRDMA_V1_DEF_INLINE_SIZE * 2);
+ if (!req)
+ goto out;
+ list_add(&req->rl_list, &buf->rb_send_bufs);
+ }
+
+ init_llist_head(&buf->rb_free_reps);
+
+ return 0;
+out:
+ rpcrdma_buffer_destroy(buf);
+ return rc;
+}
+
+/**
+ * rpcrdma_req_destroy - Destroy an rpcrdma_req object
+ * @req: unused object to be destroyed
+ *
+ * Relies on caller holding the transport send lock to protect
+ * removing req->rl_all from buf->rb_all_reqs safely.
+ */
+void rpcrdma_req_destroy(struct rpcrdma_req *req)
+{
+ struct rpcrdma_mr *mr;
+
+ list_del(&req->rl_all);
+
+ while ((mr = rpcrdma_mr_pop(&req->rl_free_mrs))) {
+ struct rpcrdma_buffer *buf = &mr->mr_xprt->rx_buf;
+
+ spin_lock(&buf->rb_lock);
+ list_del(&mr->mr_all);
+ spin_unlock(&buf->rb_lock);
+
+ frwr_mr_release(mr);
+ }
+
+ rpcrdma_regbuf_free(req->rl_recvbuf);
+ rpcrdma_regbuf_free(req->rl_sendbuf);
+ rpcrdma_regbuf_free(req->rl_rdmabuf);
+ kfree(req);
+}
+
+/**
+ * rpcrdma_mrs_destroy - Release all of a transport's MRs
+ * @r_xprt: controlling transport instance
+ *
+ * Relies on caller holding the transport send lock to protect
+ * removing mr->mr_list from req->rl_free_mrs safely.
+ */
+static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_mr *mr;
+
+ cancel_work_sync(&buf->rb_refresh_worker);
+
+ spin_lock(&buf->rb_lock);
+ while ((mr = list_first_entry_or_null(&buf->rb_all_mrs,
+ struct rpcrdma_mr,
+ mr_all)) != NULL) {
+ list_del(&mr->mr_list);
+ list_del(&mr->mr_all);
+ spin_unlock(&buf->rb_lock);
+
+ frwr_mr_release(mr);
+
+ spin_lock(&buf->rb_lock);
+ }
+ spin_unlock(&buf->rb_lock);
+}
+
+/**
+ * rpcrdma_buffer_destroy - Release all hw resources
+ * @buf: root control block for resources
+ *
+ * ORDERING: relies on a prior rpcrdma_xprt_drain :
+ * - No more Send or Receive completions can occur
+ * - All MRs, reps, and reqs are returned to their free lists
+ */
+void
+rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf)
+{
+ rpcrdma_reps_destroy(buf);
+
+ while (!list_empty(&buf->rb_send_bufs)) {
+ struct rpcrdma_req *req;
+
+ req = list_first_entry(&buf->rb_send_bufs,
+ struct rpcrdma_req, rl_list);
+ list_del(&req->rl_list);
+ rpcrdma_req_destroy(req);
+ }
+}
+
+/**
+ * rpcrdma_mr_get - Allocate an rpcrdma_mr object
+ * @r_xprt: controlling transport
+ *
+ * Returns an initialized rpcrdma_mr or NULL if no free
+ * rpcrdma_mr objects are available.
+ */
+struct rpcrdma_mr *
+rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_mr *mr;
+
+ spin_lock(&buf->rb_lock);
+ mr = rpcrdma_mr_pop(&buf->rb_mrs);
+ spin_unlock(&buf->rb_lock);
+ return mr;
+}
+
+/**
+ * rpcrdma_reply_put - Put reply buffers back into pool
+ * @buffers: buffer pool
+ * @req: object to return
+ *
+ */
+void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req)
+{
+ if (req->rl_reply) {
+ rpcrdma_rep_put(buffers, req->rl_reply);
+ req->rl_reply = NULL;
+ }
+}
+
+/**
+ * rpcrdma_buffer_get - Get a request buffer
+ * @buffers: Buffer pool from which to obtain a buffer
+ *
+ * Returns a fresh rpcrdma_req, or NULL if none are available.
+ */
+struct rpcrdma_req *
+rpcrdma_buffer_get(struct rpcrdma_buffer *buffers)
+{
+ struct rpcrdma_req *req;
+
+ spin_lock(&buffers->rb_lock);
+ req = list_first_entry_or_null(&buffers->rb_send_bufs,
+ struct rpcrdma_req, rl_list);
+ if (req)
+ list_del_init(&req->rl_list);
+ spin_unlock(&buffers->rb_lock);
+ return req;
+}
+
+/**
+ * rpcrdma_buffer_put - Put request/reply buffers back into pool
+ * @buffers: buffer pool
+ * @req: object to return
+ *
+ */
+void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req)
+{
+ rpcrdma_reply_put(buffers, req);
+
+ spin_lock(&buffers->rb_lock);
+ list_add(&req->rl_list, &buffers->rb_send_bufs);
+ spin_unlock(&buffers->rb_lock);
+}
+
+/* Returns a pointer to a rpcrdma_regbuf object, or NULL.
+ *
+ * xprtrdma uses a regbuf for posting an outgoing RDMA SEND, or for
+ * receiving the payload of RDMA RECV operations. During Long Calls
+ * or Replies they may be registered externally via frwr_map.
+ */
+static struct rpcrdma_regbuf *
+rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction)
+{
+ struct rpcrdma_regbuf *rb;
+
+ rb = kmalloc(sizeof(*rb), XPRTRDMA_GFP_FLAGS);
+ if (!rb)
+ return NULL;
+ rb->rg_data = kmalloc(size, XPRTRDMA_GFP_FLAGS);
+ if (!rb->rg_data) {
+ kfree(rb);
+ return NULL;
+ }
+
+ rb->rg_device = NULL;
+ rb->rg_direction = direction;
+ rb->rg_iov.length = size;
+ return rb;
+}
+
+/**
+ * rpcrdma_regbuf_realloc - re-allocate a SEND/RECV buffer
+ * @rb: regbuf to reallocate
+ * @size: size of buffer to be allocated, in bytes
+ * @flags: GFP flags
+ *
+ * Returns true if reallocation was successful. If false is
+ * returned, @rb is left untouched.
+ */
+bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size, gfp_t flags)
+{
+ void *buf;
+
+ buf = kmalloc(size, flags);
+ if (!buf)
+ return false;
+
+ rpcrdma_regbuf_dma_unmap(rb);
+ kfree(rb->rg_data);
+
+ rb->rg_data = buf;
+ rb->rg_iov.length = size;
+ return true;
+}
+
+/**
+ * __rpcrdma_regbuf_dma_map - DMA-map a regbuf
+ * @r_xprt: controlling transport instance
+ * @rb: regbuf to be mapped
+ *
+ * Returns true if the buffer is now DMA mapped to @r_xprt's device
+ */
+bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_regbuf *rb)
+{
+ struct ib_device *device = r_xprt->rx_ep->re_id->device;
+
+ if (rb->rg_direction == DMA_NONE)
+ return false;
+
+ rb->rg_iov.addr = ib_dma_map_single(device, rdmab_data(rb),
+ rdmab_length(rb), rb->rg_direction);
+ if (ib_dma_mapping_error(device, rdmab_addr(rb))) {
+ trace_xprtrdma_dma_maperr(rdmab_addr(rb));
+ return false;
+ }
+
+ rb->rg_device = device;
+ rb->rg_iov.lkey = r_xprt->rx_ep->re_pd->local_dma_lkey;
+ return true;
+}
+
+static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb)
+{
+ if (!rb)
+ return;
+
+ if (!rpcrdma_regbuf_is_mapped(rb))
+ return;
+
+ ib_dma_unmap_single(rb->rg_device, rdmab_addr(rb), rdmab_length(rb),
+ rb->rg_direction);
+ rb->rg_device = NULL;
+}
+
+static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb)
+{
+ rpcrdma_regbuf_dma_unmap(rb);
+ if (rb)
+ kfree(rb->rg_data);
+ kfree(rb);
+}
+
+/**
+ * rpcrdma_post_recvs - Refill the Receive Queue
+ * @r_xprt: controlling transport instance
+ * @needed: current credit grant
+ * @temp: mark Receive buffers to be deleted after one use
+ *
+ */
+void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp)
+{
+ struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
+ struct rpcrdma_ep *ep = r_xprt->rx_ep;
+ struct ib_recv_wr *wr, *bad_wr;
+ struct rpcrdma_rep *rep;
+ int count, rc;
+
+ rc = 0;
+ count = 0;
+
+ if (likely(ep->re_receive_count > needed))
+ goto out;
+ needed -= ep->re_receive_count;
+ if (!temp)
+ needed += RPCRDMA_MAX_RECV_BATCH;
+
+ if (atomic_inc_return(&ep->re_receiving) > 1)
+ goto out;
+
+ /* fast path: all needed reps can be found on the free list */
+ wr = NULL;
+ while (needed) {
+ rep = rpcrdma_rep_get_locked(buf);
+ if (rep && rep->rr_temp) {
+ rpcrdma_rep_destroy(rep);
+ continue;
+ }
+ if (!rep)
+ rep = rpcrdma_rep_create(r_xprt, temp);
+ if (!rep)
+ break;
+ if (!rpcrdma_regbuf_dma_map(r_xprt, rep->rr_rdmabuf)) {
+ rpcrdma_rep_put(buf, rep);
+ break;
+ }
+
+ rep->rr_cid.ci_queue_id = ep->re_attr.recv_cq->res.id;
+ trace_xprtrdma_post_recv(rep);
+ rep->rr_recv_wr.next = wr;
+ wr = &rep->rr_recv_wr;
+ --needed;
+ ++count;
+ }
+ if (!wr)
+ goto out;
+
+ rc = ib_post_recv(ep->re_id->qp, wr,
+ (const struct ib_recv_wr **)&bad_wr);
+ if (rc) {
+ trace_xprtrdma_post_recvs_err(r_xprt, rc);
+ for (wr = bad_wr; wr;) {
+ struct rpcrdma_rep *rep;
+
+ rep = container_of(wr, struct rpcrdma_rep, rr_recv_wr);
+ wr = wr->next;
+ rpcrdma_rep_put(buf, rep);
+ --count;
+ }
+ }
+ if (atomic_dec_return(&ep->re_receiving) > 0)
+ complete(&ep->re_done);
+
+out:
+ trace_xprtrdma_post_recvs(r_xprt, count);
+ ep->re_receive_count += count;
+ return;
+}
diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h
new file mode 100644
index 0000000000..da409450df
--- /dev/null
+++ b/net/sunrpc/xprtrdma/xprt_rdma.h
@@ -0,0 +1,604 @@
+/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */
+/*
+ * Copyright (c) 2014-2017 Oracle. All rights reserved.
+ * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the BSD-type
+ * license below:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * Neither the name of the Network Appliance, Inc. nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef _LINUX_SUNRPC_XPRT_RDMA_H
+#define _LINUX_SUNRPC_XPRT_RDMA_H
+
+#include <linux/wait.h> /* wait_queue_head_t, etc */
+#include <linux/spinlock.h> /* spinlock_t, etc */
+#include <linux/atomic.h> /* atomic_t, etc */
+#include <linux/kref.h> /* struct kref */
+#include <linux/workqueue.h> /* struct work_struct */
+#include <linux/llist.h>
+
+#include <rdma/rdma_cm.h> /* RDMA connection api */
+#include <rdma/ib_verbs.h> /* RDMA verbs api */
+
+#include <linux/sunrpc/clnt.h> /* rpc_xprt */
+#include <linux/sunrpc/rpc_rdma_cid.h> /* completion IDs */
+#include <linux/sunrpc/rpc_rdma.h> /* RPC/RDMA protocol */
+#include <linux/sunrpc/xprtrdma.h> /* xprt parameters */
+
+#define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */
+#define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */
+
+#define RPCRDMA_BIND_TO (60U * HZ)
+#define RPCRDMA_INIT_REEST_TO (5U * HZ)
+#define RPCRDMA_MAX_REEST_TO (30U * HZ)
+#define RPCRDMA_IDLE_DISC_TO (5U * 60 * HZ)
+
+/*
+ * RDMA Endpoint -- connection endpoint details
+ */
+struct rpcrdma_mr;
+struct rpcrdma_ep {
+ struct kref re_kref;
+ struct rdma_cm_id *re_id;
+ struct ib_pd *re_pd;
+ unsigned int re_max_rdma_segs;
+ unsigned int re_max_fr_depth;
+ struct rpcrdma_mr *re_write_pad_mr;
+ enum ib_mr_type re_mrtype;
+ struct completion re_done;
+ unsigned int re_send_count;
+ unsigned int re_send_batch;
+ unsigned int re_max_inline_send;
+ unsigned int re_max_inline_recv;
+ int re_async_rc;
+ int re_connect_status;
+ atomic_t re_receiving;
+ atomic_t re_force_disconnect;
+ struct ib_qp_init_attr re_attr;
+ wait_queue_head_t re_connect_wait;
+ struct rpc_xprt *re_xprt;
+ struct rpcrdma_connect_private
+ re_cm_private;
+ struct rdma_conn_param re_remote_cma;
+ int re_receive_count;
+ unsigned int re_max_requests; /* depends on device */
+ unsigned int re_inline_send; /* negotiated */
+ unsigned int re_inline_recv; /* negotiated */
+
+ atomic_t re_completion_ids;
+
+ char re_write_pad[XDR_UNIT];
+};
+
+/* Pre-allocate extra Work Requests for handling reverse-direction
+ * Receives and Sends. This is a fixed value because the Work Queues
+ * are allocated when the forward channel is set up, long before the
+ * backchannel is provisioned. This value is two times
+ * NFS4_DEF_CB_SLOT_TABLE_SIZE.
+ */
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+#define RPCRDMA_BACKWARD_WRS (32)
+#else
+#define RPCRDMA_BACKWARD_WRS (0)
+#endif
+
+/* Registered buffer -- registered kmalloc'd memory for RDMA SEND/RECV
+ */
+
+struct rpcrdma_regbuf {
+ struct ib_sge rg_iov;
+ struct ib_device *rg_device;
+ enum dma_data_direction rg_direction;
+ void *rg_data;
+};
+
+static inline u64 rdmab_addr(struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_iov.addr;
+}
+
+static inline u32 rdmab_length(struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_iov.length;
+}
+
+static inline u32 rdmab_lkey(struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_iov.lkey;
+}
+
+static inline struct ib_device *rdmab_device(struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_device;
+}
+
+static inline void *rdmab_data(const struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_data;
+}
+
+/* Do not use emergency memory reserves, and fail quickly if memory
+ * cannot be allocated easily. These flags may be used wherever there
+ * is robust logic to handle a failure to allocate.
+ */
+#define XPRTRDMA_GFP_FLAGS (__GFP_NOMEMALLOC | __GFP_NORETRY | __GFP_NOWARN)
+
+/* To ensure a transport can always make forward progress,
+ * the number of RDMA segments allowed in header chunk lists
+ * is capped at 16. This prevents less-capable devices from
+ * overrunning the Send buffer while building chunk lists.
+ *
+ * Elements of the Read list take up more room than the
+ * Write list or Reply chunk. 16 read segments means the
+ * chunk lists cannot consume more than
+ *
+ * ((16 + 2) * read segment size) + 1 XDR words,
+ *
+ * or about 400 bytes. The fixed part of the header is
+ * another 24 bytes. Thus when the inline threshold is
+ * 1024 bytes, at least 600 bytes are available for RPC
+ * message bodies.
+ */
+enum {
+ RPCRDMA_MAX_HDR_SEGS = 16,
+};
+
+/*
+ * struct rpcrdma_rep -- this structure encapsulates state required
+ * to receive and complete an RPC Reply, asychronously. It needs
+ * several pieces of state:
+ *
+ * o receive buffer and ib_sge (donated to provider)
+ * o status of receive (success or not, length, inv rkey)
+ * o bookkeeping state to get run by reply handler (XDR stream)
+ *
+ * These structures are allocated during transport initialization.
+ * N of these are associated with a transport instance, managed by
+ * struct rpcrdma_buffer. N is the max number of outstanding RPCs.
+ */
+
+struct rpcrdma_rep {
+ struct ib_cqe rr_cqe;
+ struct rpc_rdma_cid rr_cid;
+
+ __be32 rr_xid;
+ __be32 rr_vers;
+ __be32 rr_proc;
+ int rr_wc_flags;
+ u32 rr_inv_rkey;
+ bool rr_temp;
+ struct rpcrdma_regbuf *rr_rdmabuf;
+ struct rpcrdma_xprt *rr_rxprt;
+ struct rpc_rqst *rr_rqst;
+ struct xdr_buf rr_hdrbuf;
+ struct xdr_stream rr_stream;
+ struct llist_node rr_node;
+ struct ib_recv_wr rr_recv_wr;
+ struct list_head rr_all;
+};
+
+/* To reduce the rate at which a transport invokes ib_post_recv
+ * (and thus the hardware doorbell rate), xprtrdma posts Receive
+ * WRs in batches.
+ *
+ * Setting this to zero disables Receive post batching.
+ */
+enum {
+ RPCRDMA_MAX_RECV_BATCH = 7,
+};
+
+/* struct rpcrdma_sendctx - DMA mapped SGEs to unmap after Send completes
+ */
+struct rpcrdma_req;
+struct rpcrdma_sendctx {
+ struct ib_cqe sc_cqe;
+ struct rpc_rdma_cid sc_cid;
+ struct rpcrdma_req *sc_req;
+ unsigned int sc_unmap_count;
+ struct ib_sge sc_sges[];
+};
+
+/*
+ * struct rpcrdma_mr - external memory region metadata
+ *
+ * An external memory region is any buffer or page that is registered
+ * on the fly (ie, not pre-registered).
+ */
+struct rpcrdma_req;
+struct rpcrdma_mr {
+ struct list_head mr_list;
+ struct rpcrdma_req *mr_req;
+
+ struct ib_mr *mr_ibmr;
+ struct ib_device *mr_device;
+ struct scatterlist *mr_sg;
+ int mr_nents;
+ enum dma_data_direction mr_dir;
+ struct ib_cqe mr_cqe;
+ struct completion mr_linv_done;
+ union {
+ struct ib_reg_wr mr_regwr;
+ struct ib_send_wr mr_invwr;
+ };
+ struct rpcrdma_xprt *mr_xprt;
+ u32 mr_handle;
+ u32 mr_length;
+ u64 mr_offset;
+ struct list_head mr_all;
+ struct rpc_rdma_cid mr_cid;
+};
+
+/*
+ * struct rpcrdma_req -- structure central to the request/reply sequence.
+ *
+ * N of these are associated with a transport instance, and stored in
+ * struct rpcrdma_buffer. N is the max number of outstanding requests.
+ *
+ * It includes pre-registered buffer memory for send AND recv.
+ * The recv buffer, however, is not owned by this structure, and
+ * is "donated" to the hardware when a recv is posted. When a
+ * reply is handled, the recv buffer used is given back to the
+ * struct rpcrdma_req associated with the request.
+ *
+ * In addition to the basic memory, this structure includes an array
+ * of iovs for send operations. The reason is that the iovs passed to
+ * ib_post_{send,recv} must not be modified until the work request
+ * completes.
+ */
+
+/* Maximum number of page-sized "segments" per chunk list to be
+ * registered or invalidated. Must handle a Reply chunk:
+ */
+enum {
+ RPCRDMA_MAX_IOV_SEGS = 3,
+ RPCRDMA_MAX_DATA_SEGS = ((1 * 1024 * 1024) / PAGE_SIZE) + 1,
+ RPCRDMA_MAX_SEGS = RPCRDMA_MAX_DATA_SEGS +
+ RPCRDMA_MAX_IOV_SEGS,
+};
+
+/* Arguments for DMA mapping and registration */
+struct rpcrdma_mr_seg {
+ u32 mr_len; /* length of segment */
+ struct page *mr_page; /* underlying struct page */
+ u64 mr_offset; /* IN: page offset, OUT: iova */
+};
+
+/* The Send SGE array is provisioned to send a maximum size
+ * inline request:
+ * - RPC-over-RDMA header
+ * - xdr_buf head iovec
+ * - RPCRDMA_MAX_INLINE bytes, in pages
+ * - xdr_buf tail iovec
+ *
+ * The actual number of array elements consumed by each RPC
+ * depends on the device's max_sge limit.
+ */
+enum {
+ RPCRDMA_MIN_SEND_SGES = 3,
+ RPCRDMA_MAX_PAGE_SGES = RPCRDMA_MAX_INLINE >> PAGE_SHIFT,
+ RPCRDMA_MAX_SEND_SGES = 1 + 1 + RPCRDMA_MAX_PAGE_SGES + 1,
+};
+
+struct rpcrdma_buffer;
+struct rpcrdma_req {
+ struct list_head rl_list;
+ struct rpc_rqst rl_slot;
+ struct rpcrdma_rep *rl_reply;
+ struct xdr_stream rl_stream;
+ struct xdr_buf rl_hdrbuf;
+ struct ib_send_wr rl_wr;
+ struct rpcrdma_sendctx *rl_sendctx;
+ struct rpcrdma_regbuf *rl_rdmabuf; /* xprt header */
+ struct rpcrdma_regbuf *rl_sendbuf; /* rq_snd_buf */
+ struct rpcrdma_regbuf *rl_recvbuf; /* rq_rcv_buf */
+
+ struct list_head rl_all;
+ struct kref rl_kref;
+
+ struct list_head rl_free_mrs;
+ struct list_head rl_registered;
+ struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS];
+};
+
+static inline struct rpcrdma_req *
+rpcr_to_rdmar(const struct rpc_rqst *rqst)
+{
+ return container_of(rqst, struct rpcrdma_req, rl_slot);
+}
+
+static inline void
+rpcrdma_mr_push(struct rpcrdma_mr *mr, struct list_head *list)
+{
+ list_add(&mr->mr_list, list);
+}
+
+static inline struct rpcrdma_mr *
+rpcrdma_mr_pop(struct list_head *list)
+{
+ struct rpcrdma_mr *mr;
+
+ mr = list_first_entry_or_null(list, struct rpcrdma_mr, mr_list);
+ if (mr)
+ list_del_init(&mr->mr_list);
+ return mr;
+}
+
+/*
+ * struct rpcrdma_buffer -- holds list/queue of pre-registered memory for
+ * inline requests/replies, and client/server credits.
+ *
+ * One of these is associated with a transport instance
+ */
+struct rpcrdma_buffer {
+ spinlock_t rb_lock;
+ struct list_head rb_send_bufs;
+ struct list_head rb_mrs;
+
+ unsigned long rb_sc_head;
+ unsigned long rb_sc_tail;
+ unsigned long rb_sc_last;
+ struct rpcrdma_sendctx **rb_sc_ctxs;
+
+ struct list_head rb_allreqs;
+ struct list_head rb_all_mrs;
+ struct list_head rb_all_reps;
+
+ struct llist_head rb_free_reps;
+
+ __be32 rb_max_requests;
+ u32 rb_credits; /* most recent credit grant */
+
+ u32 rb_bc_srv_max_requests;
+ u32 rb_bc_max_requests;
+
+ struct work_struct rb_refresh_worker;
+};
+
+/*
+ * Statistics for RPCRDMA
+ */
+struct rpcrdma_stats {
+ /* accessed when sending a call */
+ unsigned long read_chunk_count;
+ unsigned long write_chunk_count;
+ unsigned long reply_chunk_count;
+ unsigned long long total_rdma_request;
+
+ /* rarely accessed error counters */
+ unsigned long long pullup_copy_count;
+ unsigned long hardway_register_count;
+ unsigned long failed_marshal_count;
+ unsigned long bad_reply_count;
+ unsigned long mrs_recycled;
+ unsigned long mrs_orphaned;
+ unsigned long mrs_allocated;
+ unsigned long empty_sendctx_q;
+
+ /* accessed when receiving a reply */
+ unsigned long long total_rdma_reply;
+ unsigned long long fixup_copy_count;
+ unsigned long reply_waits_for_send;
+ unsigned long local_inv_needed;
+ unsigned long nomsg_call_count;
+ unsigned long bcall_count;
+};
+
+/*
+ * RPCRDMA transport -- encapsulates the structures above for
+ * integration with RPC.
+ *
+ * The contained structures are embedded, not pointers,
+ * for convenience. This structure need not be visible externally.
+ *
+ * It is allocated and initialized during mount, and released
+ * during unmount.
+ */
+struct rpcrdma_xprt {
+ struct rpc_xprt rx_xprt;
+ struct rpcrdma_ep *rx_ep;
+ struct rpcrdma_buffer rx_buf;
+ struct delayed_work rx_connect_worker;
+ struct rpc_timeout rx_timeout;
+ struct rpcrdma_stats rx_stats;
+};
+
+#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, rx_xprt)
+
+static inline const char *
+rpcrdma_addrstr(const struct rpcrdma_xprt *r_xprt)
+{
+ return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_ADDR];
+}
+
+static inline const char *
+rpcrdma_portstr(const struct rpcrdma_xprt *r_xprt)
+{
+ return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_PORT];
+}
+
+/* Setting this to 0 ensures interoperability with early servers.
+ * Setting this to 1 enhances certain unaligned read/write performance.
+ * Default is 0, see sysctl entry and rpc_rdma.c rpcrdma_convert_iovs() */
+extern int xprt_rdma_pad_optimize;
+
+/* This setting controls the hunt for a supported memory
+ * registration strategy.
+ */
+extern unsigned int xprt_rdma_memreg_strategy;
+
+/*
+ * Endpoint calls - xprtrdma/verbs.c
+ */
+void rpcrdma_force_disconnect(struct rpcrdma_ep *ep);
+void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc);
+int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt);
+void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt);
+
+void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp);
+
+/*
+ * Buffer calls - xprtrdma/verbs.c
+ */
+struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt,
+ size_t size);
+int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req);
+void rpcrdma_req_destroy(struct rpcrdma_req *req);
+int rpcrdma_buffer_create(struct rpcrdma_xprt *);
+void rpcrdma_buffer_destroy(struct rpcrdma_buffer *);
+struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt);
+
+struct rpcrdma_mr *rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt);
+void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt);
+
+struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *);
+void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers,
+ struct rpcrdma_req *req);
+void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep);
+void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req);
+
+bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size,
+ gfp_t flags);
+bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_regbuf *rb);
+
+/**
+ * rpcrdma_regbuf_is_mapped - check if buffer is DMA mapped
+ *
+ * Returns true if the buffer is now mapped to rb->rg_device.
+ */
+static inline bool rpcrdma_regbuf_is_mapped(struct rpcrdma_regbuf *rb)
+{
+ return rb->rg_device != NULL;
+}
+
+/**
+ * rpcrdma_regbuf_dma_map - DMA-map a regbuf
+ * @r_xprt: controlling transport instance
+ * @rb: regbuf to be mapped
+ *
+ * Returns true if the buffer is currently DMA mapped.
+ */
+static inline bool rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_regbuf *rb)
+{
+ if (likely(rpcrdma_regbuf_is_mapped(rb)))
+ return true;
+ return __rpcrdma_regbuf_dma_map(r_xprt, rb);
+}
+
+/*
+ * Wrappers for chunk registration, shared by read/write chunk code.
+ */
+
+static inline enum dma_data_direction
+rpcrdma_data_dir(bool writing)
+{
+ return writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
+}
+
+/* Memory registration calls xprtrdma/frwr_ops.c
+ */
+void frwr_reset(struct rpcrdma_req *req);
+int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device);
+int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr);
+void frwr_mr_release(struct rpcrdma_mr *mr);
+struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_mr_seg *seg,
+ int nsegs, bool writing, __be32 xid,
+ struct rpcrdma_mr *mr);
+int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req);
+void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs);
+void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req);
+void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req);
+int frwr_wp_create(struct rpcrdma_xprt *r_xprt);
+
+/*
+ * RPC/RDMA protocol calls - xprtrdma/rpc_rdma.c
+ */
+
+enum rpcrdma_chunktype {
+ rpcrdma_noch = 0,
+ rpcrdma_noch_pullup,
+ rpcrdma_noch_mapped,
+ rpcrdma_readch,
+ rpcrdma_areadch,
+ rpcrdma_writech,
+ rpcrdma_replych
+};
+
+int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt,
+ struct rpcrdma_req *req, u32 hdrlen,
+ struct xdr_buf *xdr,
+ enum rpcrdma_chunktype rtype);
+void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc);
+int rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst);
+void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep);
+void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt);
+void rpcrdma_complete_rqst(struct rpcrdma_rep *rep);
+void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep);
+void rpcrdma_reply_handler(struct rpcrdma_rep *rep);
+
+static inline void rpcrdma_set_xdrlen(struct xdr_buf *xdr, size_t len)
+{
+ xdr->head[0].iov_len = len;
+ xdr->len = len;
+}
+
+/* RPC/RDMA module init - xprtrdma/transport.c
+ */
+extern unsigned int xprt_rdma_max_inline_read;
+extern unsigned int xprt_rdma_max_inline_write;
+void xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap);
+void xprt_rdma_free_addresses(struct rpc_xprt *xprt);
+void xprt_rdma_close(struct rpc_xprt *xprt);
+void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq);
+int xprt_rdma_init(void);
+void xprt_rdma_cleanup(void);
+
+/* Backchannel calls - xprtrdma/backchannel.c
+ */
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+int xprt_rdma_bc_setup(struct rpc_xprt *, unsigned int);
+size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *);
+unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *);
+void rpcrdma_bc_receive_call(struct rpcrdma_xprt *, struct rpcrdma_rep *);
+int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst);
+void xprt_rdma_bc_free_rqst(struct rpc_rqst *);
+void xprt_rdma_bc_destroy(struct rpc_xprt *, unsigned int);
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+extern struct xprt_class xprt_rdma_bc;
+
+#endif /* _LINUX_SUNRPC_XPRT_RDMA_H */
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
new file mode 100644
index 0000000000..a15bf2ede8
--- /dev/null
+++ b/net/sunrpc/xprtsock.c
@@ -0,0 +1,3718 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * linux/net/sunrpc/xprtsock.c
+ *
+ * Client-side transport implementation for sockets.
+ *
+ * TCP callback races fixes (C) 1998 Red Hat
+ * TCP send fixes (C) 1998 Red Hat
+ * TCP NFS related read + write fixes
+ * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie>
+ *
+ * Rewrite of larges part of the code in order to stabilize TCP stuff.
+ * Fix behaviour when socket buffer is full.
+ * (C) 1999 Trond Myklebust <trond.myklebust@fys.uio.no>
+ *
+ * IP socket transport implementation, (C) 2005 Chuck Lever <cel@netapp.com>
+ *
+ * IPv6 support contributed by Gilles Quillard, Bull Open Source, 2005.
+ * <gilles.quillard@bull.net>
+ */
+
+#include <linux/types.h>
+#include <linux/string.h>
+#include <linux/slab.h>
+#include <linux/module.h>
+#include <linux/capability.h>
+#include <linux/pagemap.h>
+#include <linux/errno.h>
+#include <linux/socket.h>
+#include <linux/in.h>
+#include <linux/net.h>
+#include <linux/mm.h>
+#include <linux/un.h>
+#include <linux/udp.h>
+#include <linux/tcp.h>
+#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/addr.h>
+#include <linux/sunrpc/sched.h>
+#include <linux/sunrpc/svcsock.h>
+#include <linux/sunrpc/xprtsock.h>
+#include <linux/file.h>
+#ifdef CONFIG_SUNRPC_BACKCHANNEL
+#include <linux/sunrpc/bc_xprt.h>
+#endif
+
+#include <net/sock.h>
+#include <net/checksum.h>
+#include <net/udp.h>
+#include <net/tcp.h>
+#include <net/tls_prot.h>
+#include <net/handshake.h>
+
+#include <linux/bvec.h>
+#include <linux/highmem.h>
+#include <linux/uio.h>
+#include <linux/sched/mm.h>
+
+#include <trace/events/sock.h>
+#include <trace/events/sunrpc.h>
+
+#include "socklib.h"
+#include "sunrpc.h"
+
+static void xs_close(struct rpc_xprt *xprt);
+static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock);
+static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
+ struct socket *sock);
+
+/*
+ * xprtsock tunables
+ */
+static unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE;
+static unsigned int xprt_tcp_slot_table_entries = RPC_MIN_SLOT_TABLE;
+static unsigned int xprt_max_tcp_slot_table_entries = RPC_MAX_SLOT_TABLE;
+
+static unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT;
+static unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT;
+
+#define XS_TCP_LINGER_TO (15U * HZ)
+static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO;
+
+/*
+ * We can register our own files under /proc/sys/sunrpc by
+ * calling register_sysctl() again. The files in that
+ * directory become the union of all files registered there.
+ *
+ * We simply need to make sure that we don't collide with
+ * someone else's file names!
+ */
+
+static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE;
+static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE;
+static unsigned int max_tcp_slot_table_limit = RPC_MAX_SLOT_TABLE_LIMIT;
+static unsigned int xprt_min_resvport_limit = RPC_MIN_RESVPORT;
+static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT;
+
+static struct ctl_table_header *sunrpc_table_header;
+
+static struct xprt_class xs_local_transport;
+static struct xprt_class xs_udp_transport;
+static struct xprt_class xs_tcp_transport;
+static struct xprt_class xs_tcp_tls_transport;
+static struct xprt_class xs_bc_tcp_transport;
+
+/*
+ * FIXME: changing the UDP slot table size should also resize the UDP
+ * socket buffers for existing UDP transports
+ */
+static struct ctl_table xs_tunables_table[] = {
+ {
+ .procname = "udp_slot_table_entries",
+ .data = &xprt_udp_slot_table_entries,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_slot_table_size,
+ .extra2 = &max_slot_table_size
+ },
+ {
+ .procname = "tcp_slot_table_entries",
+ .data = &xprt_tcp_slot_table_entries,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_slot_table_size,
+ .extra2 = &max_slot_table_size
+ },
+ {
+ .procname = "tcp_max_slot_table_entries",
+ .data = &xprt_max_tcp_slot_table_entries,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &min_slot_table_size,
+ .extra2 = &max_tcp_slot_table_limit
+ },
+ {
+ .procname = "min_resvport",
+ .data = &xprt_min_resvport,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &xprt_min_resvport_limit,
+ .extra2 = &xprt_max_resvport_limit
+ },
+ {
+ .procname = "max_resvport",
+ .data = &xprt_max_resvport,
+ .maxlen = sizeof(unsigned int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = &xprt_min_resvport_limit,
+ .extra2 = &xprt_max_resvport_limit
+ },
+ {
+ .procname = "tcp_fin_timeout",
+ .data = &xs_tcp_fin_timeout,
+ .maxlen = sizeof(xs_tcp_fin_timeout),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_jiffies,
+ },
+ { },
+};
+
+/*
+ * Wait duration for a reply from the RPC portmapper.
+ */
+#define XS_BIND_TO (60U * HZ)
+
+/*
+ * Delay if a UDP socket connect error occurs. This is most likely some
+ * kind of resource problem on the local host.
+ */
+#define XS_UDP_REEST_TO (2U * HZ)
+
+/*
+ * The reestablish timeout allows clients to delay for a bit before attempting
+ * to reconnect to a server that just dropped our connection.
+ *
+ * We implement an exponential backoff when trying to reestablish a TCP
+ * transport connection with the server. Some servers like to drop a TCP
+ * connection when they are overworked, so we start with a short timeout and
+ * increase over time if the server is down or not responding.
+ */
+#define XS_TCP_INIT_REEST_TO (3U * HZ)
+
+/*
+ * TCP idle timeout; client drops the transport socket if it is idle
+ * for this long. Note that we also timeout UDP sockets to prevent
+ * holding port numbers when there is no RPC traffic.
+ */
+#define XS_IDLE_DISC_TO (5U * 60 * HZ)
+
+/*
+ * TLS handshake timeout.
+ */
+#define XS_TLS_HANDSHAKE_TO (10U * HZ)
+
+#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
+# undef RPC_DEBUG_DATA
+# define RPCDBG_FACILITY RPCDBG_TRANS
+#endif
+
+#ifdef RPC_DEBUG_DATA
+static void xs_pktdump(char *msg, u32 *packet, unsigned int count)
+{
+ u8 *buf = (u8 *) packet;
+ int j;
+
+ dprintk("RPC: %s\n", msg);
+ for (j = 0; j < count && j < 128; j += 4) {
+ if (!(j & 31)) {
+ if (j)
+ dprintk("\n");
+ dprintk("0x%04x ", j);
+ }
+ dprintk("%02x%02x%02x%02x ",
+ buf[j], buf[j+1], buf[j+2], buf[j+3]);
+ }
+ dprintk("\n");
+}
+#else
+static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count)
+{
+ /* NOP */
+}
+#endif
+
+static inline struct rpc_xprt *xprt_from_sock(struct sock *sk)
+{
+ return (struct rpc_xprt *) sk->sk_user_data;
+}
+
+static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt)
+{
+ return (struct sockaddr *) &xprt->addr;
+}
+
+static inline struct sockaddr_un *xs_addr_un(struct rpc_xprt *xprt)
+{
+ return (struct sockaddr_un *) &xprt->addr;
+}
+
+static inline struct sockaddr_in *xs_addr_in(struct rpc_xprt *xprt)
+{
+ return (struct sockaddr_in *) &xprt->addr;
+}
+
+static inline struct sockaddr_in6 *xs_addr_in6(struct rpc_xprt *xprt)
+{
+ return (struct sockaddr_in6 *) &xprt->addr;
+}
+
+static void xs_format_common_peer_addresses(struct rpc_xprt *xprt)
+{
+ struct sockaddr *sap = xs_addr(xprt);
+ struct sockaddr_in6 *sin6;
+ struct sockaddr_in *sin;
+ struct sockaddr_un *sun;
+ char buf[128];
+
+ switch (sap->sa_family) {
+ case AF_LOCAL:
+ sun = xs_addr_un(xprt);
+ if (sun->sun_path[0]) {
+ strscpy(buf, sun->sun_path, sizeof(buf));
+ } else {
+ buf[0] = '@';
+ strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1);
+ }
+ xprt->address_strings[RPC_DISPLAY_ADDR] =
+ kstrdup(buf, GFP_KERNEL);
+ break;
+ case AF_INET:
+ (void)rpc_ntop(sap, buf, sizeof(buf));
+ xprt->address_strings[RPC_DISPLAY_ADDR] =
+ kstrdup(buf, GFP_KERNEL);
+ sin = xs_addr_in(xprt);
+ snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr));
+ break;
+ case AF_INET6:
+ (void)rpc_ntop(sap, buf, sizeof(buf));
+ xprt->address_strings[RPC_DISPLAY_ADDR] =
+ kstrdup(buf, GFP_KERNEL);
+ sin6 = xs_addr_in6(xprt);
+ snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr);
+ break;
+ default:
+ BUG();
+ }
+
+ xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL);
+}
+
+static void xs_format_common_peer_ports(struct rpc_xprt *xprt)
+{
+ struct sockaddr *sap = xs_addr(xprt);
+ char buf[128];
+
+ snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap));
+ xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL);
+
+ snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap));
+ xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL);
+}
+
+static void xs_format_peer_addresses(struct rpc_xprt *xprt,
+ const char *protocol,
+ const char *netid)
+{
+ xprt->address_strings[RPC_DISPLAY_PROTO] = protocol;
+ xprt->address_strings[RPC_DISPLAY_NETID] = netid;
+ xs_format_common_peer_addresses(xprt);
+ xs_format_common_peer_ports(xprt);
+}
+
+static void xs_update_peer_port(struct rpc_xprt *xprt)
+{
+ kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]);
+ kfree(xprt->address_strings[RPC_DISPLAY_PORT]);
+
+ xs_format_common_peer_ports(xprt);
+}
+
+static void xs_free_peer_addresses(struct rpc_xprt *xprt)
+{
+ unsigned int i;
+
+ for (i = 0; i < RPC_DISPLAY_MAX; i++)
+ switch (i) {
+ case RPC_DISPLAY_PROTO:
+ case RPC_DISPLAY_NETID:
+ continue;
+ default:
+ kfree(xprt->address_strings[i]);
+ }
+}
+
+static size_t
+xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)
+{
+ size_t i,n;
+
+ if (!want || !(buf->flags & XDRBUF_SPARSE_PAGES))
+ return want;
+ n = (buf->page_base + want + PAGE_SIZE - 1) >> PAGE_SHIFT;
+ for (i = 0; i < n; i++) {
+ if (buf->pages[i])
+ continue;
+ buf->bvec[i].bv_page = buf->pages[i] = alloc_page(gfp);
+ if (!buf->pages[i]) {
+ i *= PAGE_SIZE;
+ return i > buf->page_base ? i - buf->page_base : 0;
+ }
+ }
+ return want;
+}
+
+static int
+xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
+ struct cmsghdr *cmsg, int ret)
+{
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ msg->msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -EACCES : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
+ }
+ return ret;
+}
+
+static int
+xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags)
+{
+ union {
+ struct cmsghdr cmsg;
+ u8 buf[CMSG_SPACE(sizeof(u8))];
+ } u;
+ int ret;
+
+ msg->msg_control = &u;
+ msg->msg_controllen = sizeof(u);
+ ret = sock_recvmsg(sock, msg, flags);
+ if (msg->msg_controllen != sizeof(u))
+ ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret);
+ return ret;
+}
+
+static ssize_t
+xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
+{
+ ssize_t ret;
+ if (seek != 0)
+ iov_iter_advance(&msg->msg_iter, seek);
+ ret = xs_sock_recv_cmsg(sock, msg, flags);
+ return ret > 0 ? ret + seek : ret;
+}
+
+static ssize_t
+xs_read_kvec(struct socket *sock, struct msghdr *msg, int flags,
+ struct kvec *kvec, size_t count, size_t seek)
+{
+ iov_iter_kvec(&msg->msg_iter, ITER_DEST, kvec, 1, count);
+ return xs_sock_recvmsg(sock, msg, flags, seek);
+}
+
+static ssize_t
+xs_read_bvec(struct socket *sock, struct msghdr *msg, int flags,
+ struct bio_vec *bvec, unsigned long nr, size_t count,
+ size_t seek)
+{
+ iov_iter_bvec(&msg->msg_iter, ITER_DEST, bvec, nr, count);
+ return xs_sock_recvmsg(sock, msg, flags, seek);
+}
+
+static ssize_t
+xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
+ size_t count)
+{
+ iov_iter_discard(&msg->msg_iter, ITER_DEST, count);
+ return xs_sock_recv_cmsg(sock, msg, flags);
+}
+
+#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
+static void
+xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek)
+{
+ struct bvec_iter bi = {
+ .bi_size = count,
+ };
+ struct bio_vec bv;
+
+ bvec_iter_advance(bvec, &bi, seek & PAGE_MASK);
+ for_each_bvec(bv, bvec, bi, bi)
+ flush_dcache_page(bv.bv_page);
+}
+#else
+static inline void
+xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek)
+{
+}
+#endif
+
+static ssize_t
+xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags,
+ struct xdr_buf *buf, size_t count, size_t seek, size_t *read)
+{
+ size_t want, seek_init = seek, offset = 0;
+ ssize_t ret;
+
+ want = min_t(size_t, count, buf->head[0].iov_len);
+ if (seek < want) {
+ ret = xs_read_kvec(sock, msg, flags, &buf->head[0], want, seek);
+ if (ret <= 0)
+ goto sock_err;
+ offset += ret;
+ if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC))
+ goto out;
+ if (ret != want)
+ goto out;
+ seek = 0;
+ } else {
+ seek -= want;
+ offset += want;
+ }
+
+ want = xs_alloc_sparse_pages(
+ buf, min_t(size_t, count - offset, buf->page_len),
+ GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN);
+ if (seek < want) {
+ ret = xs_read_bvec(sock, msg, flags, buf->bvec,
+ xdr_buf_pagecount(buf),
+ want + buf->page_base,
+ seek + buf->page_base);
+ if (ret <= 0)
+ goto sock_err;
+ xs_flush_bvec(buf->bvec, ret, seek + buf->page_base);
+ ret -= buf->page_base;
+ offset += ret;
+ if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC))
+ goto out;
+ if (ret != want)
+ goto out;
+ seek = 0;
+ } else {
+ seek -= want;
+ offset += want;
+ }
+
+ want = min_t(size_t, count - offset, buf->tail[0].iov_len);
+ if (seek < want) {
+ ret = xs_read_kvec(sock, msg, flags, &buf->tail[0], want, seek);
+ if (ret <= 0)
+ goto sock_err;
+ offset += ret;
+ if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC))
+ goto out;
+ if (ret != want)
+ goto out;
+ } else if (offset < seek_init)
+ offset = seek_init;
+ ret = -EMSGSIZE;
+out:
+ *read = offset - seek_init;
+ return ret;
+sock_err:
+ offset += seek;
+ goto out;
+}
+
+static void
+xs_read_header(struct sock_xprt *transport, struct xdr_buf *buf)
+{
+ if (!transport->recv.copied) {
+ if (buf->head[0].iov_len >= transport->recv.offset)
+ memcpy(buf->head[0].iov_base,
+ &transport->recv.xid,
+ transport->recv.offset);
+ transport->recv.copied = transport->recv.offset;
+ }
+}
+
+static bool
+xs_read_stream_request_done(struct sock_xprt *transport)
+{
+ return transport->recv.fraghdr & cpu_to_be32(RPC_LAST_STREAM_FRAGMENT);
+}
+
+static void
+xs_read_stream_check_eor(struct sock_xprt *transport,
+ struct msghdr *msg)
+{
+ if (xs_read_stream_request_done(transport))
+ msg->msg_flags |= MSG_EOR;
+}
+
+static ssize_t
+xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg,
+ int flags, struct rpc_rqst *req)
+{
+ struct xdr_buf *buf = &req->rq_private_buf;
+ size_t want, read;
+ ssize_t ret;
+
+ xs_read_header(transport, buf);
+
+ want = transport->recv.len - transport->recv.offset;
+ if (want != 0) {
+ ret = xs_read_xdr_buf(transport->sock, msg, flags, buf,
+ transport->recv.copied + want,
+ transport->recv.copied,
+ &read);
+ transport->recv.offset += read;
+ transport->recv.copied += read;
+ }
+
+ if (transport->recv.offset == transport->recv.len)
+ xs_read_stream_check_eor(transport, msg);
+
+ if (want == 0)
+ return 0;
+
+ switch (ret) {
+ default:
+ break;
+ case -EFAULT:
+ case -EMSGSIZE:
+ msg->msg_flags |= MSG_TRUNC;
+ return read;
+ case 0:
+ return -ESHUTDOWN;
+ }
+ return ret < 0 ? ret : read;
+}
+
+static size_t
+xs_read_stream_headersize(bool isfrag)
+{
+ if (isfrag)
+ return sizeof(__be32);
+ return 3 * sizeof(__be32);
+}
+
+static ssize_t
+xs_read_stream_header(struct sock_xprt *transport, struct msghdr *msg,
+ int flags, size_t want, size_t seek)
+{
+ struct kvec kvec = {
+ .iov_base = &transport->recv.fraghdr,
+ .iov_len = want,
+ };
+ return xs_read_kvec(transport->sock, msg, flags, &kvec, want, seek);
+}
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+static ssize_t
+xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags)
+{
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct rpc_rqst *req;
+ ssize_t ret;
+
+ /* Is this transport associated with the backchannel? */
+ if (!xprt->bc_serv)
+ return -ESHUTDOWN;
+
+ /* Look up and lock the request corresponding to the given XID */
+ req = xprt_lookup_bc_request(xprt, transport->recv.xid);
+ if (!req) {
+ printk(KERN_WARNING "Callback slot table overflowed\n");
+ return -ESHUTDOWN;
+ }
+ if (transport->recv.copied && !req->rq_private_buf.len)
+ return -ESHUTDOWN;
+
+ ret = xs_read_stream_request(transport, msg, flags, req);
+ if (msg->msg_flags & (MSG_EOR|MSG_TRUNC))
+ xprt_complete_bc_request(req, transport->recv.copied);
+ else
+ req->rq_private_buf.len = transport->recv.copied;
+
+ return ret;
+}
+#else /* CONFIG_SUNRPC_BACKCHANNEL */
+static ssize_t
+xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags)
+{
+ return -ESHUTDOWN;
+}
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+static ssize_t
+xs_read_stream_reply(struct sock_xprt *transport, struct msghdr *msg, int flags)
+{
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct rpc_rqst *req;
+ ssize_t ret = 0;
+
+ /* Look up and lock the request corresponding to the given XID */
+ spin_lock(&xprt->queue_lock);
+ req = xprt_lookup_rqst(xprt, transport->recv.xid);
+ if (!req || (transport->recv.copied && !req->rq_private_buf.len)) {
+ msg->msg_flags |= MSG_TRUNC;
+ goto out;
+ }
+ xprt_pin_rqst(req);
+ spin_unlock(&xprt->queue_lock);
+
+ ret = xs_read_stream_request(transport, msg, flags, req);
+
+ spin_lock(&xprt->queue_lock);
+ if (msg->msg_flags & (MSG_EOR|MSG_TRUNC))
+ xprt_complete_rqst(req->rq_task, transport->recv.copied);
+ else
+ req->rq_private_buf.len = transport->recv.copied;
+ xprt_unpin_rqst(req);
+out:
+ spin_unlock(&xprt->queue_lock);
+ return ret;
+}
+
+static ssize_t
+xs_read_stream(struct sock_xprt *transport, int flags)
+{
+ struct msghdr msg = { 0 };
+ size_t want, read = 0;
+ ssize_t ret = 0;
+
+ if (transport->recv.len == 0) {
+ want = xs_read_stream_headersize(transport->recv.copied != 0);
+ ret = xs_read_stream_header(transport, &msg, flags, want,
+ transport->recv.offset);
+ if (ret <= 0)
+ goto out_err;
+ transport->recv.offset = ret;
+ if (transport->recv.offset != want)
+ return transport->recv.offset;
+ transport->recv.len = be32_to_cpu(transport->recv.fraghdr) &
+ RPC_FRAGMENT_SIZE_MASK;
+ transport->recv.offset -= sizeof(transport->recv.fraghdr);
+ read = ret;
+ }
+
+ switch (be32_to_cpu(transport->recv.calldir)) {
+ default:
+ msg.msg_flags |= MSG_TRUNC;
+ break;
+ case RPC_CALL:
+ ret = xs_read_stream_call(transport, &msg, flags);
+ break;
+ case RPC_REPLY:
+ ret = xs_read_stream_reply(transport, &msg, flags);
+ }
+ if (msg.msg_flags & MSG_TRUNC) {
+ transport->recv.calldir = cpu_to_be32(-1);
+ transport->recv.copied = -1;
+ }
+ if (ret < 0)
+ goto out_err;
+ read += ret;
+ if (transport->recv.offset < transport->recv.len) {
+ if (!(msg.msg_flags & MSG_TRUNC))
+ return read;
+ msg.msg_flags = 0;
+ ret = xs_read_discard(transport->sock, &msg, flags,
+ transport->recv.len - transport->recv.offset);
+ if (ret <= 0)
+ goto out_err;
+ transport->recv.offset += ret;
+ read += ret;
+ if (transport->recv.offset != transport->recv.len)
+ return read;
+ }
+ if (xs_read_stream_request_done(transport)) {
+ trace_xs_stream_read_request(transport);
+ transport->recv.copied = 0;
+ }
+ transport->recv.offset = 0;
+ transport->recv.len = 0;
+ return read;
+out_err:
+ return ret != 0 ? ret : -ESHUTDOWN;
+}
+
+static __poll_t xs_poll_socket(struct sock_xprt *transport)
+{
+ return transport->sock->ops->poll(transport->file, transport->sock,
+ NULL);
+}
+
+static bool xs_poll_socket_readable(struct sock_xprt *transport)
+{
+ __poll_t events = xs_poll_socket(transport);
+
+ return (events & (EPOLLIN | EPOLLRDNORM)) && !(events & EPOLLRDHUP);
+}
+
+static void xs_poll_check_readable(struct sock_xprt *transport)
+{
+
+ clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
+ if (!xs_poll_socket_readable(transport))
+ return;
+ if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))
+ queue_work(xprtiod_workqueue, &transport->recv_worker);
+}
+
+static void xs_stream_data_receive(struct sock_xprt *transport)
+{
+ size_t read = 0;
+ ssize_t ret = 0;
+
+ mutex_lock(&transport->recv_mutex);
+ if (transport->sock == NULL)
+ goto out;
+ for (;;) {
+ ret = xs_read_stream(transport, MSG_DONTWAIT);
+ if (ret < 0)
+ break;
+ read += ret;
+ cond_resched();
+ }
+ if (ret == -ESHUTDOWN)
+ kernel_sock_shutdown(transport->sock, SHUT_RDWR);
+ else if (ret == -EACCES)
+ xprt_wake_pending_tasks(&transport->xprt, -EACCES);
+ else
+ xs_poll_check_readable(transport);
+out:
+ mutex_unlock(&transport->recv_mutex);
+ trace_xs_stream_read_data(&transport->xprt, ret, read);
+}
+
+static void xs_stream_data_receive_workfn(struct work_struct *work)
+{
+ struct sock_xprt *transport =
+ container_of(work, struct sock_xprt, recv_worker);
+ unsigned int pflags = memalloc_nofs_save();
+
+ xs_stream_data_receive(transport);
+ memalloc_nofs_restore(pflags);
+}
+
+static void
+xs_stream_reset_connect(struct sock_xprt *transport)
+{
+ transport->recv.offset = 0;
+ transport->recv.len = 0;
+ transport->recv.copied = 0;
+ transport->xmit.offset = 0;
+}
+
+static void
+xs_stream_start_connect(struct sock_xprt *transport)
+{
+ transport->xprt.stat.connect_count++;
+ transport->xprt.stat.connect_start = jiffies;
+}
+
+#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL)
+
+/**
+ * xs_nospace - handle transmit was incomplete
+ * @req: pointer to RPC request
+ * @transport: pointer to struct sock_xprt
+ *
+ */
+static int xs_nospace(struct rpc_rqst *req, struct sock_xprt *transport)
+{
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct sock *sk = transport->inet;
+ int ret = -EAGAIN;
+
+ trace_rpc_socket_nospace(req, transport);
+
+ /* Protect against races with write_space */
+ spin_lock(&xprt->transport_lock);
+
+ /* Don't race with disconnect */
+ if (xprt_connected(xprt)) {
+ /* wait for more buffer space */
+ set_bit(XPRT_SOCK_NOSPACE, &transport->sock_state);
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+ sk->sk_write_pending++;
+ xprt_wait_for_buffer_space(xprt);
+ } else
+ ret = -ENOTCONN;
+
+ spin_unlock(&xprt->transport_lock);
+ return ret;
+}
+
+static int xs_sock_nospace(struct rpc_rqst *req)
+{
+ struct sock_xprt *transport =
+ container_of(req->rq_xprt, struct sock_xprt, xprt);
+ struct sock *sk = transport->inet;
+ int ret = -EAGAIN;
+
+ lock_sock(sk);
+ if (!sock_writeable(sk))
+ ret = xs_nospace(req, transport);
+ release_sock(sk);
+ return ret;
+}
+
+static int xs_stream_nospace(struct rpc_rqst *req, bool vm_wait)
+{
+ struct sock_xprt *transport =
+ container_of(req->rq_xprt, struct sock_xprt, xprt);
+ struct sock *sk = transport->inet;
+ int ret = -EAGAIN;
+
+ if (vm_wait)
+ return -ENOBUFS;
+ lock_sock(sk);
+ if (!sk_stream_memory_free(sk))
+ ret = xs_nospace(req, transport);
+ release_sock(sk);
+ return ret;
+}
+
+static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf)
+{
+ return xdr_alloc_bvec(buf, rpc_task_gfp_mask());
+}
+
+/*
+ * Determine if the previous message in the stream was aborted before it
+ * could complete transmission.
+ */
+static bool
+xs_send_request_was_aborted(struct sock_xprt *transport, struct rpc_rqst *req)
+{
+ return transport->xmit.offset != 0 && req->rq_bytes_sent == 0;
+}
+
+/*
+ * Return the stream record marker field for a record of length < 2^31-1
+ */
+static rpc_fraghdr
+xs_stream_record_marker(struct xdr_buf *xdr)
+{
+ if (!xdr->len)
+ return 0;
+ return cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | (u32)xdr->len);
+}
+
+/**
+ * xs_local_send_request - write an RPC request to an AF_LOCAL socket
+ * @req: pointer to RPC request
+ *
+ * Return values:
+ * 0: The request has been sent
+ * EAGAIN: The socket was blocked, please call again later to
+ * complete the request
+ * ENOTCONN: Caller needs to invoke connect logic then call again
+ * other: Some other error occurred, the request was not sent
+ */
+static int xs_local_send_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+ struct xdr_buf *xdr = &req->rq_snd_buf;
+ rpc_fraghdr rm = xs_stream_record_marker(xdr);
+ unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen;
+ struct msghdr msg = {
+ .msg_flags = XS_SENDMSG_FLAGS,
+ };
+ bool vm_wait;
+ unsigned int sent;
+ int status;
+
+ /* Close the stream if the previous transmission was incomplete */
+ if (xs_send_request_was_aborted(transport, req)) {
+ xprt_force_disconnect(xprt);
+ return -ENOTCONN;
+ }
+
+ xs_pktdump("packet data:",
+ req->rq_svec->iov_base, req->rq_svec->iov_len);
+
+ vm_wait = sk_stream_is_writeable(transport->inet) ? true : false;
+
+ req->rq_xtime = ktime_get();
+ status = xprt_sock_sendmsg(transport->sock, &msg, xdr,
+ transport->xmit.offset, rm, &sent);
+ dprintk("RPC: %s(%u) = %d\n",
+ __func__, xdr->len - transport->xmit.offset, status);
+
+ if (likely(sent > 0) || status == 0) {
+ transport->xmit.offset += sent;
+ req->rq_bytes_sent = transport->xmit.offset;
+ if (likely(req->rq_bytes_sent >= msglen)) {
+ req->rq_xmit_bytes_sent += transport->xmit.offset;
+ transport->xmit.offset = 0;
+ return 0;
+ }
+ status = -EAGAIN;
+ vm_wait = false;
+ }
+
+ switch (status) {
+ case -EAGAIN:
+ status = xs_stream_nospace(req, vm_wait);
+ break;
+ default:
+ dprintk("RPC: sendmsg returned unrecognized error %d\n",
+ -status);
+ fallthrough;
+ case -EPIPE:
+ xprt_force_disconnect(xprt);
+ status = -ENOTCONN;
+ }
+
+ return status;
+}
+
+/**
+ * xs_udp_send_request - write an RPC request to a UDP socket
+ * @req: pointer to RPC request
+ *
+ * Return values:
+ * 0: The request has been sent
+ * EAGAIN: The socket was blocked, please call again later to
+ * complete the request
+ * ENOTCONN: Caller needs to invoke connect logic then call again
+ * other: Some other error occurred, the request was not sent
+ */
+static int xs_udp_send_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct xdr_buf *xdr = &req->rq_snd_buf;
+ struct msghdr msg = {
+ .msg_name = xs_addr(xprt),
+ .msg_namelen = xprt->addrlen,
+ .msg_flags = XS_SENDMSG_FLAGS,
+ };
+ unsigned int sent;
+ int status;
+
+ xs_pktdump("packet data:",
+ req->rq_svec->iov_base,
+ req->rq_svec->iov_len);
+
+ if (!xprt_bound(xprt))
+ return -ENOTCONN;
+
+ if (!xprt_request_get_cong(xprt, req))
+ return -EBADSLT;
+
+ status = xdr_alloc_bvec(xdr, rpc_task_gfp_mask());
+ if (status < 0)
+ return status;
+ req->rq_xtime = ktime_get();
+ status = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, 0, &sent);
+
+ dprintk("RPC: xs_udp_send_request(%u) = %d\n",
+ xdr->len, status);
+
+ /* firewall is blocking us, don't return -EAGAIN or we end up looping */
+ if (status == -EPERM)
+ goto process_status;
+
+ if (status == -EAGAIN && sock_writeable(transport->inet))
+ status = -ENOBUFS;
+
+ if (sent > 0 || status == 0) {
+ req->rq_xmit_bytes_sent += sent;
+ if (sent >= req->rq_slen)
+ return 0;
+ /* Still some bytes left; set up for a retry later. */
+ status = -EAGAIN;
+ }
+
+process_status:
+ switch (status) {
+ case -ENOTSOCK:
+ status = -ENOTCONN;
+ /* Should we call xs_close() here? */
+ break;
+ case -EAGAIN:
+ status = xs_sock_nospace(req);
+ break;
+ case -ENETUNREACH:
+ case -ENOBUFS:
+ case -EPIPE:
+ case -ECONNREFUSED:
+ case -EPERM:
+ /* When the server has died, an ICMP port unreachable message
+ * prompts ECONNREFUSED. */
+ break;
+ default:
+ dprintk("RPC: sendmsg returned unrecognized error %d\n",
+ -status);
+ }
+
+ return status;
+}
+
+/**
+ * xs_tcp_send_request - write an RPC request to a TCP socket
+ * @req: pointer to RPC request
+ *
+ * Return values:
+ * 0: The request has been sent
+ * EAGAIN: The socket was blocked, please call again later to
+ * complete the request
+ * ENOTCONN: Caller needs to invoke connect logic then call again
+ * other: Some other error occurred, the request was not sent
+ *
+ * XXX: In the case of soft timeouts, should we eventually give up
+ * if sendmsg is not able to make progress?
+ */
+static int xs_tcp_send_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct xdr_buf *xdr = &req->rq_snd_buf;
+ rpc_fraghdr rm = xs_stream_record_marker(xdr);
+ unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen;
+ struct msghdr msg = {
+ .msg_flags = XS_SENDMSG_FLAGS,
+ };
+ bool vm_wait;
+ unsigned int sent;
+ int status;
+
+ /* Close the stream if the previous transmission was incomplete */
+ if (xs_send_request_was_aborted(transport, req)) {
+ if (transport->sock != NULL)
+ kernel_sock_shutdown(transport->sock, SHUT_RDWR);
+ return -ENOTCONN;
+ }
+ if (!transport->inet)
+ return -ENOTCONN;
+
+ xs_pktdump("packet data:",
+ req->rq_svec->iov_base,
+ req->rq_svec->iov_len);
+
+ if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state))
+ xs_tcp_set_socket_timeouts(xprt, transport->sock);
+
+ xs_set_srcport(transport, transport->sock);
+
+ /* Continue transmitting the packet/record. We must be careful
+ * to cope with writespace callbacks arriving _after_ we have
+ * called sendmsg(). */
+ req->rq_xtime = ktime_get();
+ tcp_sock_set_cork(transport->inet, true);
+
+ vm_wait = sk_stream_is_writeable(transport->inet) ? true : false;
+
+ do {
+ status = xprt_sock_sendmsg(transport->sock, &msg, xdr,
+ transport->xmit.offset, rm, &sent);
+
+ dprintk("RPC: xs_tcp_send_request(%u) = %d\n",
+ xdr->len - transport->xmit.offset, status);
+
+ /* If we've sent the entire packet, immediately
+ * reset the count of bytes sent. */
+ transport->xmit.offset += sent;
+ req->rq_bytes_sent = transport->xmit.offset;
+ if (likely(req->rq_bytes_sent >= msglen)) {
+ req->rq_xmit_bytes_sent += transport->xmit.offset;
+ transport->xmit.offset = 0;
+ if (atomic_long_read(&xprt->xmit_queuelen) == 1)
+ tcp_sock_set_cork(transport->inet, false);
+ return 0;
+ }
+
+ WARN_ON_ONCE(sent == 0 && status == 0);
+
+ if (sent > 0)
+ vm_wait = false;
+
+ } while (status == 0);
+
+ switch (status) {
+ case -ENOTSOCK:
+ status = -ENOTCONN;
+ /* Should we call xs_close() here? */
+ break;
+ case -EAGAIN:
+ status = xs_stream_nospace(req, vm_wait);
+ break;
+ case -ECONNRESET:
+ case -ECONNREFUSED:
+ case -ENOTCONN:
+ case -EADDRINUSE:
+ case -ENOBUFS:
+ case -EPIPE:
+ break;
+ default:
+ dprintk("RPC: sendmsg returned unrecognized error %d\n",
+ -status);
+ }
+
+ return status;
+}
+
+static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk)
+{
+ transport->old_data_ready = sk->sk_data_ready;
+ transport->old_state_change = sk->sk_state_change;
+ transport->old_write_space = sk->sk_write_space;
+ transport->old_error_report = sk->sk_error_report;
+}
+
+static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk)
+{
+ sk->sk_data_ready = transport->old_data_ready;
+ sk->sk_state_change = transport->old_state_change;
+ sk->sk_write_space = transport->old_write_space;
+ sk->sk_error_report = transport->old_error_report;
+}
+
+static void xs_sock_reset_state_flags(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
+ clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state);
+ clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state);
+ clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state);
+ clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state);
+}
+
+static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr)
+{
+ set_bit(nr, &transport->sock_state);
+ queue_work(xprtiod_workqueue, &transport->error_worker);
+}
+
+static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt)
+{
+ xprt->connect_cookie++;
+ smp_mb__before_atomic();
+ clear_bit(XPRT_CLOSE_WAIT, &xprt->state);
+ clear_bit(XPRT_CLOSING, &xprt->state);
+ xs_sock_reset_state_flags(xprt);
+ smp_mb__after_atomic();
+}
+
+/**
+ * xs_error_report - callback to handle TCP socket state errors
+ * @sk: socket
+ *
+ * Note: we don't call sock_error() since there may be a rpc_task
+ * using the socket, and so we don't want to clear sk->sk_err.
+ */
+static void xs_error_report(struct sock *sk)
+{
+ struct sock_xprt *transport;
+ struct rpc_xprt *xprt;
+
+ if (!(xprt = xprt_from_sock(sk)))
+ return;
+
+ transport = container_of(xprt, struct sock_xprt, xprt);
+ transport->xprt_err = -sk->sk_err;
+ if (transport->xprt_err == 0)
+ return;
+ dprintk("RPC: xs_error_report client %p, error=%d...\n",
+ xprt, -transport->xprt_err);
+ trace_rpc_socket_error(xprt, sk->sk_socket, transport->xprt_err);
+
+ /* barrier ensures xprt_err is set before XPRT_SOCK_WAKE_ERROR */
+ smp_mb__before_atomic();
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_ERROR);
+}
+
+static void xs_reset_transport(struct sock_xprt *transport)
+{
+ struct socket *sock = transport->sock;
+ struct sock *sk = transport->inet;
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct file *filp = transport->file;
+
+ if (sk == NULL)
+ return;
+ /*
+ * Make sure we're calling this in a context from which it is safe
+ * to call __fput_sync(). In practice that means rpciod and the
+ * system workqueue.
+ */
+ if (!(current->flags & PF_WQ_WORKER)) {
+ WARN_ON_ONCE(1);
+ set_bit(XPRT_CLOSE_WAIT, &xprt->state);
+ return;
+ }
+
+ if (atomic_read(&transport->xprt.swapper))
+ sk_clear_memalloc(sk);
+
+ tls_handshake_cancel(sk);
+
+ kernel_sock_shutdown(sock, SHUT_RDWR);
+
+ mutex_lock(&transport->recv_mutex);
+ lock_sock(sk);
+ transport->inet = NULL;
+ transport->sock = NULL;
+ transport->file = NULL;
+
+ sk->sk_user_data = NULL;
+
+ xs_restore_old_callbacks(transport, sk);
+ xprt_clear_connected(xprt);
+ xs_sock_reset_connection_flags(xprt);
+ /* Reset stream record info */
+ xs_stream_reset_connect(transport);
+ release_sock(sk);
+ mutex_unlock(&transport->recv_mutex);
+
+ trace_rpc_socket_close(xprt, sock);
+ __fput_sync(filp);
+
+ xprt_disconnect_done(xprt);
+}
+
+/**
+ * xs_close - close a socket
+ * @xprt: transport
+ *
+ * This is used when all requests are complete; ie, no DRC state remains
+ * on the server we want to save.
+ *
+ * The caller _must_ be holding XPRT_LOCKED in order to avoid issues with
+ * xs_reset_transport() zeroing the socket from underneath a writer.
+ */
+static void xs_close(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ dprintk("RPC: xs_close xprt %p\n", xprt);
+
+ if (transport->sock)
+ tls_handshake_close(transport->sock);
+ xs_reset_transport(transport);
+ xprt->reestablish_timeout = 0;
+}
+
+static void xs_inject_disconnect(struct rpc_xprt *xprt)
+{
+ dprintk("RPC: injecting transport disconnect on xprt=%p\n",
+ xprt);
+ xprt_disconnect_done(xprt);
+}
+
+static void xs_xprt_free(struct rpc_xprt *xprt)
+{
+ xs_free_peer_addresses(xprt);
+ xprt_free(xprt);
+}
+
+/**
+ * xs_destroy - prepare to shutdown a transport
+ * @xprt: doomed transport
+ *
+ */
+static void xs_destroy(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt,
+ struct sock_xprt, xprt);
+ dprintk("RPC: xs_destroy xprt %p\n", xprt);
+
+ cancel_delayed_work_sync(&transport->connect_worker);
+ xs_close(xprt);
+ cancel_work_sync(&transport->recv_worker);
+ cancel_work_sync(&transport->error_worker);
+ xs_xprt_free(xprt);
+ module_put(THIS_MODULE);
+}
+
+/**
+ * xs_udp_data_read_skb - receive callback for UDP sockets
+ * @xprt: transport
+ * @sk: socket
+ * @skb: skbuff
+ *
+ */
+static void xs_udp_data_read_skb(struct rpc_xprt *xprt,
+ struct sock *sk,
+ struct sk_buff *skb)
+{
+ struct rpc_task *task;
+ struct rpc_rqst *rovr;
+ int repsize, copied;
+ u32 _xid;
+ __be32 *xp;
+
+ repsize = skb->len;
+ if (repsize < 4) {
+ dprintk("RPC: impossible RPC reply size %d!\n", repsize);
+ return;
+ }
+
+ /* Copy the XID from the skb... */
+ xp = skb_header_pointer(skb, 0, sizeof(_xid), &_xid);
+ if (xp == NULL)
+ return;
+
+ /* Look up and lock the request corresponding to the given XID */
+ spin_lock(&xprt->queue_lock);
+ rovr = xprt_lookup_rqst(xprt, *xp);
+ if (!rovr)
+ goto out_unlock;
+ xprt_pin_rqst(rovr);
+ xprt_update_rtt(rovr->rq_task);
+ spin_unlock(&xprt->queue_lock);
+ task = rovr->rq_task;
+
+ if ((copied = rovr->rq_private_buf.buflen) > repsize)
+ copied = repsize;
+
+ /* Suck it into the iovec, verify checksum if not done by hw. */
+ if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) {
+ spin_lock(&xprt->queue_lock);
+ __UDPX_INC_STATS(sk, UDP_MIB_INERRORS);
+ goto out_unpin;
+ }
+
+
+ spin_lock(&xprt->transport_lock);
+ xprt_adjust_cwnd(xprt, task, copied);
+ spin_unlock(&xprt->transport_lock);
+ spin_lock(&xprt->queue_lock);
+ xprt_complete_rqst(task, copied);
+ __UDPX_INC_STATS(sk, UDP_MIB_INDATAGRAMS);
+out_unpin:
+ xprt_unpin_rqst(rovr);
+ out_unlock:
+ spin_unlock(&xprt->queue_lock);
+}
+
+static void xs_udp_data_receive(struct sock_xprt *transport)
+{
+ struct sk_buff *skb;
+ struct sock *sk;
+ int err;
+
+ mutex_lock(&transport->recv_mutex);
+ sk = transport->inet;
+ if (sk == NULL)
+ goto out;
+ for (;;) {
+ skb = skb_recv_udp(sk, MSG_DONTWAIT, &err);
+ if (skb == NULL)
+ break;
+ xs_udp_data_read_skb(&transport->xprt, sk, skb);
+ consume_skb(skb);
+ cond_resched();
+ }
+ xs_poll_check_readable(transport);
+out:
+ mutex_unlock(&transport->recv_mutex);
+}
+
+static void xs_udp_data_receive_workfn(struct work_struct *work)
+{
+ struct sock_xprt *transport =
+ container_of(work, struct sock_xprt, recv_worker);
+ unsigned int pflags = memalloc_nofs_save();
+
+ xs_udp_data_receive(transport);
+ memalloc_nofs_restore(pflags);
+}
+
+/**
+ * xs_data_ready - "data ready" callback for sockets
+ * @sk: socket with data to read
+ *
+ */
+static void xs_data_ready(struct sock *sk)
+{
+ struct rpc_xprt *xprt;
+
+ trace_sk_data_ready(sk);
+
+ xprt = xprt_from_sock(sk);
+ if (xprt != NULL) {
+ struct sock_xprt *transport = container_of(xprt,
+ struct sock_xprt, xprt);
+
+ trace_xs_data_ready(xprt);
+
+ transport->old_data_ready(sk);
+
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
+
+ /* Any data means we had a useful conversation, so
+ * then we don't need to delay the next reconnect
+ */
+ if (xprt->reestablish_timeout)
+ xprt->reestablish_timeout = 0;
+ if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))
+ queue_work(xprtiod_workqueue, &transport->recv_worker);
+ }
+}
+
+/*
+ * Helper function to force a TCP close if the server is sending
+ * junk and/or it has put us in CLOSE_WAIT
+ */
+static void xs_tcp_force_close(struct rpc_xprt *xprt)
+{
+ xprt_force_disconnect(xprt);
+}
+
+#if defined(CONFIG_SUNRPC_BACKCHANNEL)
+static size_t xs_tcp_bc_maxpayload(struct rpc_xprt *xprt)
+{
+ return PAGE_SIZE;
+}
+#endif /* CONFIG_SUNRPC_BACKCHANNEL */
+
+/**
+ * xs_local_state_change - callback to handle AF_LOCAL socket state changes
+ * @sk: socket whose state has changed
+ *
+ */
+static void xs_local_state_change(struct sock *sk)
+{
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+
+ if (!(xprt = xprt_from_sock(sk)))
+ return;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+ if (sk->sk_shutdown & SHUTDOWN_MASK) {
+ clear_bit(XPRT_CONNECTED, &xprt->state);
+ /* Trigger the socket release */
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT);
+ }
+}
+
+/**
+ * xs_tcp_state_change - callback to handle TCP socket state changes
+ * @sk: socket whose state has changed
+ *
+ */
+static void xs_tcp_state_change(struct sock *sk)
+{
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+
+ if (!(xprt = xprt_from_sock(sk)))
+ return;
+ dprintk("RPC: xs_tcp_state_change client %p...\n", xprt);
+ dprintk("RPC: state %x conn %d dead %d zapped %d sk_shutdown %d\n",
+ sk->sk_state, xprt_connected(xprt),
+ sock_flag(sk, SOCK_DEAD),
+ sock_flag(sk, SOCK_ZAPPED),
+ sk->sk_shutdown);
+
+ transport = container_of(xprt, struct sock_xprt, xprt);
+ trace_rpc_socket_state_change(xprt, sk->sk_socket);
+ switch (sk->sk_state) {
+ case TCP_ESTABLISHED:
+ if (!xprt_test_and_set_connected(xprt)) {
+ xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &transport->sock_state);
+ xprt_clear_connecting(xprt);
+
+ xprt->stat.connect_count++;
+ xprt->stat.connect_time += (long)jiffies -
+ xprt->stat.connect_start;
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ break;
+ case TCP_FIN_WAIT1:
+ /* The client initiated a shutdown of the socket */
+ xprt->connect_cookie++;
+ xprt->reestablish_timeout = 0;
+ set_bit(XPRT_CLOSING, &xprt->state);
+ smp_mb__before_atomic();
+ clear_bit(XPRT_CONNECTED, &xprt->state);
+ clear_bit(XPRT_CLOSE_WAIT, &xprt->state);
+ smp_mb__after_atomic();
+ break;
+ case TCP_CLOSE_WAIT:
+ /* The server initiated a shutdown of the socket */
+ xprt->connect_cookie++;
+ clear_bit(XPRT_CONNECTED, &xprt->state);
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT);
+ fallthrough;
+ case TCP_CLOSING:
+ /*
+ * If the server closed down the connection, make sure that
+ * we back off before reconnecting
+ */
+ if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO)
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ break;
+ case TCP_LAST_ACK:
+ set_bit(XPRT_CLOSING, &xprt->state);
+ smp_mb__before_atomic();
+ clear_bit(XPRT_CONNECTED, &xprt->state);
+ smp_mb__after_atomic();
+ break;
+ case TCP_CLOSE:
+ if (test_and_clear_bit(XPRT_SOCK_CONNECTING,
+ &transport->sock_state))
+ xprt_clear_connecting(xprt);
+ clear_bit(XPRT_CLOSING, &xprt->state);
+ /* Trigger the socket release */
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT);
+ }
+}
+
+static void xs_write_space(struct sock *sk)
+{
+ struct sock_xprt *transport;
+ struct rpc_xprt *xprt;
+
+ if (!sk->sk_socket)
+ return;
+ clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+
+ if (unlikely(!(xprt = xprt_from_sock(sk))))
+ return;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+ if (!test_and_clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state))
+ return;
+ xs_run_error_worker(transport, XPRT_SOCK_WAKE_WRITE);
+ sk->sk_write_pending--;
+}
+
+/**
+ * xs_udp_write_space - callback invoked when socket buffer space
+ * becomes available
+ * @sk: socket whose state has changed
+ *
+ * Called when more output buffer space is available for this socket.
+ * We try not to wake our writers until they can make "significant"
+ * progress, otherwise we'll waste resources thrashing kernel_sendmsg
+ * with a bunch of small requests.
+ */
+static void xs_udp_write_space(struct sock *sk)
+{
+ /* from net/core/sock.c:sock_def_write_space */
+ if (sock_writeable(sk))
+ xs_write_space(sk);
+}
+
+/**
+ * xs_tcp_write_space - callback invoked when socket buffer space
+ * becomes available
+ * @sk: socket whose state has changed
+ *
+ * Called when more output buffer space is available for this socket.
+ * We try not to wake our writers until they can make "significant"
+ * progress, otherwise we'll waste resources thrashing kernel_sendmsg
+ * with a bunch of small requests.
+ */
+static void xs_tcp_write_space(struct sock *sk)
+{
+ /* from net/core/stream.c:sk_stream_write_space */
+ if (sk_stream_is_writeable(sk))
+ xs_write_space(sk);
+}
+
+static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct sock *sk = transport->inet;
+
+ if (transport->rcvsize) {
+ sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
+ sk->sk_rcvbuf = transport->rcvsize * xprt->max_reqs * 2;
+ }
+ if (transport->sndsize) {
+ sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
+ sk->sk_sndbuf = transport->sndsize * xprt->max_reqs * 2;
+ sk->sk_write_space(sk);
+ }
+}
+
+/**
+ * xs_udp_set_buffer_size - set send and receive limits
+ * @xprt: generic transport
+ * @sndsize: requested size of send buffer, in bytes
+ * @rcvsize: requested size of receive buffer, in bytes
+ *
+ * Set socket send and receive buffer size limits.
+ */
+static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t rcvsize)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ transport->sndsize = 0;
+ if (sndsize)
+ transport->sndsize = sndsize + 1024;
+ transport->rcvsize = 0;
+ if (rcvsize)
+ transport->rcvsize = rcvsize + 1024;
+
+ xs_udp_do_set_buffer_size(xprt);
+}
+
+/**
+ * xs_udp_timer - called when a retransmit timeout occurs on a UDP transport
+ * @xprt: controlling transport
+ * @task: task that timed out
+ *
+ * Adjust the congestion window after a retransmit timeout has occurred.
+ */
+static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ spin_lock(&xprt->transport_lock);
+ xprt_adjust_cwnd(xprt, task, -ETIMEDOUT);
+ spin_unlock(&xprt->transport_lock);
+}
+
+static int xs_get_random_port(void)
+{
+ unsigned short min = xprt_min_resvport, max = xprt_max_resvport;
+ unsigned short range;
+ unsigned short rand;
+
+ if (max < min)
+ return -EADDRINUSE;
+ range = max - min + 1;
+ rand = get_random_u32_below(range);
+ return rand + min;
+}
+
+static unsigned short xs_sock_getport(struct socket *sock)
+{
+ struct sockaddr_storage buf;
+ unsigned short port = 0;
+
+ if (kernel_getsockname(sock, (struct sockaddr *)&buf) < 0)
+ goto out;
+ switch (buf.ss_family) {
+ case AF_INET6:
+ port = ntohs(((struct sockaddr_in6 *)&buf)->sin6_port);
+ break;
+ case AF_INET:
+ port = ntohs(((struct sockaddr_in *)&buf)->sin_port);
+ }
+out:
+ return port;
+}
+
+/**
+ * xs_set_port - reset the port number in the remote endpoint address
+ * @xprt: generic transport
+ * @port: new port number
+ *
+ */
+static void xs_set_port(struct rpc_xprt *xprt, unsigned short port)
+{
+ dprintk("RPC: setting port for xprt %p to %u\n", xprt, port);
+
+ rpc_set_port(xs_addr(xprt), port);
+ xs_update_peer_port(xprt);
+}
+
+static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock)
+{
+ if (transport->srcport == 0 && transport->xprt.reuseport)
+ transport->srcport = xs_sock_getport(sock);
+}
+
+static int xs_get_srcport(struct sock_xprt *transport)
+{
+ int port = transport->srcport;
+
+ if (port == 0 && transport->xprt.resvport)
+ port = xs_get_random_port();
+ return port;
+}
+
+static unsigned short xs_sock_srcport(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt);
+ unsigned short ret = 0;
+ mutex_lock(&sock->recv_mutex);
+ if (sock->sock)
+ ret = xs_sock_getport(sock->sock);
+ mutex_unlock(&sock->recv_mutex);
+ return ret;
+}
+
+static int xs_sock_srcaddr(struct rpc_xprt *xprt, char *buf, size_t buflen)
+{
+ struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt);
+ union {
+ struct sockaddr sa;
+ struct sockaddr_storage st;
+ } saddr;
+ int ret = -ENOTCONN;
+
+ mutex_lock(&sock->recv_mutex);
+ if (sock->sock) {
+ ret = kernel_getsockname(sock->sock, &saddr.sa);
+ if (ret >= 0)
+ ret = snprintf(buf, buflen, "%pISc", &saddr.sa);
+ }
+ mutex_unlock(&sock->recv_mutex);
+ return ret;
+}
+
+static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port)
+{
+ if (transport->srcport != 0)
+ transport->srcport = 0;
+ if (!transport->xprt.resvport)
+ return 0;
+ if (port <= xprt_min_resvport || port > xprt_max_resvport)
+ return xprt_max_resvport;
+ return --port;
+}
+static int xs_bind(struct sock_xprt *transport, struct socket *sock)
+{
+ struct sockaddr_storage myaddr;
+ int err, nloop = 0;
+ int port = xs_get_srcport(transport);
+ unsigned short last;
+
+ /*
+ * If we are asking for any ephemeral port (i.e. port == 0 &&
+ * transport->xprt.resvport == 0), don't bind. Let the local
+ * port selection happen implicitly when the socket is used
+ * (for example at connect time).
+ *
+ * This ensures that we can continue to establish TCP
+ * connections even when all local ephemeral ports are already
+ * a part of some TCP connection. This makes no difference
+ * for UDP sockets, but also doesn't harm them.
+ *
+ * If we're asking for any reserved port (i.e. port == 0 &&
+ * transport->xprt.resvport == 1) xs_get_srcport above will
+ * ensure that port is non-zero and we will bind as needed.
+ */
+ if (port <= 0)
+ return port;
+
+ memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen);
+ do {
+ rpc_set_port((struct sockaddr *)&myaddr, port);
+ err = kernel_bind(sock, (struct sockaddr *)&myaddr,
+ transport->xprt.addrlen);
+ if (err == 0) {
+ if (transport->xprt.reuseport)
+ transport->srcport = port;
+ break;
+ }
+ last = port;
+ port = xs_next_srcport(transport, port);
+ if (port > last)
+ nloop++;
+ } while (err == -EADDRINUSE && nloop != 2);
+
+ if (myaddr.ss_family == AF_INET)
+ dprintk("RPC: %s %pI4:%u: %s (%d)\n", __func__,
+ &((struct sockaddr_in *)&myaddr)->sin_addr,
+ port, err ? "failed" : "ok", err);
+ else
+ dprintk("RPC: %s %pI6:%u: %s (%d)\n", __func__,
+ &((struct sockaddr_in6 *)&myaddr)->sin6_addr,
+ port, err ? "failed" : "ok", err);
+ return err;
+}
+
+/*
+ * We don't support autobind on AF_LOCAL sockets
+ */
+static void xs_local_rpcbind(struct rpc_task *task)
+{
+ xprt_set_bound(task->tk_xprt);
+}
+
+static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port)
+{
+}
+
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+static struct lock_class_key xs_key[3];
+static struct lock_class_key xs_slock_key[3];
+
+static inline void xs_reclassify_socketu(struct socket *sock)
+{
+ struct sock *sk = sock->sk;
+
+ sock_lock_init_class_and_name(sk, "slock-AF_LOCAL-RPC",
+ &xs_slock_key[0], "sk_lock-AF_LOCAL-RPC", &xs_key[0]);
+}
+
+static inline void xs_reclassify_socket4(struct socket *sock)
+{
+ struct sock *sk = sock->sk;
+
+ sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC",
+ &xs_slock_key[1], "sk_lock-AF_INET-RPC", &xs_key[1]);
+}
+
+static inline void xs_reclassify_socket6(struct socket *sock)
+{
+ struct sock *sk = sock->sk;
+
+ sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC",
+ &xs_slock_key[2], "sk_lock-AF_INET6-RPC", &xs_key[2]);
+}
+
+static inline void xs_reclassify_socket(int family, struct socket *sock)
+{
+ if (WARN_ON_ONCE(!sock_allow_reclassification(sock->sk)))
+ return;
+
+ switch (family) {
+ case AF_LOCAL:
+ xs_reclassify_socketu(sock);
+ break;
+ case AF_INET:
+ xs_reclassify_socket4(sock);
+ break;
+ case AF_INET6:
+ xs_reclassify_socket6(sock);
+ break;
+ }
+}
+#else
+static inline void xs_reclassify_socket(int family, struct socket *sock)
+{
+}
+#endif
+
+static void xs_dummy_setup_socket(struct work_struct *work)
+{
+}
+
+static struct socket *xs_create_sock(struct rpc_xprt *xprt,
+ struct sock_xprt *transport, int family, int type,
+ int protocol, bool reuseport)
+{
+ struct file *filp;
+ struct socket *sock;
+ int err;
+
+ err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1);
+ if (err < 0) {
+ dprintk("RPC: can't create %d transport socket (%d).\n",
+ protocol, -err);
+ goto out;
+ }
+ xs_reclassify_socket(family, sock);
+
+ if (reuseport)
+ sock_set_reuseport(sock->sk);
+
+ err = xs_bind(transport, sock);
+ if (err) {
+ sock_release(sock);
+ goto out;
+ }
+
+ filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
+ if (IS_ERR(filp))
+ return ERR_CAST(filp);
+ transport->file = filp;
+
+ return sock;
+out:
+ return ERR_PTR(err);
+}
+
+static int xs_local_finish_connecting(struct rpc_xprt *xprt,
+ struct socket *sock)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt,
+ xprt);
+
+ if (!transport->inet) {
+ struct sock *sk = sock->sk;
+
+ lock_sock(sk);
+
+ xs_save_old_callbacks(transport, sk);
+
+ sk->sk_user_data = xprt;
+ sk->sk_data_ready = xs_data_ready;
+ sk->sk_write_space = xs_udp_write_space;
+ sk->sk_state_change = xs_local_state_change;
+ sk->sk_error_report = xs_error_report;
+ sk->sk_use_task_frag = false;
+
+ xprt_clear_connected(xprt);
+
+ /* Reset to new socket */
+ transport->sock = sock;
+ transport->inet = sk;
+
+ release_sock(sk);
+ }
+
+ xs_stream_start_connect(transport);
+
+ return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0);
+}
+
+/**
+ * xs_local_setup_socket - create AF_LOCAL socket, connect to a local endpoint
+ * @transport: socket transport to connect
+ */
+static int xs_local_setup_socket(struct sock_xprt *transport)
+{
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct file *filp;
+ struct socket *sock;
+ int status;
+
+ status = __sock_create(xprt->xprt_net, AF_LOCAL,
+ SOCK_STREAM, 0, &sock, 1);
+ if (status < 0) {
+ dprintk("RPC: can't create AF_LOCAL "
+ "transport socket (%d).\n", -status);
+ goto out;
+ }
+ xs_reclassify_socket(AF_LOCAL, sock);
+
+ filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
+ if (IS_ERR(filp)) {
+ status = PTR_ERR(filp);
+ goto out;
+ }
+ transport->file = filp;
+
+ dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n",
+ xprt, xprt->address_strings[RPC_DISPLAY_ADDR]);
+
+ status = xs_local_finish_connecting(xprt, sock);
+ trace_rpc_socket_connect(xprt, sock, status);
+ switch (status) {
+ case 0:
+ dprintk("RPC: xprt %p connected to %s\n",
+ xprt, xprt->address_strings[RPC_DISPLAY_ADDR]);
+ xprt->stat.connect_count++;
+ xprt->stat.connect_time += (long)jiffies -
+ xprt->stat.connect_start;
+ xprt_set_connected(xprt);
+ break;
+ case -ENOBUFS:
+ break;
+ case -ENOENT:
+ dprintk("RPC: xprt %p: socket %s does not exist\n",
+ xprt, xprt->address_strings[RPC_DISPLAY_ADDR]);
+ break;
+ case -ECONNREFUSED:
+ dprintk("RPC: xprt %p: connection refused for %s\n",
+ xprt, xprt->address_strings[RPC_DISPLAY_ADDR]);
+ break;
+ default:
+ printk(KERN_ERR "%s: unhandled error (%d) connecting to %s\n",
+ __func__, -status,
+ xprt->address_strings[RPC_DISPLAY_ADDR]);
+ }
+
+out:
+ xprt_clear_connecting(xprt);
+ xprt_wake_pending_tasks(xprt, status);
+ return status;
+}
+
+static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ int ret;
+
+ if (transport->file)
+ goto force_disconnect;
+
+ if (RPC_IS_ASYNC(task)) {
+ /*
+ * We want the AF_LOCAL connect to be resolved in the
+ * filesystem namespace of the process making the rpc
+ * call. Thus we connect synchronously.
+ *
+ * If we want to support asynchronous AF_LOCAL calls,
+ * we'll need to figure out how to pass a namespace to
+ * connect.
+ */
+ rpc_task_set_rpc_status(task, -ENOTCONN);
+ goto out_wake;
+ }
+ ret = xs_local_setup_socket(transport);
+ if (ret && !RPC_IS_SOFTCONN(task))
+ msleep_interruptible(15000);
+ return;
+force_disconnect:
+ xprt_force_disconnect(xprt);
+out_wake:
+ xprt_clear_connecting(xprt);
+ xprt_wake_pending_tasks(xprt, -ENOTCONN);
+}
+
+#if IS_ENABLED(CONFIG_SUNRPC_SWAP)
+/*
+ * Note that this should be called with XPRT_LOCKED held, or recv_mutex
+ * held, or when we otherwise know that we have exclusive access to the
+ * socket, to guard against races with xs_reset_transport.
+ */
+static void xs_set_memalloc(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt,
+ xprt);
+
+ /*
+ * If there's no sock, then we have nothing to set. The
+ * reconnecting process will get it for us.
+ */
+ if (!transport->inet)
+ return;
+ if (atomic_read(&xprt->swapper))
+ sk_set_memalloc(transport->inet);
+}
+
+/**
+ * xs_enable_swap - Tag this transport as being used for swap.
+ * @xprt: transport to tag
+ *
+ * Take a reference to this transport on behalf of the rpc_clnt, and
+ * optionally mark it for swapping if it wasn't already.
+ */
+static int
+xs_enable_swap(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt);
+
+ mutex_lock(&xs->recv_mutex);
+ if (atomic_inc_return(&xprt->swapper) == 1 &&
+ xs->inet)
+ sk_set_memalloc(xs->inet);
+ mutex_unlock(&xs->recv_mutex);
+ return 0;
+}
+
+/**
+ * xs_disable_swap - Untag this transport as being used for swap.
+ * @xprt: transport to tag
+ *
+ * Drop a "swapper" reference to this xprt on behalf of the rpc_clnt. If the
+ * swapper refcount goes to 0, untag the socket as a memalloc socket.
+ */
+static void
+xs_disable_swap(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt);
+
+ mutex_lock(&xs->recv_mutex);
+ if (atomic_dec_and_test(&xprt->swapper) &&
+ xs->inet)
+ sk_clear_memalloc(xs->inet);
+ mutex_unlock(&xs->recv_mutex);
+}
+#else
+static void xs_set_memalloc(struct rpc_xprt *xprt)
+{
+}
+
+static int
+xs_enable_swap(struct rpc_xprt *xprt)
+{
+ return -EINVAL;
+}
+
+static void
+xs_disable_swap(struct rpc_xprt *xprt)
+{
+}
+#endif
+
+static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ if (!transport->inet) {
+ struct sock *sk = sock->sk;
+
+ lock_sock(sk);
+
+ xs_save_old_callbacks(transport, sk);
+
+ sk->sk_user_data = xprt;
+ sk->sk_data_ready = xs_data_ready;
+ sk->sk_write_space = xs_udp_write_space;
+ sk->sk_use_task_frag = false;
+
+ xprt_set_connected(xprt);
+
+ /* Reset to new socket */
+ transport->sock = sock;
+ transport->inet = sk;
+
+ xs_set_memalloc(xprt);
+
+ release_sock(sk);
+ }
+ xs_udp_do_set_buffer_size(xprt);
+
+ xprt->stat.connect_start = jiffies;
+}
+
+static void xs_udp_setup_socket(struct work_struct *work)
+{
+ struct sock_xprt *transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct rpc_xprt *xprt = &transport->xprt;
+ struct socket *sock;
+ int status = -EIO;
+ unsigned int pflags = current->flags;
+
+ if (atomic_read(&xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+ sock = xs_create_sock(xprt, transport,
+ xs_addr(xprt)->sa_family, SOCK_DGRAM,
+ IPPROTO_UDP, false);
+ if (IS_ERR(sock))
+ goto out;
+
+ dprintk("RPC: worker connecting xprt %p via %s to "
+ "%s (port %s)\n", xprt,
+ xprt->address_strings[RPC_DISPLAY_PROTO],
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT]);
+
+ xs_udp_finish_connecting(xprt, sock);
+ trace_rpc_socket_connect(xprt, sock, 0);
+ status = 0;
+out:
+ xprt_clear_connecting(xprt);
+ xprt_unlock_connect(xprt, transport);
+ xprt_wake_pending_tasks(xprt, status);
+ current_restore_flags(pflags, PF_MEMALLOC);
+}
+
+/**
+ * xs_tcp_shutdown - gracefully shut down a TCP socket
+ * @xprt: transport
+ *
+ * Initiates a graceful shutdown of the TCP socket by calling the
+ * equivalent of shutdown(SHUT_RDWR);
+ */
+static void xs_tcp_shutdown(struct rpc_xprt *xprt)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct socket *sock = transport->sock;
+ int skst = transport->inet ? transport->inet->sk_state : TCP_CLOSE;
+
+ if (sock == NULL)
+ return;
+ if (!xprt->reuseport) {
+ xs_close(xprt);
+ return;
+ }
+ switch (skst) {
+ case TCP_FIN_WAIT1:
+ case TCP_FIN_WAIT2:
+ case TCP_LAST_ACK:
+ break;
+ case TCP_ESTABLISHED:
+ case TCP_CLOSE_WAIT:
+ kernel_sock_shutdown(sock, SHUT_RDWR);
+ trace_rpc_socket_shutdown(xprt, sock);
+ break;
+ default:
+ xs_reset_transport(transport);
+ }
+}
+
+static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
+ struct socket *sock)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct net *net = sock_net(sock->sk);
+ unsigned long connect_timeout;
+ unsigned long syn_retries;
+ unsigned int keepidle;
+ unsigned int keepcnt;
+ unsigned int timeo;
+ unsigned long t;
+
+ spin_lock(&xprt->transport_lock);
+ keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ);
+ keepcnt = xprt->timeout->to_retries + 1;
+ timeo = jiffies_to_msecs(xprt->timeout->to_initval) *
+ (xprt->timeout->to_retries + 1);
+ clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
+ spin_unlock(&xprt->transport_lock);
+
+ /* TCP Keepalive options */
+ sock_set_keepalive(sock->sk);
+ tcp_sock_set_keepidle(sock->sk, keepidle);
+ tcp_sock_set_keepintvl(sock->sk, keepidle);
+ tcp_sock_set_keepcnt(sock->sk, keepcnt);
+
+ /* TCP user timeout (see RFC5482) */
+ tcp_sock_set_user_timeout(sock->sk, timeo);
+
+ /* Connect timeout */
+ connect_timeout = max_t(unsigned long,
+ DIV_ROUND_UP(xprt->connect_timeout, HZ), 1);
+ syn_retries = max_t(unsigned long,
+ READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1);
+ for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++)
+ ;
+ if (t <= syn_retries)
+ tcp_sock_set_syncnt(sock->sk, t - 1);
+}
+
+static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt,
+ unsigned long connect_timeout)
+{
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+ struct rpc_timeout to;
+ unsigned long initval;
+
+ memcpy(&to, xprt->timeout, sizeof(to));
+ /* Arbitrary lower limit */
+ initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO);
+ to.to_initval = initval;
+ to.to_maxval = initval;
+ to.to_retries = 0;
+ memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout));
+ xprt->timeout = &transport->tcp_timeout;
+ xprt->connect_timeout = connect_timeout;
+}
+
+static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt,
+ unsigned long connect_timeout,
+ unsigned long reconnect_timeout)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ spin_lock(&xprt->transport_lock);
+ if (reconnect_timeout < xprt->max_reconnect_timeout)
+ xprt->max_reconnect_timeout = reconnect_timeout;
+ if (connect_timeout < xprt->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, connect_timeout);
+ set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
+ spin_unlock(&xprt->transport_lock);
+}
+
+static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ if (!transport->inet) {
+ struct sock *sk = sock->sk;
+
+ /* Avoid temporary address, they are bad for long-lived
+ * connections such as NFS mounts.
+ * RFC4941, section 3.6 suggests that:
+ * Individual applications, which have specific
+ * knowledge about the normal duration of connections,
+ * MAY override this as appropriate.
+ */
+ if (xs_addr(xprt)->sa_family == PF_INET6) {
+ ip6_sock_set_addr_preferences(sk,
+ IPV6_PREFER_SRC_PUBLIC);
+ }
+
+ xs_tcp_set_socket_timeouts(xprt, sock);
+ tcp_sock_set_nodelay(sk);
+
+ lock_sock(sk);
+
+ xs_save_old_callbacks(transport, sk);
+
+ sk->sk_user_data = xprt;
+ sk->sk_data_ready = xs_data_ready;
+ sk->sk_state_change = xs_tcp_state_change;
+ sk->sk_write_space = xs_tcp_write_space;
+ sk->sk_error_report = xs_error_report;
+ sk->sk_use_task_frag = false;
+
+ /* socket options */
+ sock_reset_flag(sk, SOCK_LINGER);
+
+ xprt_clear_connected(xprt);
+
+ /* Reset to new socket */
+ transport->sock = sock;
+ transport->inet = sk;
+
+ release_sock(sk);
+ }
+
+ if (!xprt_bound(xprt))
+ return -ENOTCONN;
+
+ xs_set_memalloc(xprt);
+
+ xs_stream_start_connect(transport);
+
+ /* Tell the socket layer to start connecting... */
+ set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state);
+ return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
+}
+
+/**
+ * xs_tcp_setup_socket - create a TCP socket and connect to a remote endpoint
+ * @work: queued work item
+ *
+ * Invoked by a work queue tasklet.
+ */
+static void xs_tcp_setup_socket(struct work_struct *work)
+{
+ struct sock_xprt *transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct socket *sock = transport->sock;
+ struct rpc_xprt *xprt = &transport->xprt;
+ int status;
+ unsigned int pflags = current->flags;
+
+ if (atomic_read(&xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+
+ if (xprt_connected(xprt))
+ goto out;
+ if (test_and_clear_bit(XPRT_SOCK_CONNECT_SENT,
+ &transport->sock_state) ||
+ !sock) {
+ xs_reset_transport(transport);
+ sock = xs_create_sock(xprt, transport, xs_addr(xprt)->sa_family,
+ SOCK_STREAM, IPPROTO_TCP, true);
+ if (IS_ERR(sock)) {
+ xprt_wake_pending_tasks(xprt, PTR_ERR(sock));
+ goto out;
+ }
+ }
+
+ dprintk("RPC: worker connecting xprt %p via %s to "
+ "%s (port %s)\n", xprt,
+ xprt->address_strings[RPC_DISPLAY_PROTO],
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT]);
+
+ status = xs_tcp_finish_connecting(xprt, sock);
+ trace_rpc_socket_connect(xprt, sock, status);
+ dprintk("RPC: %p connect status %d connected %d sock state %d\n",
+ xprt, -status, xprt_connected(xprt),
+ sock->sk->sk_state);
+ switch (status) {
+ case 0:
+ case -EINPROGRESS:
+ /* SYN_SENT! */
+ set_bit(XPRT_SOCK_CONNECT_SENT, &transport->sock_state);
+ if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO)
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ fallthrough;
+ case -EALREADY:
+ goto out_unlock;
+ case -EADDRNOTAVAIL:
+ /* Source port number is unavailable. Try a new one! */
+ transport->srcport = 0;
+ status = -EAGAIN;
+ break;
+ case -EINVAL:
+ /* Happens, for instance, if the user specified a link
+ * local IPv6 address without a scope-id.
+ */
+ case -ECONNREFUSED:
+ case -ECONNRESET:
+ case -ENETDOWN:
+ case -ENETUNREACH:
+ case -EHOSTUNREACH:
+ case -EADDRINUSE:
+ case -ENOBUFS:
+ break;
+ default:
+ printk("%s: connect returned unhandled error %d\n",
+ __func__, status);
+ status = -EAGAIN;
+ }
+
+ /* xs_tcp_force_close() wakes tasks with a fixed error code.
+ * We need to wake them first to ensure the correct error code.
+ */
+ xprt_wake_pending_tasks(xprt, status);
+ xs_tcp_force_close(xprt);
+out:
+ xprt_clear_connecting(xprt);
+out_unlock:
+ xprt_unlock_connect(xprt, transport);
+ current_restore_flags(pflags, PF_MEMALLOC);
+}
+
+/*
+ * Transfer the connected socket to @upper_transport, then mark that
+ * xprt CONNECTED.
+ */
+static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
+ struct sock_xprt *upper_transport)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+
+ if (!upper_transport->inet) {
+ struct socket *sock = lower_transport->sock;
+ struct sock *sk = sock->sk;
+
+ /* Avoid temporary address, they are bad for long-lived
+ * connections such as NFS mounts.
+ * RFC4941, section 3.6 suggests that:
+ * Individual applications, which have specific
+ * knowledge about the normal duration of connections,
+ * MAY override this as appropriate.
+ */
+ if (xs_addr(upper_xprt)->sa_family == PF_INET6)
+ ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);
+
+ xs_tcp_set_socket_timeouts(upper_xprt, sock);
+ tcp_sock_set_nodelay(sk);
+
+ lock_sock(sk);
+
+ /* @sk is already connected, so it now has the RPC callbacks.
+ * Reach into @lower_transport to save the original ones.
+ */
+ upper_transport->old_data_ready = lower_transport->old_data_ready;
+ upper_transport->old_state_change = lower_transport->old_state_change;
+ upper_transport->old_write_space = lower_transport->old_write_space;
+ upper_transport->old_error_report = lower_transport->old_error_report;
+ sk->sk_user_data = upper_xprt;
+
+ /* socket options */
+ sock_reset_flag(sk, SOCK_LINGER);
+
+ xprt_clear_connected(upper_xprt);
+
+ upper_transport->sock = sock;
+ upper_transport->inet = sk;
+ upper_transport->file = lower_transport->file;
+
+ release_sock(sk);
+
+ /* Reset lower_transport before shutting down its clnt */
+ mutex_lock(&lower_transport->recv_mutex);
+ lower_transport->inet = NULL;
+ lower_transport->sock = NULL;
+ lower_transport->file = NULL;
+
+ xprt_clear_connected(lower_xprt);
+ xs_sock_reset_connection_flags(lower_xprt);
+ xs_stream_reset_connect(lower_transport);
+ mutex_unlock(&lower_transport->recv_mutex);
+ }
+
+ if (!xprt_bound(upper_xprt))
+ return -ENOTCONN;
+
+ xs_set_memalloc(upper_xprt);
+
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ return 0;
+}
+
+/**
+ * xs_tls_handshake_done - TLS handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote's identity
+ *
+ */
+static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
+{
+ struct rpc_xprt *lower_xprt = data;
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+
+ lower_transport->xprt_err = status ? -EACCES : 0;
+ complete(&lower_transport->handshake_done);
+ xprt_put(lower_xprt);
+}
+
+static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct tls_handshake_args args = {
+ .ta_sock = lower_transport->sock,
+ .ta_done = xs_tls_handshake_done,
+ .ta_data = xprt_get(lower_xprt),
+ .ta_peername = lower_xprt->servername,
+ };
+ struct sock *sk = lower_transport->inet;
+ int rc;
+
+ init_completion(&lower_transport->handshake_done);
+ set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ lower_transport->xprt_err = -ETIMEDOUT;
+ switch (xprtsec->policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ rc = tls_client_hello_anon(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ case RPC_XPRTSEC_TLS_X509:
+ args.ta_my_cert = xprtsec->cert_serial;
+ args.ta_my_privkey = xprtsec->privkey_serial;
+ rc = tls_client_hello_x509(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ default:
+ rc = -EACCES;
+ goto out_put_xprt;
+ }
+
+ rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
+ XS_TLS_HANDSHAKE_TO);
+ if (rc <= 0) {
+ if (!tls_handshake_cancel(sk)) {
+ if (rc == 0)
+ rc = -ETIMEDOUT;
+ goto out_put_xprt;
+ }
+ }
+
+ rc = lower_transport->xprt_err;
+
+out:
+ xs_stream_reset_connect(lower_transport);
+ clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ return rc;
+
+out_put_xprt:
+ xprt_put(lower_xprt);
+ goto out;
+}
+
+/**
+ * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
+ * @work: queued work item
+ *
+ * Invoked by a work queue tasklet.
+ *
+ * For RPC-with-TLS, there is a two-stage connection process.
+ *
+ * The "upper-layer xprt" is visible to the RPC consumer. Once it has
+ * been marked connected, the consumer knows that a TCP connection and
+ * a TLS session have been established.
+ *
+ * A "lower-layer xprt", created in this function, handles the mechanics
+ * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
+ * then driving the TLS handshake. Once all that is complete, the upper
+ * layer xprt is marked connected.
+ */
+static void xs_tcp_tls_setup_socket(struct work_struct *work)
+{
+ struct sock_xprt *upper_transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct rpc_clnt *upper_clnt = upper_transport->clnt;
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+ struct rpc_create_args args = {
+ .net = upper_xprt->xprt_net,
+ .protocol = upper_xprt->prot,
+ .address = (struct sockaddr *)&upper_xprt->addr,
+ .addrsize = upper_xprt->addrlen,
+ .timeout = upper_clnt->cl_timeout,
+ .servername = upper_xprt->servername,
+ .program = upper_clnt->cl_program,
+ .prognumber = upper_clnt->cl_prog,
+ .version = upper_clnt->cl_vers,
+ .authflavor = RPC_AUTH_TLS,
+ .cred = upper_clnt->cl_cred,
+ .xprtsec = {
+ .policy = RPC_XPRTSEC_NONE,
+ },
+ };
+ unsigned int pflags = current->flags;
+ struct rpc_clnt *lower_clnt;
+ struct rpc_xprt *lower_xprt;
+ int status;
+
+ if (atomic_read(&upper_xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+
+ xs_stream_start_connect(upper_transport);
+
+ /* This implicitly sends an RPC_AUTH_TLS probe */
+ lower_clnt = rpc_create(&args);
+ if (IS_ERR(lower_clnt)) {
+ trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+ xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ goto out_unlock;
+ }
+
+ /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
+ * the lower xprt.
+ */
+ rcu_read_lock();
+ lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
+ rcu_read_unlock();
+
+ if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE))
+ goto out_unlock;
+
+ status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
+ if (status) {
+ trace_rpc_tls_not_started(upper_clnt, upper_xprt);
+ goto out_close;
+ }
+
+ status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
+ if (status)
+ goto out_close;
+ xprt_release_write(lower_xprt, NULL);
+
+ trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ rpc_shutdown_client(lower_clnt);
+
+out_unlock:
+ current_restore_flags(pflags, PF_MEMALLOC);
+ upper_transport->clnt = NULL;
+ xprt_unlock_connect(upper_xprt, upper_transport);
+ return;
+
+out_close:
+ xprt_release_write(lower_xprt, NULL);
+ rpc_shutdown_client(lower_clnt);
+
+ /* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
+ * Wake them first here to ensure they get our tk_status code.
+ */
+ xprt_wake_pending_tasks(upper_xprt, status);
+ xs_tcp_force_close(upper_xprt);
+ xprt_clear_connecting(upper_xprt);
+ goto out_unlock;
+}
+
+/**
+ * xs_connect - connect a socket to a remote endpoint
+ * @xprt: pointer to transport structure
+ * @task: address of RPC task that manages state of connect request
+ *
+ * TCP: If the remote end dropped the connection, delay reconnecting.
+ *
+ * UDP socket connects are synchronous, but we use a work queue anyway
+ * to guarantee that even unprivileged user processes can set up a
+ * socket on a privileged port.
+ *
+ * If a UDP socket connect fails, the delay behavior here prevents
+ * retry floods (hard mounts).
+ */
+static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ unsigned long delay = 0;
+
+ WARN_ON_ONCE(!xprt_lock_connect(xprt, task, transport));
+
+ if (transport->sock != NULL) {
+ dprintk("RPC: xs_connect delayed xprt %p for %lu "
+ "seconds\n", xprt, xprt->reestablish_timeout / HZ);
+
+ delay = xprt_reconnect_delay(xprt);
+ xprt_reconnect_backoff(xprt, XS_TCP_INIT_REEST_TO);
+
+ } else
+ dprintk("RPC: xs_connect scheduled xprt %p\n", xprt);
+
+ transport->clnt = task->tk_client;
+ queue_delayed_work(xprtiod_workqueue,
+ &transport->connect_worker,
+ delay);
+}
+
+static void xs_wake_disconnect(struct sock_xprt *transport)
+{
+ if (test_and_clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state))
+ xs_tcp_force_close(&transport->xprt);
+}
+
+static void xs_wake_write(struct sock_xprt *transport)
+{
+ if (test_and_clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state))
+ xprt_write_space(&transport->xprt);
+}
+
+static void xs_wake_error(struct sock_xprt *transport)
+{
+ int sockerr;
+
+ if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
+ return;
+ mutex_lock(&transport->recv_mutex);
+ if (transport->sock == NULL)
+ goto out;
+ if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
+ goto out;
+ sockerr = xchg(&transport->xprt_err, 0);
+ if (sockerr < 0)
+ xprt_wake_pending_tasks(&transport->xprt, sockerr);
+out:
+ mutex_unlock(&transport->recv_mutex);
+}
+
+static void xs_wake_pending(struct sock_xprt *transport)
+{
+ if (test_and_clear_bit(XPRT_SOCK_WAKE_PENDING, &transport->sock_state))
+ xprt_wake_pending_tasks(&transport->xprt, -EAGAIN);
+}
+
+static void xs_error_handle(struct work_struct *work)
+{
+ struct sock_xprt *transport = container_of(work,
+ struct sock_xprt, error_worker);
+
+ xs_wake_disconnect(transport);
+ xs_wake_write(transport);
+ xs_wake_error(transport);
+ xs_wake_pending(transport);
+}
+
+/**
+ * xs_local_print_stats - display AF_LOCAL socket-specific stats
+ * @xprt: rpc_xprt struct containing statistics
+ * @seq: output file
+ *
+ */
+static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq)
+{
+ long idle_time = 0;
+
+ if (xprt_connected(xprt))
+ idle_time = (long)(jiffies - xprt->last_used) / HZ;
+
+ seq_printf(seq, "\txprt:\tlocal %lu %lu %lu %ld %lu %lu %lu "
+ "%llu %llu %lu %llu %llu\n",
+ xprt->stat.bind_count,
+ xprt->stat.connect_count,
+ xprt->stat.connect_time / HZ,
+ idle_time,
+ xprt->stat.sends,
+ xprt->stat.recvs,
+ xprt->stat.bad_xids,
+ xprt->stat.req_u,
+ xprt->stat.bklog_u,
+ xprt->stat.max_slots,
+ xprt->stat.sending_u,
+ xprt->stat.pending_u);
+}
+
+/**
+ * xs_udp_print_stats - display UDP socket-specific stats
+ * @xprt: rpc_xprt struct containing statistics
+ * @seq: output file
+ *
+ */
+static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+
+ seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %llu %llu "
+ "%lu %llu %llu\n",
+ transport->srcport,
+ xprt->stat.bind_count,
+ xprt->stat.sends,
+ xprt->stat.recvs,
+ xprt->stat.bad_xids,
+ xprt->stat.req_u,
+ xprt->stat.bklog_u,
+ xprt->stat.max_slots,
+ xprt->stat.sending_u,
+ xprt->stat.pending_u);
+}
+
+/**
+ * xs_tcp_print_stats - display TCP socket-specific stats
+ * @xprt: rpc_xprt struct containing statistics
+ * @seq: output file
+ *
+ */
+static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq)
+{
+ struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ long idle_time = 0;
+
+ if (xprt_connected(xprt))
+ idle_time = (long)(jiffies - xprt->last_used) / HZ;
+
+ seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu "
+ "%llu %llu %lu %llu %llu\n",
+ transport->srcport,
+ xprt->stat.bind_count,
+ xprt->stat.connect_count,
+ xprt->stat.connect_time / HZ,
+ idle_time,
+ xprt->stat.sends,
+ xprt->stat.recvs,
+ xprt->stat.bad_xids,
+ xprt->stat.req_u,
+ xprt->stat.bklog_u,
+ xprt->stat.max_slots,
+ xprt->stat.sending_u,
+ xprt->stat.pending_u);
+}
+
+/*
+ * Allocate a bunch of pages for a scratch buffer for the rpc code. The reason
+ * we allocate pages instead doing a kmalloc like rpc_malloc is because we want
+ * to use the server side send routines.
+ */
+static int bc_malloc(struct rpc_task *task)
+{
+ struct rpc_rqst *rqst = task->tk_rqstp;
+ size_t size = rqst->rq_callsize;
+ struct page *page;
+ struct rpc_buffer *buf;
+
+ if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) {
+ WARN_ONCE(1, "xprtsock: large bc buffer request (size %zu)\n",
+ size);
+ return -EINVAL;
+ }
+
+ page = alloc_page(GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN);
+ if (!page)
+ return -ENOMEM;
+
+ buf = page_address(page);
+ buf->len = PAGE_SIZE;
+
+ rqst->rq_buffer = buf->data;
+ rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize;
+ return 0;
+}
+
+/*
+ * Free the space allocated in the bc_alloc routine
+ */
+static void bc_free(struct rpc_task *task)
+{
+ void *buffer = task->tk_rqstp->rq_buffer;
+ struct rpc_buffer *buf;
+
+ buf = container_of(buffer, struct rpc_buffer, data);
+ free_page((unsigned long)buf);
+}
+
+static int bc_sendto(struct rpc_rqst *req)
+{
+ struct xdr_buf *xdr = &req->rq_snd_buf;
+ struct sock_xprt *transport =
+ container_of(req->rq_xprt, struct sock_xprt, xprt);
+ struct msghdr msg = {
+ .msg_flags = 0,
+ };
+ rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT |
+ (u32)xdr->len);
+ unsigned int sent = 0;
+ int err;
+
+ req->rq_xtime = ktime_get();
+ err = xdr_alloc_bvec(xdr, rpc_task_gfp_mask());
+ if (err < 0)
+ return err;
+ err = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, marker, &sent);
+ xdr_free_bvec(xdr);
+ if (err < 0 || sent != (xdr->len + sizeof(marker)))
+ return -EAGAIN;
+ return sent;
+}
+
+/**
+ * bc_send_request - Send a backchannel Call on a TCP socket
+ * @req: rpc_rqst containing Call message to be sent
+ *
+ * xpt_mutex ensures @rqstp's whole message is written to the socket
+ * without interruption.
+ *
+ * Return values:
+ * %0 if the message was sent successfully
+ * %ENOTCONN if the message was not sent
+ */
+static int bc_send_request(struct rpc_rqst *req)
+{
+ struct svc_xprt *xprt;
+ int len;
+
+ /*
+ * Get the server socket associated with this callback xprt
+ */
+ xprt = req->rq_xprt->bc_xprt;
+
+ /*
+ * Grab the mutex to serialize data as the connection is shared
+ * with the fore channel
+ */
+ mutex_lock(&xprt->xpt_mutex);
+ if (test_bit(XPT_DEAD, &xprt->xpt_flags))
+ len = -ENOTCONN;
+ else
+ len = bc_sendto(req);
+ mutex_unlock(&xprt->xpt_mutex);
+
+ if (len > 0)
+ len = 0;
+
+ return len;
+}
+
+/*
+ * The close routine. Since this is client initiated, we do nothing
+ */
+
+static void bc_close(struct rpc_xprt *xprt)
+{
+ xprt_disconnect_done(xprt);
+}
+
+/*
+ * The xprt destroy routine. Again, because this connection is client
+ * initiated, we do nothing
+ */
+
+static void bc_destroy(struct rpc_xprt *xprt)
+{
+ dprintk("RPC: bc_destroy xprt %p\n", xprt);
+
+ xs_xprt_free(xprt);
+ module_put(THIS_MODULE);
+}
+
+static const struct rpc_xprt_ops xs_local_ops = {
+ .reserve_xprt = xprt_reserve_xprt,
+ .release_xprt = xprt_release_xprt,
+ .alloc_slot = xprt_alloc_slot,
+ .free_slot = xprt_free_slot,
+ .rpcbind = xs_local_rpcbind,
+ .set_port = xs_local_set_port,
+ .connect = xs_local_connect,
+ .buf_alloc = rpc_malloc,
+ .buf_free = rpc_free,
+ .prepare_request = xs_stream_prepare_request,
+ .send_request = xs_local_send_request,
+ .wait_for_reply_request = xprt_wait_for_reply_request_def,
+ .close = xs_close,
+ .destroy = xs_destroy,
+ .print_stats = xs_local_print_stats,
+ .enable_swap = xs_enable_swap,
+ .disable_swap = xs_disable_swap,
+};
+
+static const struct rpc_xprt_ops xs_udp_ops = {
+ .set_buffer_size = xs_udp_set_buffer_size,
+ .reserve_xprt = xprt_reserve_xprt_cong,
+ .release_xprt = xprt_release_xprt_cong,
+ .alloc_slot = xprt_alloc_slot,
+ .free_slot = xprt_free_slot,
+ .rpcbind = rpcb_getport_async,
+ .set_port = xs_set_port,
+ .connect = xs_connect,
+ .get_srcaddr = xs_sock_srcaddr,
+ .get_srcport = xs_sock_srcport,
+ .buf_alloc = rpc_malloc,
+ .buf_free = rpc_free,
+ .send_request = xs_udp_send_request,
+ .wait_for_reply_request = xprt_wait_for_reply_request_rtt,
+ .timer = xs_udp_timer,
+ .release_request = xprt_release_rqst_cong,
+ .close = xs_close,
+ .destroy = xs_destroy,
+ .print_stats = xs_udp_print_stats,
+ .enable_swap = xs_enable_swap,
+ .disable_swap = xs_disable_swap,
+ .inject_disconnect = xs_inject_disconnect,
+};
+
+static const struct rpc_xprt_ops xs_tcp_ops = {
+ .reserve_xprt = xprt_reserve_xprt,
+ .release_xprt = xprt_release_xprt,
+ .alloc_slot = xprt_alloc_slot,
+ .free_slot = xprt_free_slot,
+ .rpcbind = rpcb_getport_async,
+ .set_port = xs_set_port,
+ .connect = xs_connect,
+ .get_srcaddr = xs_sock_srcaddr,
+ .get_srcport = xs_sock_srcport,
+ .buf_alloc = rpc_malloc,
+ .buf_free = rpc_free,
+ .prepare_request = xs_stream_prepare_request,
+ .send_request = xs_tcp_send_request,
+ .wait_for_reply_request = xprt_wait_for_reply_request_def,
+ .close = xs_tcp_shutdown,
+ .destroy = xs_destroy,
+ .set_connect_timeout = xs_tcp_set_connect_timeout,
+ .print_stats = xs_tcp_print_stats,
+ .enable_swap = xs_enable_swap,
+ .disable_swap = xs_disable_swap,
+ .inject_disconnect = xs_inject_disconnect,
+#ifdef CONFIG_SUNRPC_BACKCHANNEL
+ .bc_setup = xprt_setup_bc,
+ .bc_maxpayload = xs_tcp_bc_maxpayload,
+ .bc_num_slots = xprt_bc_max_slots,
+ .bc_free_rqst = xprt_free_bc_rqst,
+ .bc_destroy = xprt_destroy_bc,
+#endif
+};
+
+/*
+ * The rpc_xprt_ops for the server backchannel
+ */
+
+static const struct rpc_xprt_ops bc_tcp_ops = {
+ .reserve_xprt = xprt_reserve_xprt,
+ .release_xprt = xprt_release_xprt,
+ .alloc_slot = xprt_alloc_slot,
+ .free_slot = xprt_free_slot,
+ .buf_alloc = bc_malloc,
+ .buf_free = bc_free,
+ .send_request = bc_send_request,
+ .wait_for_reply_request = xprt_wait_for_reply_request_def,
+ .close = bc_close,
+ .destroy = bc_destroy,
+ .print_stats = xs_tcp_print_stats,
+ .enable_swap = xs_enable_swap,
+ .disable_swap = xs_disable_swap,
+ .inject_disconnect = xs_inject_disconnect,
+};
+
+static int xs_init_anyaddr(const int family, struct sockaddr *sap)
+{
+ static const struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_addr.s_addr = htonl(INADDR_ANY),
+ };
+ static const struct sockaddr_in6 sin6 = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = IN6ADDR_ANY_INIT,
+ };
+
+ switch (family) {
+ case AF_LOCAL:
+ break;
+ case AF_INET:
+ memcpy(sap, &sin, sizeof(sin));
+ break;
+ case AF_INET6:
+ memcpy(sap, &sin6, sizeof(sin6));
+ break;
+ default:
+ dprintk("RPC: %s: Bad address family\n", __func__);
+ return -EAFNOSUPPORT;
+ }
+ return 0;
+}
+
+static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args,
+ unsigned int slot_table_size,
+ unsigned int max_slot_table_size)
+{
+ struct rpc_xprt *xprt;
+ struct sock_xprt *new;
+
+ if (args->addrlen > sizeof(xprt->addr)) {
+ dprintk("RPC: xs_setup_xprt: address too large\n");
+ return ERR_PTR(-EBADF);
+ }
+
+ xprt = xprt_alloc(args->net, sizeof(*new), slot_table_size,
+ max_slot_table_size);
+ if (xprt == NULL) {
+ dprintk("RPC: xs_setup_xprt: couldn't allocate "
+ "rpc_xprt\n");
+ return ERR_PTR(-ENOMEM);
+ }
+
+ new = container_of(xprt, struct sock_xprt, xprt);
+ mutex_init(&new->recv_mutex);
+ memcpy(&xprt->addr, args->dstaddr, args->addrlen);
+ xprt->addrlen = args->addrlen;
+ if (args->srcaddr)
+ memcpy(&new->srcaddr, args->srcaddr, args->addrlen);
+ else {
+ int err;
+ err = xs_init_anyaddr(args->dstaddr->sa_family,
+ (struct sockaddr *)&new->srcaddr);
+ if (err != 0) {
+ xprt_free(xprt);
+ return ERR_PTR(err);
+ }
+ }
+
+ return xprt;
+}
+
+static const struct rpc_timeout xs_local_default_timeout = {
+ .to_initval = 10 * HZ,
+ .to_maxval = 10 * HZ,
+ .to_retries = 2,
+};
+
+/**
+ * xs_setup_local - Set up transport to use an AF_LOCAL socket
+ * @args: rpc transport creation arguments
+ *
+ * AF_LOCAL is a "tpi_cots_ord" transport, just like TCP
+ */
+static struct rpc_xprt *xs_setup_local(struct xprt_create *args)
+{
+ struct sockaddr_un *sun = (struct sockaddr_un *)args->dstaddr;
+ struct sock_xprt *transport;
+ struct rpc_xprt *xprt;
+ struct rpc_xprt *ret;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ xprt_max_tcp_slot_table_entries);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = 0;
+ xprt->xprt_class = &xs_local_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_local_ops;
+ xprt->timeout = &xs_local_default_timeout;
+
+ INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+ INIT_DELAYED_WORK(&transport->connect_worker, xs_dummy_setup_socket);
+
+ switch (sun->sun_family) {
+ case AF_LOCAL:
+ if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') {
+ dprintk("RPC: bad AF_LOCAL address: %s\n",
+ sun->sun_path);
+ ret = ERR_PTR(-EINVAL);
+ goto out_err;
+ }
+ xprt_set_bound(xprt);
+ xs_format_peer_addresses(xprt, "local", RPCBIND_NETID_LOCAL);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ dprintk("RPC: set up xprt to %s via AF_LOCAL\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+static const struct rpc_timeout xs_udp_default_timeout = {
+ .to_initval = 5 * HZ,
+ .to_maxval = 30 * HZ,
+ .to_increment = 5 * HZ,
+ .to_retries = 5,
+};
+
+/**
+ * xs_setup_udp - Set up transport to use a UDP socket
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_udp(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct rpc_xprt *ret;
+
+ xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries,
+ xprt_udp_slot_table_entries);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_UDP;
+ xprt->xprt_class = &xs_udp_transport;
+ /* XXX: header size can vary due to auth type, IPv6, etc. */
+ xprt->max_payload = (1U << 16) - (MAX_HEADER << 3);
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_UDP_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_udp_ops;
+
+ xprt->timeout = &xs_udp_default_timeout;
+
+ INIT_WORK(&transport->recv_worker, xs_udp_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+ INIT_DELAYED_WORK(&transport->connect_worker, xs_udp_setup_socket);
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP);
+ break;
+ case AF_INET6:
+ if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ if (xprt_bound(xprt))
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+ else
+ dprintk("RPC: set up xprt to %s (autobind) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+static const struct rpc_timeout xs_tcp_default_timeout = {
+ .to_initval = 60 * HZ,
+ .to_maxval = 60 * HZ,
+ .to_retries = 2,
+};
+
+/**
+ * xs_setup_tcp - Set up transport to use a TCP socket
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct rpc_xprt *ret;
+ unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
+
+ if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
+ max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ max_slot_table_size);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xs_tcp_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_tcp_ops;
+ xprt->timeout = &xs_tcp_default_timeout;
+
+ xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ if (args->reconnect_timeout)
+ xprt->max_reconnect_timeout = args->reconnect_timeout;
+
+ xprt->connect_timeout = xprt->timeout->to_initval *
+ (xprt->timeout->to_retries + 1);
+ if (args->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout);
+
+ INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+ INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_setup_socket);
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
+ break;
+ case AF_INET6:
+ if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ if (xprt_bound(xprt))
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+ else
+ dprintk("RPC: set up xprt to %s (autobind) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+/**
+ * xs_setup_tcp_tls - Set up transport to use a TCP with TLS
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct rpc_xprt *ret;
+ unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
+
+ if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
+ max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ max_slot_table_size);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xs_tcp_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_tcp_ops;
+ xprt->timeout = &xs_tcp_default_timeout;
+
+ xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ xprt->connect_timeout = xprt->timeout->to_initval *
+ (xprt->timeout->to_retries + 1);
+
+ INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+
+ switch (args->xprtsec.policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ case RPC_XPRTSEC_TLS_X509:
+ xprt->xprtsec = args->xprtsec;
+ INIT_DELAYED_WORK(&transport->connect_worker,
+ xs_tcp_tls_setup_socket);
+ break;
+ default:
+ ret = ERR_PTR(-EACCES);
+ goto out_err;
+ }
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
+ break;
+ case AF_INET6:
+ if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ if (xprt_bound(xprt))
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+ else
+ dprintk("RPC: set up xprt to %s (autobind) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+/**
+ * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct svc_sock *bc_sock;
+ struct rpc_xprt *ret;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ xprt_tcp_slot_table_entries);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xs_bc_tcp_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+ xprt->timeout = &xs_tcp_default_timeout;
+
+ /* backchannel */
+ xprt_set_bound(xprt);
+ xprt->bind_timeout = 0;
+ xprt->reestablish_timeout = 0;
+ xprt->idle_timeout = 0;
+
+ xprt->ops = &bc_tcp_ops;
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ xs_format_peer_addresses(xprt, "tcp",
+ RPCBIND_NETID_TCP);
+ break;
+ case AF_INET6:
+ xs_format_peer_addresses(xprt, "tcp",
+ RPCBIND_NETID_TCP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ /*
+ * Once we've associated a backchannel xprt with a connection,
+ * we want to keep it around as long as the connection lasts,
+ * in case we need to start using it for a backchannel again;
+ * this reference won't be dropped until bc_xprt is destroyed.
+ */
+ xprt_get(xprt);
+ args->bc_xprt->xpt_bc_xprt = xprt;
+ xprt->bc_xprt = args->bc_xprt;
+ bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt);
+ transport->sock = bc_sock->sk_sock;
+ transport->inet = bc_sock->sk_sk;
+
+ /*
+ * Since we don't want connections for the backchannel, we set
+ * the xprt status to connected
+ */
+ xprt_set_connected(xprt);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+
+ args->bc_xprt->xpt_bc_xprt = NULL;
+ args->bc_xprt->xpt_bc_xps = NULL;
+ xprt_put(xprt);
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+static struct xprt_class xs_local_transport = {
+ .list = LIST_HEAD_INIT(xs_local_transport.list),
+ .name = "named UNIX socket",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_LOCAL,
+ .setup = xs_setup_local,
+ .netid = { "" },
+};
+
+static struct xprt_class xs_udp_transport = {
+ .list = LIST_HEAD_INIT(xs_udp_transport.list),
+ .name = "udp",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_UDP,
+ .setup = xs_setup_udp,
+ .netid = { "udp", "udp6", "" },
+};
+
+static struct xprt_class xs_tcp_transport = {
+ .list = LIST_HEAD_INIT(xs_tcp_transport.list),
+ .name = "tcp",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_TCP,
+ .setup = xs_setup_tcp,
+ .netid = { "tcp", "tcp6", "" },
+};
+
+static struct xprt_class xs_tcp_tls_transport = {
+ .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list),
+ .name = "tcp-with-tls",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_TCP_TLS,
+ .setup = xs_setup_tcp_tls,
+ .netid = { "tcp", "tcp6", "" },
+};
+
+static struct xprt_class xs_bc_tcp_transport = {
+ .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list),
+ .name = "tcp NFSv4.1 backchannel",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_BC_TCP,
+ .setup = xs_setup_bc_tcp,
+ .netid = { "" },
+};
+
+/**
+ * init_socket_xprt - set up xprtsock's sysctls, register with RPC client
+ *
+ */
+int init_socket_xprt(void)
+{
+ if (!sunrpc_table_header)
+ sunrpc_table_header = register_sysctl("sunrpc", xs_tunables_table);
+
+ xprt_register_transport(&xs_local_transport);
+ xprt_register_transport(&xs_udp_transport);
+ xprt_register_transport(&xs_tcp_transport);
+ xprt_register_transport(&xs_tcp_tls_transport);
+ xprt_register_transport(&xs_bc_tcp_transport);
+
+ return 0;
+}
+
+/**
+ * cleanup_socket_xprt - remove xprtsock's sysctls, unregister
+ *
+ */
+void cleanup_socket_xprt(void)
+{
+ if (sunrpc_table_header) {
+ unregister_sysctl_table(sunrpc_table_header);
+ sunrpc_table_header = NULL;
+ }
+
+ xprt_unregister_transport(&xs_local_transport);
+ xprt_unregister_transport(&xs_udp_transport);
+ xprt_unregister_transport(&xs_tcp_transport);
+ xprt_unregister_transport(&xs_tcp_tls_transport);
+ xprt_unregister_transport(&xs_bc_tcp_transport);
+}
+
+static int param_set_portnr(const char *val, const struct kernel_param *kp)
+{
+ return param_set_uint_minmax(val, kp,
+ RPC_MIN_RESVPORT,
+ RPC_MAX_RESVPORT);
+}
+
+static const struct kernel_param_ops param_ops_portnr = {
+ .set = param_set_portnr,
+ .get = param_get_uint,
+};
+
+#define param_check_portnr(name, p) \
+ __param_check(name, p, unsigned int);
+
+module_param_named(min_resvport, xprt_min_resvport, portnr, 0644);
+module_param_named(max_resvport, xprt_max_resvport, portnr, 0644);
+
+static int param_set_slot_table_size(const char *val,
+ const struct kernel_param *kp)
+{
+ return param_set_uint_minmax(val, kp,
+ RPC_MIN_SLOT_TABLE,
+ RPC_MAX_SLOT_TABLE);
+}
+
+static const struct kernel_param_ops param_ops_slot_table_size = {
+ .set = param_set_slot_table_size,
+ .get = param_get_uint,
+};
+
+#define param_check_slot_table_size(name, p) \
+ __param_check(name, p, unsigned int);
+
+static int param_set_max_slot_table_size(const char *val,
+ const struct kernel_param *kp)
+{
+ return param_set_uint_minmax(val, kp,
+ RPC_MIN_SLOT_TABLE,
+ RPC_MAX_SLOT_TABLE_LIMIT);
+}
+
+static const struct kernel_param_ops param_ops_max_slot_table_size = {
+ .set = param_set_max_slot_table_size,
+ .get = param_get_uint,
+};
+
+#define param_check_max_slot_table_size(name, p) \
+ __param_check(name, p, unsigned int);
+
+module_param_named(tcp_slot_table_entries, xprt_tcp_slot_table_entries,
+ slot_table_size, 0644);
+module_param_named(tcp_max_slot_table_entries, xprt_max_tcp_slot_table_entries,
+ max_slot_table_size, 0644);
+module_param_named(udp_slot_table_entries, xprt_udp_slot_table_entries,
+ slot_table_size, 0644);