diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-11 08:27:49 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-11 08:27:49 +0000 |
commit | ace9429bb58fd418f0c81d4c2835699bddf6bde6 (patch) | |
tree | b2d64bc10158fdd5497876388cd68142ca374ed3 /net/vmw_vsock | |
parent | Initial commit. (diff) | |
download | linux-ace9429bb58fd418f0c81d4c2835699bddf6bde6.tar.xz linux-ace9429bb58fd418f0c81d4c2835699bddf6bde6.zip |
Adding upstream version 6.6.15.upstream/6.6.15
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r-- | net/vmw_vsock/Kconfig | 83 | ||||
-rw-r--r-- | net/vmw_vsock/Makefile | 22 | ||||
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 2516 | ||||
-rw-r--r-- | net/vmw_vsock/af_vsock_tap.c | 110 | ||||
-rw-r--r-- | net/vmw_vsock/diag.c | 178 | ||||
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 954 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport.c | 824 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 1561 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport.c | 2160 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport.h | 130 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport_notify.c | 672 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport_notify.h | 75 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport_notify_qstate.c | 430 | ||||
-rw-r--r-- | net/vmw_vsock/vsock_addr.c | 69 | ||||
-rw-r--r-- | net/vmw_vsock/vsock_bpf.c | 174 | ||||
-rw-r--r-- | net/vmw_vsock/vsock_loopback.c | 167 |
16 files changed, 10125 insertions, 0 deletions
diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig new file mode 100644 index 0000000000..56356d2980 --- /dev/null +++ b/net/vmw_vsock/Kconfig @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Vsock protocol +# + +config VSOCKETS + tristate "Virtual Socket protocol" + help + Virtual Socket Protocol is a socket protocol similar to TCP/IP + allowing communication between Virtual Machines and hypervisor + or host. + + You should also select one or more hypervisor-specific transports + below. + + To compile this driver as a module, choose M here: the module + will be called vsock. If unsure, say N. + +config VSOCKETS_DIAG + tristate "Virtual Sockets monitoring interface" + depends on VSOCKETS + default y + help + Support for PF_VSOCK sockets monitoring interface used by the ss tool. + If unsure, say Y. + + Enable this module so userspace applications can query open sockets. + +config VSOCKETS_LOOPBACK + tristate "Virtual Sockets loopback transport" + depends on VSOCKETS + default y + select VIRTIO_VSOCKETS_COMMON + help + This module implements a loopback transport for Virtual Sockets, + using vmw_vsock_virtio_transport_common. + + To compile this driver as a module, choose M here: the module + will be called vsock_loopback. If unsure, say N. + +config VMWARE_VMCI_VSOCKETS + tristate "VMware VMCI transport for Virtual Sockets" + depends on VSOCKETS && VMWARE_VMCI + help + This module implements a VMCI transport for Virtual Sockets. + + Enable this transport if your Virtual Machine runs on a VMware + hypervisor. + + To compile this driver as a module, choose M here: the module + will be called vmw_vsock_vmci_transport. If unsure, say N. + +config VIRTIO_VSOCKETS + tristate "virtio transport for Virtual Sockets" + depends on VSOCKETS && VIRTIO + select VIRTIO_VSOCKETS_COMMON + help + This module implements a virtio transport for Virtual Sockets. + + Enable this transport if your Virtual Machine host supports Virtual + Sockets over virtio. + + To compile this driver as a module, choose M here: the module will be + called vmw_vsock_virtio_transport. If unsure, say N. + +config VIRTIO_VSOCKETS_COMMON + tristate + help + This option is selected by any driver which needs to access + the virtio_vsock. The module will be called + vmw_vsock_virtio_transport_common. + +config HYPERV_VSOCKETS + tristate "Hyper-V transport for Virtual Sockets" + depends on VSOCKETS && HYPERV + help + This module implements a Hyper-V transport for Virtual Sockets. + + Enable this transport if your Virtual Machine host supports Virtual + Sockets over Hyper-V VMBus. + + To compile this driver as a module, choose M here: the module will be + called hv_sock. If unsure, say N. diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile new file mode 100644 index 0000000000..5da74c4a9f --- /dev/null +++ b/net/vmw_vsock/Makefile @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_VSOCKETS) += vsock.o +obj-$(CONFIG_VSOCKETS_DIAG) += vsock_diag.o +obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o +obj-$(CONFIG_HYPERV_VSOCKETS) += hv_sock.o +obj-$(CONFIG_VSOCKETS_LOOPBACK) += vsock_loopback.o + +vsock-y += af_vsock.o af_vsock_tap.o vsock_addr.o +vsock-$(CONFIG_BPF_SYSCALL) += vsock_bpf.o + +vsock_diag-y += diag.o + +vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ + vmci_transport_notify_qstate.o + +vmw_vsock_virtio_transport-y += virtio_transport.o + +vmw_vsock_virtio_transport_common-y += virtio_transport_common.o + +hv_sock-y += hyperv_transport.o diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c new file mode 100644 index 0000000000..4afb6a541c --- /dev/null +++ b/net/vmw_vsock/af_vsock.c @@ -0,0 +1,2516 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. + */ + +/* Implementation notes: + * + * - There are two kinds of sockets: those created by user action (such as + * calling socket(2)) and those created by incoming connection request packets. + * + * - There are two "global" tables, one for bound sockets (sockets that have + * specified an address that they are responsible for) and one for connected + * sockets (sockets that have established a connection with another socket). + * These tables are "global" in that all sockets on the system are placed + * within them. - Note, though, that the bound table contains an extra entry + * for a list of unbound sockets and SOCK_DGRAM sockets will always remain in + * that list. The bound table is used solely for lookup of sockets when packets + * are received and that's not necessary for SOCK_DGRAM sockets since we create + * a datagram handle for each and need not perform a lookup. Keeping SOCK_DGRAM + * sockets out of the bound hash buckets will reduce the chance of collisions + * when looking for SOCK_STREAM sockets and prevents us from having to check the + * socket type in the hash table lookups. + * + * - Sockets created by user action will either be "client" sockets that + * initiate a connection or "server" sockets that listen for connections; we do + * not support simultaneous connects (two "client" sockets connecting). + * + * - "Server" sockets are referred to as listener sockets throughout this + * implementation because they are in the TCP_LISTEN state. When a + * connection request is received (the second kind of socket mentioned above), + * we create a new socket and refer to it as a pending socket. These pending + * sockets are placed on the pending connection list of the listener socket. + * When future packets are received for the address the listener socket is + * bound to, we check if the source of the packet is from one that has an + * existing pending connection. If it does, we process the packet for the + * pending socket. When that socket reaches the connected state, it is removed + * from the listener socket's pending list and enqueued in the listener + * socket's accept queue. Callers of accept(2) will accept connected sockets + * from the listener socket's accept queue. If the socket cannot be accepted + * for some reason then it is marked rejected. Once the connection is + * accepted, it is owned by the user process and the responsibility for cleanup + * falls with that user process. + * + * - It is possible that these pending sockets will never reach the connected + * state; in fact, we may never receive another packet after the connection + * request. Because of this, we must schedule a cleanup function to run in the + * future, after some amount of time passes where a connection should have been + * established. This function ensures that the socket is off all lists so it + * cannot be retrieved, then drops all references to the socket so it is cleaned + * up (sock_put() -> sk_free() -> our sk_destruct implementation). Note this + * function will also cleanup rejected sockets, those that reach the connected + * state but leave it before they have been accepted. + * + * - Lock ordering for pending or accept queue sockets is: + * + * lock_sock(listener); + * lock_sock_nested(pending, SINGLE_DEPTH_NESTING); + * + * Using explicit nested locking keeps lockdep happy since normally only one + * lock of a given class may be taken at a time. + * + * - Sockets created by user action will be cleaned up when the user process + * calls close(2), causing our release implementation to be called. Our release + * implementation will perform some cleanup then drop the last reference so our + * sk_destruct implementation is invoked. Our sk_destruct implementation will + * perform additional cleanup that's common for both types of sockets. + * + * - A socket's reference count is what ensures that the structure won't be + * freed. Each entry in a list (such as the "global" bound and connected tables + * and the listener socket's pending list and connected queue) ensures a + * reference. When we defer work until process context and pass a socket as our + * argument, we must ensure the reference count is increased to ensure the + * socket isn't freed before the function is run; the deferred function will + * then drop the reference. + * + * - sk->sk_state uses the TCP state constants because they are widely used by + * other address families and exposed to userspace tools like ss(8): + * + * TCP_CLOSE - unconnected + * TCP_SYN_SENT - connecting + * TCP_ESTABLISHED - connected + * TCP_CLOSING - disconnecting + * TCP_LISTEN - listening + */ + +#include <linux/compat.h> +#include <linux/types.h> +#include <linux/bitops.h> +#include <linux/cred.h> +#include <linux/errqueue.h> +#include <linux/init.h> +#include <linux/io.h> +#include <linux/kernel.h> +#include <linux/sched/signal.h> +#include <linux/kmod.h> +#include <linux/list.h> +#include <linux/miscdevice.h> +#include <linux/module.h> +#include <linux/mutex.h> +#include <linux/net.h> +#include <linux/poll.h> +#include <linux/random.h> +#include <linux/skbuff.h> +#include <linux/smp.h> +#include <linux/socket.h> +#include <linux/stddef.h> +#include <linux/unistd.h> +#include <linux/wait.h> +#include <linux/workqueue.h> +#include <net/sock.h> +#include <net/af_vsock.h> +#include <uapi/linux/vm_sockets.h> + +static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); +static void vsock_sk_destruct(struct sock *sk); +static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); + +/* Protocol family. */ +struct proto vsock_proto = { + .name = "AF_VSOCK", + .owner = THIS_MODULE, + .obj_size = sizeof(struct vsock_sock), +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = vsock_bpf_update_proto, +#endif +}; + +/* The default peer timeout indicates how long we will wait for a peer response + * to a control message. + */ +#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) + +#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128 + +/* Transport used for host->guest communication */ +static const struct vsock_transport *transport_h2g; +/* Transport used for guest->host communication */ +static const struct vsock_transport *transport_g2h; +/* Transport used for DGRAM communication */ +static const struct vsock_transport *transport_dgram; +/* Transport used for local communication */ +static const struct vsock_transport *transport_local; +static DEFINE_MUTEX(vsock_register_mutex); + +/**** UTILS ****/ + +/* Each bound VSocket is stored in the bind hash table and each connected + * VSocket is stored in the connected hash table. + * + * Unbound sockets are all put on the same list attached to the end of the hash + * table (vsock_unbound_sockets). Bound sockets are added to the hash table in + * the bucket that their local address hashes to (vsock_bound_sockets(addr) + * represents the list that addr hashes to). + * + * Specifically, we initialize the vsock_bind_table array to a size of + * VSOCK_HASH_SIZE + 1 so that vsock_bind_table[0] through + * vsock_bind_table[VSOCK_HASH_SIZE - 1] are for bound sockets and + * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets. The hash function + * mods with VSOCK_HASH_SIZE to ensure this. + */ +#define MAX_PORT_RETRIES 24 + +#define VSOCK_HASH(addr) ((addr)->svm_port % VSOCK_HASH_SIZE) +#define vsock_bound_sockets(addr) (&vsock_bind_table[VSOCK_HASH(addr)]) +#define vsock_unbound_sockets (&vsock_bind_table[VSOCK_HASH_SIZE]) + +/* XXX This can probably be implemented in a better way. */ +#define VSOCK_CONN_HASH(src, dst) \ + (((src)->svm_cid ^ (dst)->svm_port) % VSOCK_HASH_SIZE) +#define vsock_connected_sockets(src, dst) \ + (&vsock_connected_table[VSOCK_CONN_HASH(src, dst)]) +#define vsock_connected_sockets_vsk(vsk) \ + vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr) + +struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; +EXPORT_SYMBOL_GPL(vsock_bind_table); +struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; +EXPORT_SYMBOL_GPL(vsock_connected_table); +DEFINE_SPINLOCK(vsock_table_lock); +EXPORT_SYMBOL_GPL(vsock_table_lock); + +/* Autobind this socket to the local address if necessary. */ +static int vsock_auto_bind(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + struct sockaddr_vm local_addr; + + if (vsock_addr_bound(&vsk->local_addr)) + return 0; + vsock_addr_init(&local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + return __vsock_bind(sk, &local_addr); +} + +static void vsock_init_tables(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(vsock_bind_table); i++) + INIT_LIST_HEAD(&vsock_bind_table[i]); + + for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) + INIT_LIST_HEAD(&vsock_connected_table[i]); +} + +static void __vsock_insert_bound(struct list_head *list, + struct vsock_sock *vsk) +{ + sock_hold(&vsk->sk); + list_add(&vsk->bound_table, list); +} + +static void __vsock_insert_connected(struct list_head *list, + struct vsock_sock *vsk) +{ + sock_hold(&vsk->sk); + list_add(&vsk->connected_table, list); +} + +static void __vsock_remove_bound(struct vsock_sock *vsk) +{ + list_del_init(&vsk->bound_table); + sock_put(&vsk->sk); +} + +static void __vsock_remove_connected(struct vsock_sock *vsk) +{ + list_del_init(&vsk->connected_table); + sock_put(&vsk->sk); +} + +static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) +{ + struct vsock_sock *vsk; + + list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { + if (vsock_addr_equals_addr(addr, &vsk->local_addr)) + return sk_vsock(vsk); + + if (addr->svm_port == vsk->local_addr.svm_port && + (vsk->local_addr.svm_cid == VMADDR_CID_ANY || + addr->svm_cid == VMADDR_CID_ANY)) + return sk_vsock(vsk); + } + + return NULL; +} + +static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, + struct sockaddr_vm *dst) +{ + struct vsock_sock *vsk; + + list_for_each_entry(vsk, vsock_connected_sockets(src, dst), + connected_table) { + if (vsock_addr_equals_addr(src, &vsk->remote_addr) && + dst->svm_port == vsk->local_addr.svm_port) { + return sk_vsock(vsk); + } + } + + return NULL; +} + +static void vsock_insert_unbound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + __vsock_insert_bound(vsock_unbound_sockets, vsk); + spin_unlock_bh(&vsock_table_lock); +} + +void vsock_insert_connected(struct vsock_sock *vsk) +{ + struct list_head *list = vsock_connected_sockets( + &vsk->remote_addr, &vsk->local_addr); + + spin_lock_bh(&vsock_table_lock); + __vsock_insert_connected(list, vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_insert_connected); + +void vsock_remove_bound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + if (__vsock_in_bound_table(vsk)) + __vsock_remove_bound(vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_remove_bound); + +void vsock_remove_connected(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + if (__vsock_in_connected_table(vsk)) + __vsock_remove_connected(vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_remove_connected); + +struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) +{ + struct sock *sk; + + spin_lock_bh(&vsock_table_lock); + sk = __vsock_find_bound_socket(addr); + if (sk) + sock_hold(sk); + + spin_unlock_bh(&vsock_table_lock); + + return sk; +} +EXPORT_SYMBOL_GPL(vsock_find_bound_socket); + +struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, + struct sockaddr_vm *dst) +{ + struct sock *sk; + + spin_lock_bh(&vsock_table_lock); + sk = __vsock_find_connected_socket(src, dst); + if (sk) + sock_hold(sk); + + spin_unlock_bh(&vsock_table_lock); + + return sk; +} +EXPORT_SYMBOL_GPL(vsock_find_connected_socket); + +void vsock_remove_sock(struct vsock_sock *vsk) +{ + vsock_remove_bound(vsk); + vsock_remove_connected(vsk); +} +EXPORT_SYMBOL_GPL(vsock_remove_sock); + +void vsock_for_each_connected_socket(struct vsock_transport *transport, + void (*fn)(struct sock *sk)) +{ + int i; + + spin_lock_bh(&vsock_table_lock); + + for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) { + struct vsock_sock *vsk; + list_for_each_entry(vsk, &vsock_connected_table[i], + connected_table) { + if (vsk->transport != transport) + continue; + + fn(sk_vsock(vsk)); + } + } + + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket); + +void vsock_add_pending(struct sock *listener, struct sock *pending) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vpending; + + vlistener = vsock_sk(listener); + vpending = vsock_sk(pending); + + sock_hold(pending); + sock_hold(listener); + list_add_tail(&vpending->pending_links, &vlistener->pending_links); +} +EXPORT_SYMBOL_GPL(vsock_add_pending); + +void vsock_remove_pending(struct sock *listener, struct sock *pending) +{ + struct vsock_sock *vpending = vsock_sk(pending); + + list_del_init(&vpending->pending_links); + sock_put(listener); + sock_put(pending); +} +EXPORT_SYMBOL_GPL(vsock_remove_pending); + +void vsock_enqueue_accept(struct sock *listener, struct sock *connected) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vconnected; + + vlistener = vsock_sk(listener); + vconnected = vsock_sk(connected); + + sock_hold(connected); + sock_hold(listener); + list_add_tail(&vconnected->accept_queue, &vlistener->accept_queue); +} +EXPORT_SYMBOL_GPL(vsock_enqueue_accept); + +static bool vsock_use_local_transport(unsigned int remote_cid) +{ + if (!transport_local) + return false; + + if (remote_cid == VMADDR_CID_LOCAL) + return true; + + if (transport_g2h) { + return remote_cid == transport_g2h->get_local_cid(); + } else { + return remote_cid == VMADDR_CID_HOST; + } +} + +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + +/* Assign a transport to a socket and call the .init transport callback. + * + * Note: for connection oriented socket this must be called when vsk->remote_addr + * is set (e.g. during the connect() or when a connection request on a listener + * socket is received). + * The vsk->remote_addr is used to decide which transport to use: + * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if + * g2h is not loaded, will use local transport; + * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field + * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; + * - remote CID > VMADDR_CID_HOST will use host->guest transport; + */ +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + const struct vsock_transport *new_transport; + struct sock *sk = sk_vsock(vsk); + unsigned int remote_cid = vsk->remote_addr.svm_cid; + __u8 remote_flags; + int ret; + + /* If the packet is coming with the source and destination CIDs higher + * than VMADDR_CID_HOST, then a vsock channel where all the packets are + * forwarded to the host should be established. Then the host will + * need to forward the packets to the guest. + * + * The flag is set on the (listen) receive path (psk is not NULL). On + * the connect path the flag can be set by the user space application. + */ + if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST && + vsk->remote_addr.svm_cid > VMADDR_CID_HOST) + vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + + remote_flags = vsk->remote_addr.svm_flags; + + switch (sk->sk_type) { + case SOCK_DGRAM: + new_transport = transport_dgram; + break; + case SOCK_STREAM: + case SOCK_SEQPACKET: + if (vsock_use_local_transport(remote_cid)) + new_transport = transport_local; + else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g || + (remote_flags & VMADDR_FLAG_TO_HOST)) + new_transport = transport_g2h; + else + new_transport = transport_h2g; + break; + default: + return -ESOCKTNOSUPPORT; + } + + if (vsk->transport) { + if (vsk->transport == new_transport) + return 0; + + /* transport->release() must be called with sock lock acquired. + * This path can only be taken during vsock_connect(), where we + * have already held the sock lock. In the other cases, this + * function is called on a new socket which is not assigned to + * any transport. + */ + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + } + + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) + return -ENODEV; + + if (sk->sk_type == SOCK_SEQPACKET) { + if (!new_transport->seqpacket_allow || + !new_transport->seqpacket_allow(remote_cid)) { + module_put(new_transport->module); + return -ESOCKTNOSUPPORT; + } + } + + ret = new_transport->init(vsk, psk); + if (ret) { + module_put(new_transport->module); + return ret; + } + + vsk->transport = new_transport; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_assign_transport); + +bool vsock_find_cid(unsigned int cid) +{ + if (transport_g2h && cid == transport_g2h->get_local_cid()) + return true; + + if (transport_h2g && cid == VMADDR_CID_HOST) + return true; + + if (transport_local && cid == VMADDR_CID_LOCAL) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(vsock_find_cid); + +static struct sock *vsock_dequeue_accept(struct sock *listener) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vconnected; + + vlistener = vsock_sk(listener); + + if (list_empty(&vlistener->accept_queue)) + return NULL; + + vconnected = list_entry(vlistener->accept_queue.next, + struct vsock_sock, accept_queue); + + list_del_init(&vconnected->accept_queue); + sock_put(listener); + /* The caller will need a reference on the connected socket so we let + * it call sock_put(). + */ + + return sk_vsock(vconnected); +} + +static bool vsock_is_accept_queue_empty(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + return list_empty(&vsk->accept_queue); +} + +static bool vsock_is_pending(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + return !list_empty(&vsk->pending_links); +} + +static int vsock_send_shutdown(struct sock *sk, int mode) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (!vsk->transport) + return -ENODEV; + + return vsk->transport->shutdown(vsk, mode); +} + +static void vsock_pending_work(struct work_struct *work) +{ + struct sock *sk; + struct sock *listener; + struct vsock_sock *vsk; + bool cleanup; + + vsk = container_of(work, struct vsock_sock, pending_work.work); + sk = sk_vsock(vsk); + listener = vsk->listener; + cleanup = true; + + lock_sock(listener); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (vsock_is_pending(sk)) { + vsock_remove_pending(listener, sk); + + sk_acceptq_removed(listener); + } else if (!vsk->rejected) { + /* We are not on the pending list and accept() did not reject + * us, so we must have been accepted by our user process. We + * just need to drop our references to the sockets and be on + * our way. + */ + cleanup = false; + goto out; + } + + /* We need to remove ourself from the global connected sockets list so + * incoming packets can't find this socket, and to reduce the reference + * count. + */ + vsock_remove_connected(vsk); + + sk->sk_state = TCP_CLOSE; + +out: + release_sock(sk); + release_sock(listener); + if (cleanup) + sock_put(sk); + + sock_put(sk); + sock_put(listener); +} + +/**** SOCKET OPERATIONS ****/ + +static int __vsock_bind_connectible(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + static u32 port; + struct sockaddr_vm new_addr; + + if (!port) + port = get_random_u32_above(LAST_RESERVED_PORT); + + vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); + + if (addr->svm_port == VMADDR_PORT_ANY) { + bool found = false; + unsigned int i; + + for (i = 0; i < MAX_PORT_RETRIES; i++) { + if (port <= LAST_RESERVED_PORT) + port = LAST_RESERVED_PORT + 1; + + new_addr.svm_port = port++; + + if (!__vsock_find_bound_socket(&new_addr)) { + found = true; + break; + } + } + + if (!found) + return -EADDRNOTAVAIL; + } else { + /* If port is in reserved range, ensure caller + * has necessary privileges. + */ + if (addr->svm_port <= LAST_RESERVED_PORT && + !capable(CAP_NET_BIND_SERVICE)) { + return -EACCES; + } + + if (__vsock_find_bound_socket(&new_addr)) + return -EADDRINUSE; + } + + vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port); + + /* Remove connection oriented sockets from the unbound list and add them + * to the hash table for easy lookup by its address. The unbound list + * is simply an extra entry at the end of the hash table, a trick used + * by AF_UNIX. + */ + __vsock_remove_bound(vsk); + __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk); + + return 0; +} + +static int __vsock_bind_dgram(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return vsk->transport->dgram_bind(vsk, addr); +} + +static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) +{ + struct vsock_sock *vsk = vsock_sk(sk); + int retval; + + /* First ensure this socket isn't already bound. */ + if (vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + /* Now bind to the provided address or select appropriate values if + * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that + * like AF_INET prevents binding to a non-local IP address (in most + * cases), we only allow binding to a local CID. + */ + if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid)) + return -EADDRNOTAVAIL; + + switch (sk->sk_socket->type) { + case SOCK_STREAM: + case SOCK_SEQPACKET: + spin_lock_bh(&vsock_table_lock); + retval = __vsock_bind_connectible(vsk, addr); + spin_unlock_bh(&vsock_table_lock); + break; + + case SOCK_DGRAM: + retval = __vsock_bind_dgram(vsk, addr); + break; + + default: + retval = -EINVAL; + break; + } + + return retval; +} + +static void vsock_connect_timeout(struct work_struct *work); + +static struct sock *__vsock_create(struct net *net, + struct socket *sock, + struct sock *parent, + gfp_t priority, + unsigned short type, + int kern) +{ + struct sock *sk; + struct vsock_sock *psk; + struct vsock_sock *vsk; + + sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + + /* sk->sk_type is normally set in sock_init_data, but only if sock is + * non-NULL. We make sure that our sockets always have a type by + * setting it here if needed. + */ + if (!sock) + sk->sk_type = type; + + vsk = vsock_sk(sk); + vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + sk->sk_destruct = vsock_sk_destruct; + sk->sk_backlog_rcv = vsock_queue_rcv_skb; + sock_reset_flag(sk, SOCK_DONE); + + INIT_LIST_HEAD(&vsk->bound_table); + INIT_LIST_HEAD(&vsk->connected_table); + vsk->listener = NULL; + INIT_LIST_HEAD(&vsk->pending_links); + INIT_LIST_HEAD(&vsk->accept_queue); + vsk->rejected = false; + vsk->sent_request = false; + vsk->ignore_connecting_rst = false; + vsk->peer_shutdown = 0; + INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout); + INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work); + + psk = parent ? vsock_sk(parent) : NULL; + if (parent) { + vsk->trusted = psk->trusted; + vsk->owner = get_cred(psk->owner); + vsk->connect_timeout = psk->connect_timeout; + vsk->buffer_size = psk->buffer_size; + vsk->buffer_min_size = psk->buffer_min_size; + vsk->buffer_max_size = psk->buffer_max_size; + security_sk_clone(parent, sk); + } else { + vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN); + vsk->owner = get_current_cred(); + vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; + vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE; + vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE; + vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE; + } + + return sk; +} + +static bool sock_type_connectible(u16 type) +{ + return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET); +} + +static void __vsock_release(struct sock *sk, int level) +{ + if (sk) { + struct sock *pending; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + pending = NULL; /* Compiler warning. */ + + /* When "level" is SINGLE_DEPTH_NESTING, use the nested + * version to avoid the warning "possible recursive locking + * detected". When "level" is 0, lock_sock_nested(sk, level) + * is the same as lock_sock(sk). + */ + lock_sock_nested(sk, level); + + if (vsk->transport) + vsk->transport->release(vsk); + else if (sock_type_connectible(sk->sk_type)) + vsock_remove_sock(vsk); + + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + + skb_queue_purge(&sk->sk_receive_queue); + + /* Clean up any sockets that never were accepted. */ + while ((pending = vsock_dequeue_accept(sk)) != NULL) { + __vsock_release(pending, SINGLE_DEPTH_NESTING); + sock_put(pending); + } + + release_sock(sk); + sock_put(sk); + } +} + +static void vsock_sk_destruct(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + vsock_deassign_transport(vsk); + + /* When clearing these addresses, there's no need to set the family and + * possibly register the address family with the kernel. + */ + vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + put_cred(vsk->owner); +} + +static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int err; + + err = sock_queue_rcv_skb(sk, skb); + if (err) + kfree_skb(skb); + + return err; +} + +struct sock *vsock_create_connected(struct sock *parent) +{ + return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL, + parent->sk_type, 0); +} +EXPORT_SYMBOL_GPL(vsock_create_connected); + +s64 vsock_stream_has_data(struct vsock_sock *vsk) +{ + return vsk->transport->stream_has_data(vsk); +} +EXPORT_SYMBOL_GPL(vsock_stream_has_data); + +s64 vsock_connectible_has_data(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + + if (sk->sk_type == SOCK_SEQPACKET) + return vsk->transport->seqpacket_has_data(vsk); + else + return vsock_stream_has_data(vsk); +} +EXPORT_SYMBOL_GPL(vsock_connectible_has_data); + +s64 vsock_stream_has_space(struct vsock_sock *vsk) +{ + return vsk->transport->stream_has_space(vsk); +} +EXPORT_SYMBOL_GPL(vsock_stream_has_space); + +void vsock_data_ready(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat || + sock_flag(sk, SOCK_DONE)) + sk->sk_data_ready(sk); +} +EXPORT_SYMBOL_GPL(vsock_data_ready); + +static int vsock_release(struct socket *sock) +{ + __vsock_release(sock->sk, 0); + sock->sk = NULL; + sock->state = SS_FREE; + + return 0; +} + +static int +vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + int err; + struct sock *sk; + struct sockaddr_vm *vm_addr; + + sk = sock->sk; + + if (vsock_addr_cast(addr, addr_len, &vm_addr) != 0) + return -EINVAL; + + lock_sock(sk); + err = __vsock_bind(sk, vm_addr); + release_sock(sk); + + return err; +} + +static int vsock_getname(struct socket *sock, + struct sockaddr *addr, int peer) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *vm_addr; + + sk = sock->sk; + vsk = vsock_sk(sk); + err = 0; + + lock_sock(sk); + + if (peer) { + if (sock->state != SS_CONNECTED) { + err = -ENOTCONN; + goto out; + } + vm_addr = &vsk->remote_addr; + } else { + vm_addr = &vsk->local_addr; + } + + if (!vm_addr) { + err = -EINVAL; + goto out; + } + + /* sys_getsockname() and sys_getpeername() pass us a + * MAX_SOCK_ADDR-sized buffer and don't set addr_len. Unfortunately + * that macro is defined in socket.c instead of .h, so we hardcode its + * value here. + */ + BUILD_BUG_ON(sizeof(*vm_addr) > 128); + memcpy(addr, vm_addr, sizeof(*vm_addr)); + err = sizeof(*vm_addr); + +out: + release_sock(sk); + return err; +} + +static int vsock_shutdown(struct socket *sock, int mode) +{ + int err; + struct sock *sk; + + /* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses + * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode + * here like the other address families do. Note also that the + * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3), + * which is what we want. + */ + mode++; + + if ((mode & ~SHUTDOWN_MASK) || !mode) + return -EINVAL; + + /* If this is a connection oriented socket and it is not connected then + * bail out immediately. If it is a DGRAM socket then we must first + * kick the socket so that it wakes up from any sleeping calls, for + * example recv(), and then afterwards return the error. + */ + + sk = sock->sk; + + lock_sock(sk); + if (sock->state == SS_UNCONNECTED) { + err = -ENOTCONN; + if (sock_type_connectible(sk->sk_type)) + goto out; + } else { + sock->state = SS_DISCONNECTING; + err = 0; + } + + /* Receive and send shutdowns are treated alike. */ + mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN); + if (mode) { + sk->sk_shutdown |= mode; + sk->sk_state_change(sk); + + if (sock_type_connectible(sk->sk_type)) { + sock_reset_flag(sk, SOCK_DONE); + vsock_send_shutdown(sk, mode); + } + } + +out: + release_sock(sk); + return err; +} + +static __poll_t vsock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk; + __poll_t mask; + struct vsock_sock *vsk; + + sk = sock->sk; + vsk = vsock_sk(sk); + + poll_wait(file, sk_sleep(sk), wait); + mask = 0; + + if (sk->sk_err) + /* Signify that there has been an error on this socket. */ + mask |= EPOLLERR; + + /* INET sockets treat local write shutdown and peer write shutdown as a + * case of EPOLLHUP set. + */ + if ((sk->sk_shutdown == SHUTDOWN_MASK) || + ((sk->sk_shutdown & SEND_SHUTDOWN) && + (vsk->peer_shutdown & SEND_SHUTDOWN))) { + mask |= EPOLLHUP; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN || + vsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= EPOLLRDHUP; + } + + if (sock->type == SOCK_DGRAM) { + /* For datagram sockets we can read if there is something in + * the queue and write as long as the socket isn't shutdown for + * sending. + */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue) || + (sk->sk_shutdown & RCV_SHUTDOWN)) { + mask |= EPOLLIN | EPOLLRDNORM; + } + + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + } else if (sock_type_connectible(sk->sk_type)) { + const struct vsock_transport *transport; + + lock_sock(sk); + + transport = vsk->transport; + + /* Listening sockets that have connections in their accept + * queue can be read. + */ + if (sk->sk_state == TCP_LISTEN + && !vsock_is_accept_queue_empty(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* If there is something in the queue then we can read. */ + if (transport && transport->stream_is_active(vsk) && + !(sk->sk_shutdown & RCV_SHUTDOWN)) { + bool data_ready_now = false; + int target = sock_rcvlowat(sk, 0, INT_MAX); + int ret = transport->notify_poll_in( + vsk, target, &data_ready_now); + if (ret < 0) { + mask |= EPOLLERR; + } else { + if (data_ready_now) + mask |= EPOLLIN | EPOLLRDNORM; + + } + } + + /* Sockets whose connections have been closed, reset, or + * terminated should also be considered read, and we check the + * shutdown flag for that. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN || + vsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= EPOLLIN | EPOLLRDNORM; + } + + /* Connected sockets that can produce data can be written. */ + if (transport && sk->sk_state == TCP_ESTABLISHED) { + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { + bool space_avail_now = false; + int ret = transport->notify_poll_out( + vsk, 1, &space_avail_now); + if (ret < 0) { + mask |= EPOLLERR; + } else { + if (space_avail_now) + /* Remove EPOLLWRBAND since INET + * sockets are not setting it. + */ + mask |= EPOLLOUT | EPOLLWRNORM; + + } + } + } + + /* Simulate INET socket poll behaviors, which sets + * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read, + * but local send is not shutdown. + */ + if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) { + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) + mask |= EPOLLOUT | EPOLLWRNORM; + + } + + release_sock(sk); + } + + return mask; +} + +static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + return vsk->transport->read_skb(vsk, read_actor); +} + +static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* For now, MSG_DONTWAIT is always assumed... */ + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + transport = vsk->transport; + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + + /* If the provided message contains an address, use that. Otherwise + * fall back on the socket's remote handle (if it has been connected). + */ + if (msg->msg_name && + vsock_addr_cast(msg->msg_name, msg->msg_namelen, + &remote_addr) == 0) { + /* Ensure this address is of the right type and is a valid + * destination. + */ + + if (remote_addr->svm_cid == VMADDR_CID_ANY) + remote_addr->svm_cid = transport->get_local_cid(); + + if (!vsock_addr_bound(remote_addr)) { + err = -EINVAL; + goto out; + } + } else if (sock->state == SS_CONNECTED) { + remote_addr = &vsk->remote_addr; + + if (remote_addr->svm_cid == VMADDR_CID_ANY) + remote_addr->svm_cid = transport->get_local_cid(); + + /* XXX Should connect() or this function ensure remote_addr is + * bound? + */ + if (!vsock_addr_bound(&vsk->remote_addr)) { + err = -EINVAL; + goto out; + } + } else { + err = -EINVAL; + goto out; + } + + if (!transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -EINVAL; + goto out; + } + + err = transport->dgram_enqueue(vsk, remote_addr, msg, len); + +out: + release_sock(sk); + return err; +} + +static int vsock_dgram_connect(struct socket *sock, + struct sockaddr *addr, int addr_len, int flags) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *remote_addr; + + sk = sock->sk; + vsk = vsock_sk(sk); + + err = vsock_addr_cast(addr, addr_len, &remote_addr); + if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) { + lock_sock(sk); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, + VMADDR_PORT_ANY); + sock->state = SS_UNCONNECTED; + release_sock(sk); + return 0; + } else if (err != 0) + return -EINVAL; + + lock_sock(sk); + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -EINVAL; + goto out; + } + + memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); + sock->state = SS_CONNECTED; + + /* sock map disallows redirection of non-TCP sockets with sk_state != + * TCP_ESTABLISHED (see sock_map_redirect_allowed()), so we set + * TCP_ESTABLISHED here to allow redirection of connected vsock dgrams. + * + * This doesn't seem to be abnormal state for datagram sockets, as the + * same approach can be see in other datagram socket types as well + * (such as unix sockets). + */ + sk->sk_state = TCP_ESTABLISHED; + +out: + release_sock(sk); + return err; +} + +int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ +#ifdef CONFIG_BPF_SYSCALL + const struct proto *prot; +#endif + struct vsock_sock *vsk; + struct sock *sk; + + sk = sock->sk; + vsk = vsock_sk(sk); + +#ifdef CONFIG_BPF_SYSCALL + prot = READ_ONCE(sk->sk_prot); + if (prot != &vsock_proto) + return prot->recvmsg(sk, msg, len, flags, NULL); +#endif + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); +} +EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg); + +static const struct proto_ops vsock_dgram_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = vsock_shutdown, + .sendmsg = vsock_dgram_sendmsg, + .recvmsg = vsock_dgram_recvmsg, + .mmap = sock_no_mmap, + .read_skb = vsock_read_skb, +}; + +static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) +{ + const struct vsock_transport *transport = vsk->transport; + + if (!transport || !transport->cancel_pkt) + return -EOPNOTSUPP; + + return transport->cancel_pkt(vsk); +} + +static void vsock_connect_timeout(struct work_struct *work) +{ + struct sock *sk; + struct vsock_sock *vsk; + + vsk = container_of(work, struct vsock_sock, connect_work.work); + sk = sk_vsock(vsk); + + lock_sock(sk); + if (sk->sk_state == TCP_SYN_SENT && + (sk->sk_shutdown != SHUTDOWN_MASK)) { + sk->sk_state = TCP_CLOSE; + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_err = ETIMEDOUT; + sk_error_report(sk); + vsock_transport_cancel_pkt(vsk); + } + release_sock(sk); + + sock_put(sk); +} + +static int vsock_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + struct sockaddr_vm *remote_addr; + long timeout; + DEFINE_WAIT(wait); + + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + /* XXX AF_UNSPEC should make us disconnect like AF_INET. */ + switch (sock->state) { + case SS_CONNECTED: + err = -EISCONN; + goto out; + case SS_DISCONNECTING: + err = -EINVAL; + goto out; + case SS_CONNECTING: + /* This continues on so we can move sock into the SS_CONNECTED + * state once the connection has completed (at which point err + * will be set to zero also). Otherwise, we will either wait + * for the connection or return -EALREADY should this be a + * non-blocking call. + */ + err = -EALREADY; + if (flags & O_NONBLOCK) + goto out; + break; + default: + if ((sk->sk_state == TCP_LISTEN) || + vsock_addr_cast(addr, addr_len, &remote_addr) != 0) { + err = -EINVAL; + goto out; + } + + /* Set the remote address that we are connecting to. */ + memcpy(&vsk->remote_addr, remote_addr, + sizeof(vsk->remote_addr)); + + err = vsock_assign_transport(vsk, NULL); + if (err) + goto out; + + transport = vsk->transport; + + /* The hypervisor and well-known contexts do not have socket + * endpoints. + */ + if (!transport || + !transport->stream_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -ENETUNREACH; + goto out; + } + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + sk->sk_state = TCP_SYN_SENT; + + err = transport->connect(vsk); + if (err < 0) + goto out; + + /* Mark sock as connecting and set the error code to in + * progress in case this is a non-blocking connect. + */ + sock->state = SS_CONNECTING; + err = -EINPROGRESS; + } + + /* The receive path will handle all communication until we are able to + * enter the connected state. Here we wait for the connection to be + * completed or a notification of an error. + */ + timeout = vsk->connect_timeout; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) { + if (flags & O_NONBLOCK) { + /* If we're not going to block, we schedule a timeout + * function to generate a timeout on the connection + * attempt, in case the peer doesn't respond in a + * timely manner. We hold on to the socket until the + * timeout fires. + */ + sock_hold(sk); + + /* If the timeout function is already scheduled, + * reschedule it, then ungrab the socket refcount to + * keep it balanced. + */ + if (mod_delayed_work(system_wq, &vsk->connect_work, + timeout)) + sock_put(sk); + + /* Skip ahead to preserve error code set above. */ + goto out_wait; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + sk->sk_state = sk->sk_state == TCP_ESTABLISHED ? TCP_CLOSING : TCP_CLOSE; + sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + vsock_remove_connected(vsk); + goto out_wait; + } else if ((sk->sk_state != TCP_ESTABLISHED) && (timeout == 0)) { + err = -ETIMEDOUT; + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + goto out_wait; + } + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + } + + if (sk->sk_err) { + err = -sk->sk_err; + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + } else { + err = 0; + } + +out_wait: + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + return err; +} + +static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *listener; + int err; + struct sock *connected; + struct vsock_sock *vconnected; + long timeout; + DEFINE_WAIT(wait); + + err = 0; + listener = sock->sk; + + lock_sock(listener); + + if (!sock_type_connectible(sock->type)) { + err = -EOPNOTSUPP; + goto out; + } + + if (listener->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto out; + } + + /* Wait for children sockets to appear; these are the new sockets + * created upon connection establishment. + */ + timeout = sock_rcvtimeo(listener, flags & O_NONBLOCK); + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + + while ((connected = vsock_dequeue_accept(listener)) == NULL && + listener->sk_err == 0) { + release_sock(listener); + timeout = schedule_timeout(timeout); + finish_wait(sk_sleep(listener), &wait); + lock_sock(listener); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + goto out; + } else if (timeout == 0) { + err = -EAGAIN; + goto out; + } + + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + } + finish_wait(sk_sleep(listener), &wait); + + if (listener->sk_err) + err = -listener->sk_err; + + if (connected) { + sk_acceptq_removed(listener); + + lock_sock_nested(connected, SINGLE_DEPTH_NESTING); + vconnected = vsock_sk(connected); + + /* If the listener socket has received an error, then we should + * reject this socket and return. Note that we simply mark the + * socket rejected, drop our reference, and let the cleanup + * function handle the cleanup; the fact that we found it in + * the listener's accept queue guarantees that the cleanup + * function hasn't run yet. + */ + if (err) { + vconnected->rejected = true; + } else { + newsock->state = SS_CONNECTED; + sock_graft(connected, newsock); + } + + release_sock(connected); + sock_put(connected); + } + +out: + release_sock(listener); + return err; +} + +static int vsock_listen(struct socket *sock, int backlog) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + + sk = sock->sk; + + lock_sock(sk); + + if (!sock_type_connectible(sk->sk_type)) { + err = -EOPNOTSUPP; + goto out; + } + + if (sock->state != SS_UNCONNECTED) { + err = -EINVAL; + goto out; + } + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) { + err = -EINVAL; + goto out; + } + + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + + err = 0; + +out: + release_sock(sk); + return err; +} + +static void vsock_update_buffer_size(struct vsock_sock *vsk, + const struct vsock_transport *transport, + u64 val) +{ + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + + if (val < vsk->buffer_min_size) + val = vsk->buffer_min_size; + + if (val != vsk->buffer_size && + transport && transport->notify_buffer_size) + transport->notify_buffer_size(vsk, &val); + + vsk->buffer_size = val; +} + +static int vsock_connectible_setsockopt(struct socket *sock, + int level, + int optname, + sockptr_t optval, + unsigned int optlen) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + u64 val; + + if (level != AF_VSOCK) + return -ENOPROTOOPT; + +#define COPY_IN(_v) \ + do { \ + if (optlen < sizeof(_v)) { \ + err = -EINVAL; \ + goto exit; \ + } \ + if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) { \ + err = -EFAULT; \ + goto exit; \ + } \ + } while (0) + + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + transport = vsk->transport; + + switch (optname) { + case SO_VM_SOCKETS_BUFFER_SIZE: + COPY_IN(val); + vsock_update_buffer_size(vsk, transport, val); + break; + + case SO_VM_SOCKETS_BUFFER_MAX_SIZE: + COPY_IN(val); + vsk->buffer_max_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); + break; + + case SO_VM_SOCKETS_BUFFER_MIN_SIZE: + COPY_IN(val); + vsk->buffer_min_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); + break; + + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: { + struct __kernel_sock_timeval tv; + + err = sock_copy_user_timeval(&tv, optval, optlen, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); + if (err) + break; + if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC && + tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) { + vsk->connect_timeout = tv.tv_sec * HZ + + DIV_ROUND_UP((unsigned long)tv.tv_usec, (USEC_PER_SEC / HZ)); + if (vsk->connect_timeout == 0) + vsk->connect_timeout = + VSOCK_DEFAULT_CONNECT_TIMEOUT; + + } else { + err = -ERANGE; + } + break; + } + + default: + err = -ENOPROTOOPT; + break; + } + +#undef COPY_IN + +exit: + release_sock(sk); + return err; +} + +static int vsock_connectible_getsockopt(struct socket *sock, + int level, int optname, + char __user *optval, + int __user *optlen) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk = vsock_sk(sk); + + union { + u64 val64; + struct old_timeval32 tm32; + struct __kernel_old_timeval tm; + struct __kernel_sock_timeval stm; + } v; + + int lv = sizeof(v.val64); + int len; + + if (level != AF_VSOCK) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + memset(&v, 0, sizeof(v)); + + switch (optname) { + case SO_VM_SOCKETS_BUFFER_SIZE: + v.val64 = vsk->buffer_size; + break; + + case SO_VM_SOCKETS_BUFFER_MAX_SIZE: + v.val64 = vsk->buffer_max_size; + break; + + case SO_VM_SOCKETS_BUFFER_MIN_SIZE: + v.val64 = vsk->buffer_min_size; + break; + + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: + lv = sock_get_timeout(vsk->connect_timeout, &v, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); + break; + + default: + return -ENOPROTOOPT; + } + + if (len < lv) + return -EINVAL; + if (len > lv) + len = lv; + if (copy_to_user(optval, &v, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + ssize_t total_written; + long timeout; + int err; + struct vsock_transport_send_notify_data send_data; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + sk = sock->sk; + vsk = vsock_sk(sk); + total_written = 0; + err = 0; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + lock_sock(sk); + + transport = vsk->transport; + + /* Callers should not provide a destination with connection oriented + * sockets. + */ + if (msg->msg_namelen) { + err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP; + goto out; + } + + /* Send data only if both sides are not shutdown in the direction. */ + if (sk->sk_shutdown & SEND_SHUTDOWN || + vsk->peer_shutdown & RCV_SHUTDOWN) { + err = -EPIPE; + goto out; + } + + if (!transport || sk->sk_state != TCP_ESTABLISHED || + !vsock_addr_bound(&vsk->local_addr)) { + err = -ENOTCONN; + goto out; + } + + if (!vsock_addr_bound(&vsk->remote_addr)) { + err = -EDESTADDRREQ; + goto out; + } + + /* Wait for room in the produce queue to enqueue our user's data. */ + timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + err = transport->notify_send_init(vsk, &send_data); + if (err < 0) + goto out; + + while (total_written < len) { + ssize_t written; + + add_wait_queue(sk_sleep(sk), &wait); + while (vsock_stream_has_space(vsk) == 0 && + sk->sk_err == 0 && + !(sk->sk_shutdown & SEND_SHUTDOWN) && + !(vsk->peer_shutdown & RCV_SHUTDOWN)) { + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + + err = transport->notify_send_pre_block(vsk, &send_data); + if (err < 0) { + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + + release_sock(sk); + timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout); + lock_sock(sk); + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } else if (timeout == 0) { + err = -EAGAIN; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + } + remove_wait_queue(sk_sleep(sk), &wait); + + /* These checks occur both as part of and after the loop + * conditional since we need to check before and after + * sleeping. + */ + if (sk->sk_err) { + err = -sk->sk_err; + goto out_err; + } else if ((sk->sk_shutdown & SEND_SHUTDOWN) || + (vsk->peer_shutdown & RCV_SHUTDOWN)) { + err = -EPIPE; + goto out_err; + } + + err = transport->notify_send_pre_enqueue(vsk, &send_data); + if (err < 0) + goto out_err; + + /* Note that enqueue will only write as many bytes as are free + * in the produce queue, so we don't need to ensure len is + * smaller than the queue size. It is the caller's + * responsibility to check how many bytes we were able to send. + */ + + if (sk->sk_type == SOCK_SEQPACKET) { + written = transport->seqpacket_enqueue(vsk, + msg, len - total_written); + } else { + written = transport->stream_enqueue(vsk, + msg, len - total_written); + } + + if (written < 0) { + err = written; + goto out_err; + } + + total_written += written; + + err = transport->notify_send_post_enqueue( + vsk, written, &send_data); + if (err < 0) + goto out_err; + + } + +out_err: + if (total_written > 0) { + /* Return number of written bytes only if: + * 1) SOCK_STREAM socket. + * 2) SOCK_SEQPACKET socket when whole buffer is sent. + */ + if (sk->sk_type == SOCK_STREAM || total_written == len) + err = total_written; + } +out: + release_sock(sk); + return err; +} + +static int vsock_connectible_wait_data(struct sock *sk, + struct wait_queue_entry *wait, + long timeout, + struct vsock_transport_recv_notify_data *recv_data, + size_t target) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + s64 data; + int err; + + vsk = vsock_sk(sk); + err = 0; + transport = vsk->transport; + + while (1) { + prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE); + data = vsock_connectible_has_data(vsk); + if (data != 0) + break; + + if (sk->sk_err != 0 || + (sk->sk_shutdown & RCV_SHUTDOWN) || + (vsk->peer_shutdown & SEND_SHUTDOWN)) { + break; + } + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + break; + } + + if (recv_data) { + err = transport->notify_recv_pre_block(vsk, target, recv_data); + if (err < 0) + break; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + break; + } else if (timeout == 0) { + err = -EAGAIN; + break; + } + } + + finish_wait(sk_sleep(sk), wait); + + if (err) + return err; + + /* Internal transport error when checking for available + * data. XXX This should be changed to a connection + * reset in a later change. + */ + if (data < 0) + return -ENOMEM; + + return data; +} + +static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + struct vsock_transport_recv_notify_data recv_data; + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t copied; + size_t target; + long timeout; + int err; + + DEFINE_WAIT(wait); + + vsk = vsock_sk(sk); + transport = vsk->transport; + + /* We must not copy less than target bytes into the user's buffer + * before returning successfully, so we wait for the consume queue to + * have that much data to consume before dequeueing. Note that this + * makes it impossible to handle cases where target is greater than the + * queue size. + */ + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + if (target >= transport->stream_rcvhiwat(vsk)) { + err = -ENOMEM; + goto out; + } + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + copied = 0; + + err = transport->notify_recv_init(vsk, target, &recv_data); + if (err < 0) + goto out; + + + while (1) { + ssize_t read; + + err = vsock_connectible_wait_data(sk, &wait, timeout, + &recv_data, target); + if (err <= 0) + break; + + err = transport->notify_recv_pre_dequeue(vsk, target, + &recv_data); + if (err < 0) + break; + + read = transport->stream_dequeue(vsk, msg, len - copied, flags); + if (read < 0) { + err = read; + break; + } + + copied += read; + + err = transport->notify_recv_post_dequeue(vsk, target, read, + !(flags & MSG_PEEK), &recv_data); + if (err < 0) + goto out; + + if (read >= target || flags & MSG_PEEK) + break; + + target -= read; + } + + if (sk->sk_err) + err = -sk->sk_err; + else if (sk->sk_shutdown & RCV_SHUTDOWN) + err = 0; + + if (copied > 0) + err = copied; + +out: + return err; +} + +static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t msg_len; + long timeout; + int err = 0; + DEFINE_WAIT(wait); + + vsk = vsock_sk(sk); + transport = vsk->transport; + + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0); + if (err <= 0) + goto out; + + msg_len = transport->seqpacket_dequeue(vsk, msg, flags); + + if (msg_len < 0) { + err = msg_len; + goto out; + } + + if (sk->sk_err) { + err = -sk->sk_err; + } else if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = 0; + } else { + /* User sets MSG_TRUNC, so return real length of + * packet. + */ + if (flags & MSG_TRUNC) + err = msg_len; + else + err = len - msg_data_left(msg); + + /* Always set MSG_TRUNC if real length of packet is + * bigger than user's buffer. + */ + if (msg_len > len) + msg->msg_flags |= MSG_TRUNC; + } + +out: + return err; +} + +int +vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; +#ifdef CONFIG_BPF_SYSCALL + const struct proto *prot; +#endif + int err; + + sk = sock->sk; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_VSOCK, VSOCK_RECVERR); + + vsk = vsock_sk(sk); + err = 0; + + lock_sock(sk); + + transport = vsk->transport; + + if (!transport || sk->sk_state != TCP_ESTABLISHED) { + /* Recvmsg is supposed to return 0 if a peer performs an + * orderly shutdown. Differentiate between that case and when a + * peer has not connected or a local shutdown occurred with the + * SOCK_DONE flag. + */ + if (sock_flag(sk, SOCK_DONE)) + err = 0; + else + err = -ENOTCONN; + + goto out; + } + + if (flags & MSG_OOB) { + err = -EOPNOTSUPP; + goto out; + } + + /* We don't check peer_shutdown flag here since peer may actually shut + * down, but there can be data in the queue that a local socket can + * receive. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = 0; + goto out; + } + + /* It is valid on Linux to pass in a zero-length receive buffer. This + * is not an error. We may as well bail out now. + */ + if (!len) { + err = 0; + goto out; + } + +#ifdef CONFIG_BPF_SYSCALL + prot = READ_ONCE(sk->sk_prot); + if (prot != &vsock_proto) { + release_sock(sk); + return prot->recvmsg(sk, msg, len, flags, NULL); + } +#endif + + if (sk->sk_type == SOCK_STREAM) + err = __vsock_stream_recvmsg(sk, msg, len, flags); + else + err = __vsock_seqpacket_recvmsg(sk, msg, len, flags); + +out: + release_sock(sk); + return err; +} +EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg); + +static int vsock_set_rcvlowat(struct sock *sk, int val) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (val > vsk->buffer_size) + return -EINVAL; + + transport = vsk->transport; + + if (transport && transport->notify_set_rcvlowat) { + int err; + + err = transport->notify_set_rcvlowat(vsk, val); + if (err) + return err; + } + + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + return 0; +} + +static const struct proto_ops vsock_stream_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_connect, + .socketpair = sock_no_socketpair, + .accept = vsock_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = vsock_listen, + .shutdown = vsock_shutdown, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, + .mmap = sock_no_mmap, + .set_rcvlowat = vsock_set_rcvlowat, + .read_skb = vsock_read_skb, +}; + +static const struct proto_ops vsock_seqpacket_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_connect, + .socketpair = sock_no_socketpair, + .accept = vsock_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = vsock_listen, + .shutdown = vsock_shutdown, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, + .mmap = sock_no_mmap, + .read_skb = vsock_read_skb, +}; + +static int vsock_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct vsock_sock *vsk; + struct sock *sk; + int ret; + + if (!sock) + return -EINVAL; + + if (protocol && protocol != PF_VSOCK) + return -EPROTONOSUPPORT; + + switch (sock->type) { + case SOCK_DGRAM: + sock->ops = &vsock_dgram_ops; + break; + case SOCK_STREAM: + sock->ops = &vsock_stream_ops; + break; + case SOCK_SEQPACKET: + sock->ops = &vsock_seqpacket_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sock->state = SS_UNCONNECTED; + + sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern); + if (!sk) + return -ENOMEM; + + vsk = vsock_sk(sk); + + if (sock->type == SOCK_DGRAM) { + ret = vsock_assign_transport(vsk, NULL); + if (ret < 0) { + sock_put(sk); + return ret; + } + } + + vsock_insert_unbound(vsk); + + return 0; +} + +static const struct net_proto_family vsock_family_ops = { + .family = AF_VSOCK, + .create = vsock_create, + .owner = THIS_MODULE, +}; + +static long vsock_dev_do_ioctl(struct file *filp, + unsigned int cmd, void __user *ptr) +{ + u32 __user *p = ptr; + u32 cid = VMADDR_CID_ANY; + int retval = 0; + + switch (cmd) { + case IOCTL_VM_SOCKETS_GET_LOCAL_CID: + /* To be compatible with the VMCI behavior, we prioritize the + * guest CID instead of well-know host CID (VMADDR_CID_HOST). + */ + if (transport_g2h) + cid = transport_g2h->get_local_cid(); + else if (transport_h2g) + cid = transport_h2g->get_local_cid(); + + if (put_user(cid, p) != 0) + retval = -EFAULT; + break; + + default: + retval = -ENOIOCTLCMD; + } + + return retval; +} + +static long vsock_dev_ioctl(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + return vsock_dev_do_ioctl(filp, cmd, (void __user *)arg); +} + +#ifdef CONFIG_COMPAT +static long vsock_dev_compat_ioctl(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + return vsock_dev_do_ioctl(filp, cmd, compat_ptr(arg)); +} +#endif + +static const struct file_operations vsock_device_ops = { + .owner = THIS_MODULE, + .unlocked_ioctl = vsock_dev_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = vsock_dev_compat_ioctl, +#endif + .open = nonseekable_open, +}; + +static struct miscdevice vsock_device = { + .name = "vsock", + .fops = &vsock_device_ops, +}; + +static int __init vsock_init(void) +{ + int err = 0; + + vsock_init_tables(); + + vsock_proto.owner = THIS_MODULE; + vsock_device.minor = MISC_DYNAMIC_MINOR; + err = misc_register(&vsock_device); + if (err) { + pr_err("Failed to register misc device\n"); + goto err_reset_transport; + } + + err = proto_register(&vsock_proto, 1); /* we want our slab */ + if (err) { + pr_err("Cannot register vsock protocol\n"); + goto err_deregister_misc; + } + + err = sock_register(&vsock_family_ops); + if (err) { + pr_err("could not register af_vsock (%d) address family: %d\n", + AF_VSOCK, err); + goto err_unregister_proto; + } + + vsock_bpf_build_proto(); + + return 0; + +err_unregister_proto: + proto_unregister(&vsock_proto); +err_deregister_misc: + misc_deregister(&vsock_device); +err_reset_transport: + return err; +} + +static void __exit vsock_exit(void) +{ + misc_deregister(&vsock_device); + sock_unregister(AF_VSOCK); + proto_unregister(&vsock_proto); +} + +const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) +{ + return vsk->transport; +} +EXPORT_SYMBOL_GPL(vsock_core_get_transport); + +int vsock_core_register(const struct vsock_transport *t, int features) +{ + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local; + int err = mutex_lock_interruptible(&vsock_register_mutex); + + if (err) + return err; + + t_h2g = transport_h2g; + t_g2h = transport_g2h; + t_dgram = transport_dgram; + t_local = transport_local; + + if (features & VSOCK_TRANSPORT_F_H2G) { + if (t_h2g) { + err = -EBUSY; + goto err_busy; + } + t_h2g = t; + } + + if (features & VSOCK_TRANSPORT_F_G2H) { + if (t_g2h) { + err = -EBUSY; + goto err_busy; + } + t_g2h = t; + } + + if (features & VSOCK_TRANSPORT_F_DGRAM) { + if (t_dgram) { + err = -EBUSY; + goto err_busy; + } + t_dgram = t; + } + + if (features & VSOCK_TRANSPORT_F_LOCAL) { + if (t_local) { + err = -EBUSY; + goto err_busy; + } + t_local = t; + } + + transport_h2g = t_h2g; + transport_g2h = t_g2h; + transport_dgram = t_dgram; + transport_local = t_local; + +err_busy: + mutex_unlock(&vsock_register_mutex); + return err; +} +EXPORT_SYMBOL_GPL(vsock_core_register); + +void vsock_core_unregister(const struct vsock_transport *t) +{ + mutex_lock(&vsock_register_mutex); + + if (transport_h2g == t) + transport_h2g = NULL; + + if (transport_g2h == t) + transport_g2h = NULL; + + if (transport_dgram == t) + transport_dgram = NULL; + + if (transport_local == t) + transport_local = NULL; + + mutex_unlock(&vsock_register_mutex); +} +EXPORT_SYMBOL_GPL(vsock_core_unregister); + +module_init(vsock_init); +module_exit(vsock_exit); + +MODULE_AUTHOR("VMware, Inc."); +MODULE_DESCRIPTION("VMware Virtual Socket Family"); +MODULE_VERSION("1.0.2.0-k"); +MODULE_LICENSE("GPL v2"); diff --git a/net/vmw_vsock/af_vsock_tap.c b/net/vmw_vsock/af_vsock_tap.c new file mode 100644 index 0000000000..30ee7e4fe2 --- /dev/null +++ b/net/vmw_vsock/af_vsock_tap.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Tap functions for AF_VSOCK sockets. + * + * Code based on net/netlink/af_netlink.c tap functions. + */ + +#include <linux/module.h> +#include <net/sock.h> +#include <net/af_vsock.h> +#include <linux/if_arp.h> + +static DEFINE_SPINLOCK(vsock_tap_lock); +static struct list_head vsock_tap_all __read_mostly = + LIST_HEAD_INIT(vsock_tap_all); + +int vsock_add_tap(struct vsock_tap *vt) +{ + if (unlikely(vt->dev->type != ARPHRD_VSOCKMON)) + return -EINVAL; + + __module_get(vt->module); + + spin_lock(&vsock_tap_lock); + list_add_rcu(&vt->list, &vsock_tap_all); + spin_unlock(&vsock_tap_lock); + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_add_tap); + +int vsock_remove_tap(struct vsock_tap *vt) +{ + struct vsock_tap *tmp; + bool found = false; + + spin_lock(&vsock_tap_lock); + + list_for_each_entry(tmp, &vsock_tap_all, list) { + if (vt == tmp) { + list_del_rcu(&vt->list); + found = true; + goto out; + } + } + + pr_warn("vsock_remove_tap: %p not found\n", vt); +out: + spin_unlock(&vsock_tap_lock); + + synchronize_net(); + + if (found) + module_put(vt->module); + + return found ? 0 : -ENODEV; +} +EXPORT_SYMBOL_GPL(vsock_remove_tap); + +static int __vsock_deliver_tap_skb(struct sk_buff *skb, + struct net_device *dev) +{ + int ret = 0; + struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC); + + if (nskb) { + dev_hold(dev); + + nskb->dev = dev; + ret = dev_queue_xmit(nskb); + if (unlikely(ret > 0)) + ret = net_xmit_errno(ret); + + dev_put(dev); + } + + return ret; +} + +static void __vsock_deliver_tap(struct sk_buff *skb) +{ + int ret; + struct vsock_tap *tmp; + + list_for_each_entry_rcu(tmp, &vsock_tap_all, list) { + ret = __vsock_deliver_tap_skb(skb, tmp->dev); + if (unlikely(ret)) + break; + } +} + +void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque) +{ + struct sk_buff *skb; + + rcu_read_lock(); + + if (likely(list_empty(&vsock_tap_all))) + goto out; + + skb = build_skb(opaque); + if (skb) { + __vsock_deliver_tap(skb); + consume_skb(skb); + } + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(vsock_deliver_tap); diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c new file mode 100644 index 0000000000..a2823b1c5e --- /dev/null +++ b/net/vmw_vsock/diag.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * vsock sock_diag(7) module + * + * Copyright (C) 2017 Red Hat, Inc. + * Author: Stefan Hajnoczi <stefanha@redhat.com> + */ + +#include <linux/module.h> +#include <linux/sock_diag.h> +#include <linux/vm_sockets_diag.h> +#include <net/af_vsock.h> + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + u32 portid, u32 seq, u32 flags) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_diag_msg *rep; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), + flags); + if (!nlh) + return -EMSGSIZE; + + rep = nlmsg_data(nlh); + rep->vdiag_family = AF_VSOCK; + + /* Lock order dictates that sk_lock is acquired before + * vsock_table_lock, so we cannot lock here. Simply don't take + * sk_lock; sk is guaranteed to stay alive since vsock_table_lock is + * held. + */ + rep->vdiag_type = sk->sk_type; + rep->vdiag_state = sk->sk_state; + rep->vdiag_shutdown = sk->sk_shutdown; + rep->vdiag_src_cid = vsk->local_addr.svm_cid; + rep->vdiag_src_port = vsk->local_addr.svm_port; + rep->vdiag_dst_cid = vsk->remote_addr.svm_cid; + rep->vdiag_dst_port = vsk->remote_addr.svm_port; + rep->vdiag_ino = sock_i_ino(sk); + + sock_diag_save_cookie(sk, rep->vdiag_cookie); + + return 0; +} + +static int vsock_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct vsock_diag_req *req; + struct vsock_sock *vsk; + unsigned int bucket; + unsigned int last_i; + unsigned int table; + struct net *net; + unsigned int i; + + req = nlmsg_data(cb->nlh); + net = sock_net(skb->sk); + + /* State saved between calls: */ + table = cb->args[0]; + bucket = cb->args[1]; + i = last_i = cb->args[2]; + + /* TODO VMCI pending sockets? */ + + spin_lock_bh(&vsock_table_lock); + + /* Bind table (locally created sockets) */ + if (table == 0) { + while (bucket < ARRAY_SIZE(vsock_bind_table)) { + struct list_head *head = &vsock_bind_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, bound_table) { + struct sock *sk = sk_vsock(vsk); + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_bind; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_bind; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_bind: + i++; + } + last_i = 0; + bucket++; + } + + table++; + bucket = 0; + } + + /* Connected table (accepted connections) */ + while (bucket < ARRAY_SIZE(vsock_connected_table)) { + struct list_head *head = &vsock_connected_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, connected_table) { + struct sock *sk = sk_vsock(vsk); + + /* Skip sockets we've already seen above */ + if (__vsock_in_bound_table(vsk)) + continue; + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_connected; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_connected; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_connected: + i++; + } + last_i = 0; + bucket++; + } + +done: + spin_unlock_bh(&vsock_table_lock); + + cb->args[0] = table; + cb->args[1] = bucket; + cb->args[2] = i; + + return skb->len; +} + +static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct vsock_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = vsock_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } + + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler vsock_diag_handler = { + .family = AF_VSOCK, + .dump = vsock_diag_handler_dump, +}; + +static int __init vsock_diag_init(void) +{ + return sock_diag_register(&vsock_diag_handler); +} + +static void __exit vsock_diag_exit(void) +{ + sock_diag_unregister(&vsock_diag_handler); +} + +module_init(vsock_diag_init); +module_exit(vsock_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, + 40 /* AF_VSOCK */); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c new file mode 100644 index 0000000000..e2157e3872 --- /dev/null +++ b/net/vmw_vsock/hyperv_transport.c @@ -0,0 +1,954 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Hyper-V transport for vsock + * + * Hyper-V Sockets supplies a byte-stream based communication mechanism + * between the host and the VM. This driver implements the necessary + * support in the VM by introducing the new vsock transport. + * + * Copyright (c) 2017, Microsoft Corporation. + */ +#include <linux/module.h> +#include <linux/vmalloc.h> +#include <linux/hyperv.h> +#include <net/sock.h> +#include <net/af_vsock.h> +#include <asm/hyperv-tlfs.h> + +/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some + * stricter requirements on the hv_sock ring buffer size of six 4K pages. + * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this + * limitation; but, keep the defaults the same for compat. + */ +#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64) + +/* The MTU is 16KB per the host side's design */ +#define HVS_MTU_SIZE (1024 * 16) + +/* How long to wait for graceful shutdown of a connection */ +#define HVS_CLOSE_TIMEOUT (8 * HZ) + +struct vmpipe_proto_header { + u32 pkt_type; + u32 data_size; +}; + +/* For recv, we use the VMBus in-place packet iterator APIs to directly copy + * data from the ringbuffer into the userspace buffer. + */ +struct hvs_recv_buf { + /* The header before the payload data */ + struct vmpipe_proto_header hdr; + + /* The payload */ + u8 data[HVS_MTU_SIZE]; +}; + +/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use + * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the + * guest and the host processing as one VMBUS packet is the smallest processing + * unit. + * + * Note: the buffer can be eliminated in the future when we add new VMBus + * ringbuffer APIs that allow us to directly copy data from userspace buffer + * to VMBus ringbuffer. + */ +#define HVS_SEND_BUF_SIZE \ + (HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header)) + +struct hvs_send_buf { + /* The header before the payload data */ + struct vmpipe_proto_header hdr; + + /* The payload */ + u8 data[HVS_SEND_BUF_SIZE]; +}; + +#define HVS_HEADER_LEN (sizeof(struct vmpacket_descriptor) + \ + sizeof(struct vmpipe_proto_header)) + +/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and + * __hv_pkt_iter_next(). + */ +#define VMBUS_PKT_TRAILER_SIZE (sizeof(u64)) + +#define HVS_PKT_LEN(payload_len) (HVS_HEADER_LEN + \ + ALIGN((payload_len), 8) + \ + VMBUS_PKT_TRAILER_SIZE) + +/* Upper bound on the size of a VMbus packet for hv_sock */ +#define HVS_MAX_PKT_SIZE HVS_PKT_LEN(HVS_MTU_SIZE) + +union hvs_service_id { + guid_t srv_id; + + struct { + unsigned int svm_port; + unsigned char b[sizeof(guid_t) - sizeof(unsigned int)]; + }; +}; + +/* Per-socket state (accessed via vsk->trans) */ +struct hvsock { + struct vsock_sock *vsk; + + guid_t vm_srv_id; + guid_t host_srv_id; + + struct vmbus_channel *chan; + struct vmpacket_descriptor *recv_desc; + + /* The length of the payload not delivered to userland yet */ + u32 recv_data_len; + /* The offset of the payload */ + u32 recv_data_off; + + /* Have we sent the zero-length packet (FIN)? */ + bool fin_sent; +}; + +/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is + * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here: + * when we write apps to connect to the host, we can only use VMADDR_CID_ANY + * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we + * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY + * as the local cid. + * + * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV: + * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user- + * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with + * the below sockaddr: + * + * struct SOCKADDR_HV + * { + * ADDRESS_FAMILY Family; + * USHORT Reserved; + * GUID VmId; + * GUID ServiceId; + * }; + * Note: VmID is not used by Linux VM and actually it isn't transmitted via + * VMBus, because here it's obvious the host and the VM can easily identify + * each other. Though the VmID is useful on the host, especially in the case + * of Windows container, Linux VM doesn't need it at all. + * + * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit + * the available GUID space of SOCKADDR_HV so that we can create a mapping + * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing + * Hyper-V Sockets apps on the host and in Linux VM is: + * + **************************************************************************** + * The only valid Service GUIDs, from the perspectives of both the host and * + * Linux VM, that can be connected by the other end, must conform to this * + * format: <port>-facb-11e6-bd58-64006a7986d3. * + **************************************************************************** + * + * When we write apps on the host to connect(), the GUID ServiceID is used. + * When we write apps in Linux VM to connect(), we only need to specify the + * port and the driver will form the GUID and use that to request the host. + * + */ + +/* 00000000-facb-11e6-bd58-64006a7986d3 */ +static const guid_t srv_id_template = + GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58, + 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3); + +static bool hvs_check_transport(struct vsock_sock *vsk); + +static bool is_valid_srv_id(const guid_t *id) +{ + return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4); +} + +static unsigned int get_port_by_srv_id(const guid_t *svr_id) +{ + return *((unsigned int *)svr_id); +} + +static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id) +{ + unsigned int port = get_port_by_srv_id(svr_id); + + vsock_addr_init(addr, VMADDR_CID_ANY, port); +} + +static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan) +{ + set_channel_pending_send_size(chan, + HVS_PKT_LEN(HVS_SEND_BUF_SIZE)); + + virt_mb(); +} + +static bool hvs_channel_readable(struct vmbus_channel *chan) +{ + u32 readable = hv_get_bytes_to_read(&chan->inbound); + + /* 0-size payload means FIN */ + return readable >= HVS_PKT_LEN(0); +} + +static int hvs_channel_readable_payload(struct vmbus_channel *chan) +{ + u32 readable = hv_get_bytes_to_read(&chan->inbound); + + if (readable > HVS_PKT_LEN(0)) { + /* At least we have 1 byte to read. We don't need to return + * the exact readable bytes: see vsock_stream_recvmsg() -> + * vsock_stream_has_data(). + */ + return 1; + } + + if (readable == HVS_PKT_LEN(0)) { + /* 0-size payload means FIN */ + return 0; + } + + /* No payload or FIN */ + return -1; +} + +static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan) +{ + u32 writeable = hv_get_bytes_to_write(&chan->outbound); + size_t ret; + + /* The ringbuffer mustn't be 100% full, and we should reserve a + * zero-length-payload packet for the FIN: see hv_ringbuffer_write() + * and hvs_shutdown(). + */ + if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0)) + return 0; + + ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0); + + return round_down(ret, 8); +} + +static int __hvs_send_data(struct vmbus_channel *chan, + struct vmpipe_proto_header *hdr, + size_t to_write) +{ + hdr->pkt_type = 1; + hdr->data_size = to_write; + return vmbus_sendpacket(chan, hdr, sizeof(*hdr) + to_write, + 0, VM_PKT_DATA_INBAND, 0); +} + +static int hvs_send_data(struct vmbus_channel *chan, + struct hvs_send_buf *send_buf, size_t to_write) +{ + return __hvs_send_data(chan, &send_buf->hdr, to_write); +} + +static void hvs_channel_cb(void *ctx) +{ + struct sock *sk = (struct sock *)ctx; + struct vsock_sock *vsk = vsock_sk(sk); + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + + if (hvs_channel_readable(chan)) + sk->sk_data_ready(sk); + + if (hv_get_bytes_to_write(&chan->outbound) > 0) + sk->sk_write_space(sk); +} + +static void hvs_do_close_lock_held(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + sk->sk_state_change(sk); + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + vsock_remove_sock(vsk); + + /* Release the reference taken while scheduling the timeout */ + sock_put(sk); + } +} + +static void hvs_close_connection(struct vmbus_channel *chan) +{ + struct sock *sk = get_per_channel_state(chan); + + lock_sock(sk); + hvs_do_close_lock_held(vsock_sk(sk), true); + release_sock(sk); + + /* Release the refcnt for the channel that's opened in + * hvs_open_connection(). + */ + sock_put(sk); +} + +static void hvs_open_connection(struct vmbus_channel *chan) +{ + guid_t *if_instance, *if_type; + unsigned char conn_from_host; + + struct sockaddr_vm addr; + struct sock *sk, *new = NULL; + struct vsock_sock *vnew = NULL; + struct hvsock *hvs = NULL; + struct hvsock *hvs_new = NULL; + int rcvbuf; + int ret; + int sndbuf; + + if_type = &chan->offermsg.offer.if_type; + if_instance = &chan->offermsg.offer.if_instance; + conn_from_host = chan->offermsg.offer.u.pipe.user_def[0]; + if (!is_valid_srv_id(if_type)) + return; + + hvs_addr_init(&addr, conn_from_host ? if_type : if_instance); + sk = vsock_find_bound_socket(&addr); + if (!sk) + return; + + lock_sock(sk); + if ((conn_from_host && sk->sk_state != TCP_LISTEN) || + (!conn_from_host && sk->sk_state != TCP_SYN_SENT)) + goto out; + + if (conn_from_host) { + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) + goto out; + + new = vsock_create_connected(sk); + if (!new) + goto out; + + new->sk_state = TCP_SYN_SENT; + vnew = vsock_sk(new); + + hvs_addr_init(&vnew->local_addr, if_type); + + /* Remote peer is always the host */ + vsock_addr_init(&vnew->remote_addr, + VMADDR_CID_HOST, VMADDR_PORT_ANY); + vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance); + ret = vsock_assign_transport(vnew, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the + * same where we received the request. + */ + if (ret || !hvs_check_transport(vnew)) { + sock_put(new); + goto out; + } + hvs_new = vnew->trans; + hvs_new->chan = chan; + } else { + hvs = vsock_sk(sk)->trans; + hvs->chan = chan; + } + + set_channel_read_mode(chan, HV_CALL_DIRECT); + + /* Use the socket buffer sizes as hints for the VMBUS ring size. For + * server side sockets, 'sk' is the parent socket and thus, this will + * allow the child sockets to inherit the size from the parent. Keep + * the mins to the default value and align to page size as per VMBUS + * requirements. + * For the max, the socket core library will limit the socket buffer + * size that can be set by the user, but, since currently, the hv_sock + * VMBUS ring buffer is physically contiguous allocation, restrict it + * further. + * Older versions of hv_sock host side code cannot handle bigger VMBUS + * ring buffer size. Use the version number to limit the change to newer + * versions. + */ + if (vmbus_proto_version < VERSION_WIN10_V5) { + sndbuf = RINGBUFFER_HVS_SND_SIZE; + rcvbuf = RINGBUFFER_HVS_RCV_SIZE; + } else { + sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE); + sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE); + sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE); + rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE); + rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE); + rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE); + } + + chan->max_pkt_size = HVS_MAX_PKT_SIZE; + + ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb, + conn_from_host ? new : sk); + if (ret != 0) { + if (conn_from_host) { + hvs_new->chan = NULL; + sock_put(new); + } else { + hvs->chan = NULL; + } + goto out; + } + + set_per_channel_state(chan, conn_from_host ? new : sk); + + /* This reference will be dropped by hvs_close_connection(). */ + sock_hold(conn_from_host ? new : sk); + vmbus_set_chn_rescind_callback(chan, hvs_close_connection); + + /* Set the pending send size to max packet size to always get + * notifications from the host when there is enough writable space. + * The host is optimized to send notifications only when the pending + * size boundary is crossed, and not always. + */ + hvs_set_channel_pending_send_size(chan); + + if (conn_from_host) { + new->sk_state = TCP_ESTABLISHED; + sk_acceptq_added(sk); + + hvs_new->vm_srv_id = *if_type; + hvs_new->host_srv_id = *if_instance; + + vsock_insert_connected(vnew); + + vsock_enqueue_accept(sk, new); + } else { + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + + vsock_insert_connected(vsock_sk(sk)); + } + + sk->sk_state_change(sk); + +out: + /* Release refcnt obtained when we called vsock_find_bound_socket() */ + sock_put(sk); + + release_sock(sk); +} + +static u32 hvs_get_local_cid(void) +{ + return VMADDR_CID_ANY; +} + +static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + struct hvsock *hvs; + struct sock *sk = sk_vsock(vsk); + + hvs = kzalloc(sizeof(*hvs), GFP_KERNEL); + if (!hvs) + return -ENOMEM; + + vsk->trans = hvs; + hvs->vsk = vsk; + sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE; + sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE; + return 0; +} + +static int hvs_connect(struct vsock_sock *vsk) +{ + union hvs_service_id vm, host; + struct hvsock *h = vsk->trans; + + vm.srv_id = srv_id_template; + vm.svm_port = vsk->local_addr.svm_port; + h->vm_srv_id = vm.srv_id; + + host.srv_id = srv_id_template; + host.svm_port = vsk->remote_addr.svm_port; + h->host_srv_id = host.srv_id; + + return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id); +} + +static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode) +{ + struct vmpipe_proto_header hdr; + + if (hvs->fin_sent || !hvs->chan) + return; + + /* It can't fail: see hvs_channel_writable_bytes(). */ + (void)__hvs_send_data(hvs->chan, &hdr, 0); + hvs->fin_sent = true; +} + +static int hvs_shutdown(struct vsock_sock *vsk, int mode) +{ + if (!(mode & SEND_SHUTDOWN)) + return 0; + + hvs_shutdown_lock_held(vsk->trans, mode); + return 0; +} + +static void hvs_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); + + sock_hold(sk); + lock_sock(sk); + if (!sock_flag(sk, SOCK_DONE)) + hvs_do_close_lock_held(vsk, false); + + vsk->close_work_scheduled = false; + release_sock(sk); + sock_put(sk); +} + +/* Returns true, if it is safe to remove socket; false otherwise */ +static bool hvs_close_lock_held(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) + return true; + + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK); + + if (sock_flag(sk, SOCK_DONE)) + return true; + + /* This reference will be dropped by the delayed close routine */ + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT); + return false; +} + +static void hvs_release(struct vsock_sock *vsk) +{ + bool remove_sock; + + remove_sock = hvs_close_lock_held(vsk); + if (remove_sock) + vsock_remove_sock(vsk); +} + +static void hvs_destruct(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + + if (chan) + vmbus_hvsock_device_unregister(chan); + + kfree(hvs); +} + +static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr) +{ + return -EOPNOTSUPP; +} + +static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len, int flags) +{ + return -EOPNOTSUPP; +} + +static int hvs_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote, struct msghdr *msg, + size_t dgram_len) +{ + return -EOPNOTSUPP; +} + +static bool hvs_dgram_allow(u32 cid, u32 port) +{ + return false; +} + +static int hvs_update_recv_data(struct hvsock *hvs) +{ + struct hvs_recv_buf *recv_buf; + u32 pkt_len, payload_len; + + pkt_len = hv_pkt_len(hvs->recv_desc); + + if (pkt_len < HVS_HEADER_LEN) + return -EIO; + + recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1); + payload_len = recv_buf->hdr.data_size; + + if (payload_len > pkt_len - HVS_HEADER_LEN || + payload_len > HVS_MTU_SIZE) + return -EIO; + + if (payload_len == 0) + hvs->vsk->peer_shutdown |= SEND_SHUTDOWN; + + hvs->recv_data_len = payload_len; + hvs->recv_data_off = 0; + + return 0; +} + +static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len, int flags) +{ + struct hvsock *hvs = vsk->trans; + bool need_refill = !hvs->recv_desc; + struct hvs_recv_buf *recv_buf; + u32 to_read; + int ret; + + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + + if (need_refill) { + hvs->recv_desc = hv_pkt_iter_first(hvs->chan); + if (!hvs->recv_desc) + return -ENOBUFS; + ret = hvs_update_recv_data(hvs); + if (ret) + return ret; + } + + recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1); + to_read = min_t(u32, len, hvs->recv_data_len); + ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read); + if (ret != 0) + return ret; + + hvs->recv_data_len -= to_read; + if (hvs->recv_data_len == 0) { + hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc); + if (hvs->recv_desc) { + ret = hvs_update_recv_data(hvs); + if (ret) + return ret; + } + } else { + hvs->recv_data_off += to_read; + } + + return to_read; +} + +static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len) +{ + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + struct hvs_send_buf *send_buf; + ssize_t to_write, max_writable; + ssize_t ret = 0; + ssize_t bytes_written = 0; + + BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE); + + send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL); + if (!send_buf) + return -ENOMEM; + + /* Reader(s) could be draining data from the channel as we write. + * Maximize bandwidth, by iterating until the channel is found to be + * full. + */ + while (len) { + max_writable = hvs_channel_writable_bytes(chan); + if (!max_writable) + break; + to_write = min_t(ssize_t, len, max_writable); + to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE); + /* memcpy_from_msg is safe for loop as it advances the offsets + * within the message iterator. + */ + ret = memcpy_from_msg(send_buf->data, msg, to_write); + if (ret < 0) + goto out; + + ret = hvs_send_data(hvs->chan, send_buf, to_write); + if (ret < 0) + goto out; + + bytes_written += to_write; + len -= to_write; + } +out: + /* If any data has been sent, return that */ + if (bytes_written) + ret = bytes_written; + kfree(send_buf); + return ret; +} + +static s64 hvs_stream_has_data(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + s64 ret; + + if (hvs->recv_data_len > 0) + return 1; + + switch (hvs_channel_readable_payload(hvs->chan)) { + case 1: + ret = 1; + break; + case 0: + vsk->peer_shutdown |= SEND_SHUTDOWN; + ret = 0; + break; + default: /* -1 */ + ret = 0; + break; + } + + return ret; +} + +static s64 hvs_stream_has_space(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + + return hvs_channel_writable_bytes(hvs->chan); +} + +static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return HVS_MTU_SIZE + 1; +} + +static bool hvs_stream_is_active(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + + return hvs->chan != NULL; +} + +static bool hvs_stream_allow(u32 cid, u32 port) +{ + if (cid == VMADDR_CID_HOST) + return true; + + return false; +} + +static +int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable) +{ + struct hvsock *hvs = vsk->trans; + + *readable = hvs_channel_readable(hvs->chan); + return 0; +} + +static +int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable) +{ + *writable = hvs_stream_has_space(vsk) > 0; + + return 0; +} + +static +int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target, + ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_set_rcvlowat(struct vsock_sock *vsk, int val) +{ + return -EOPNOTSUPP; +} + +static struct vsock_transport hvs_transport = { + .module = THIS_MODULE, + + .get_local_cid = hvs_get_local_cid, + + .init = hvs_sock_init, + .destruct = hvs_destruct, + .release = hvs_release, + .connect = hvs_connect, + .shutdown = hvs_shutdown, + + .dgram_bind = hvs_dgram_bind, + .dgram_dequeue = hvs_dgram_dequeue, + .dgram_enqueue = hvs_dgram_enqueue, + .dgram_allow = hvs_dgram_allow, + + .stream_dequeue = hvs_stream_dequeue, + .stream_enqueue = hvs_stream_enqueue, + .stream_has_data = hvs_stream_has_data, + .stream_has_space = hvs_stream_has_space, + .stream_rcvhiwat = hvs_stream_rcvhiwat, + .stream_is_active = hvs_stream_is_active, + .stream_allow = hvs_stream_allow, + + .notify_poll_in = hvs_notify_poll_in, + .notify_poll_out = hvs_notify_poll_out, + .notify_recv_init = hvs_notify_recv_init, + .notify_recv_pre_block = hvs_notify_recv_pre_block, + .notify_recv_pre_dequeue = hvs_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = hvs_notify_recv_post_dequeue, + .notify_send_init = hvs_notify_send_init, + .notify_send_pre_block = hvs_notify_send_pre_block, + .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue, + .notify_send_post_enqueue = hvs_notify_send_post_enqueue, + + .notify_set_rcvlowat = hvs_notify_set_rcvlowat +}; + +static bool hvs_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &hvs_transport; +} + +static int hvs_probe(struct hv_device *hdev, + const struct hv_vmbus_device_id *dev_id) +{ + struct vmbus_channel *chan = hdev->channel; + + hvs_open_connection(chan); + + /* Always return success to suppress the unnecessary error message + * in vmbus_probe(): on error the host will rescind the device in + * 30 seconds and we can do cleanup at that time in + * vmbus_onoffer_rescind(). + */ + return 0; +} + +static void hvs_remove(struct hv_device *hdev) +{ + struct vmbus_channel *chan = hdev->channel; + + vmbus_close(chan); +} + +/* hv_sock connections can not persist across hibernation, and all the hv_sock + * channels are forced to be rescinded before hibernation: see + * vmbus_bus_suspend(). Here the dummy hvs_suspend() and hvs_resume() + * are only needed because hibernation requires that every vmbus device's + * driver should have a .suspend and .resume callback: see vmbus_suspend(). + */ +static int hvs_suspend(struct hv_device *hv_dev) +{ + /* Dummy */ + return 0; +} + +static int hvs_resume(struct hv_device *dev) +{ + /* Dummy */ + return 0; +} + +/* This isn't really used. See vmbus_match() and vmbus_probe() */ +static const struct hv_vmbus_device_id id_table[] = { + {}, +}; + +static struct hv_driver hvs_drv = { + .name = "hv_sock", + .hvsock = true, + .id_table = id_table, + .probe = hvs_probe, + .remove = hvs_remove, + .suspend = hvs_suspend, + .resume = hvs_resume, +}; + +static int __init hvs_init(void) +{ + int ret; + + if (vmbus_proto_version < VERSION_WIN10) + return -ENODEV; + + ret = vmbus_driver_register(&hvs_drv); + if (ret != 0) + return ret; + + ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H); + if (ret) { + vmbus_driver_unregister(&hvs_drv); + return ret; + } + + return 0; +} + +static void __exit hvs_exit(void) +{ + vsock_core_unregister(&hvs_transport); + vmbus_driver_unregister(&hvs_drv); +} + +module_init(hvs_init); +module_exit(hvs_exit); + +MODULE_DESCRIPTION("Hyper-V Sockets"); +MODULE_VERSION("1.0.0"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c new file mode 100644 index 0000000000..a64bf601b4 --- /dev/null +++ b/net/vmw_vsock/virtio_transport.c @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * virtio transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s + * early virtio-vsock proof-of-concept bits. + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/list.h> +#include <linux/atomic.h> +#include <linux/virtio.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> +#include <linux/virtio_vsock.h> +#include <net/sock.h> +#include <linux/mutex.h> +#include <net/af_vsock.h> + +static struct workqueue_struct *virtio_vsock_workqueue; +static struct virtio_vsock __rcu *the_virtio_vsock; +static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ +static struct virtio_transport virtio_transport; /* forward declaration */ + +struct virtio_vsock { + struct virtio_device *vdev; + struct virtqueue *vqs[VSOCK_VQ_MAX]; + + /* Virtqueue processing is deferred to a workqueue */ + struct work_struct tx_work; + struct work_struct rx_work; + struct work_struct event_work; + + /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX] + * must be accessed with tx_lock held. + */ + struct mutex tx_lock; + bool tx_run; + + struct work_struct send_pkt_work; + struct sk_buff_head send_pkt_queue; + + atomic_t queued_replies; + + /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX] + * must be accessed with rx_lock held. + */ + struct mutex rx_lock; + bool rx_run; + int rx_buf_nr; + int rx_buf_max_nr; + + /* The following fields are protected by event_lock. + * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. + */ + struct mutex event_lock; + bool event_run; + struct virtio_vsock_event event_list[8]; + + u32 guest_cid; + bool seqpacket_allow; +}; + +static u32 virtio_transport_get_local_cid(void) +{ + struct virtio_vsock *vsock; + u32 ret; + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + ret = VMADDR_CID_ANY; + goto out_rcu; + } + + ret = vsock->guest_cid; +out_rcu: + rcu_read_unlock(); + return ret; +} + +static void +virtio_transport_send_pkt_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, send_pkt_work); + struct virtqueue *vq; + bool added = false; + bool restart_rx = false; + + mutex_lock(&vsock->tx_lock); + + if (!vsock->tx_run) + goto out; + + vq = vsock->vqs[VSOCK_VQ_TX]; + + for (;;) { + struct scatterlist hdr, buf, *sgs[2]; + int ret, in_sg = 0, out_sg = 0; + struct sk_buff *skb; + bool reply; + + skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue); + if (!skb) + break; + + virtio_transport_deliver_tap_pkt(skb); + reply = virtio_vsock_skb_reply(skb); + + sg_init_one(&hdr, virtio_vsock_hdr(skb), sizeof(*virtio_vsock_hdr(skb))); + sgs[out_sg++] = &hdr; + if (skb->len > 0) { + sg_init_one(&buf, skb->data, skb->len); + sgs[out_sg++] = &buf; + } + + ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL); + /* Usually this means that there is no more space available in + * the vq + */ + if (ret < 0) { + virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb); + break; + } + + if (reply) { + struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + val = atomic_dec_return(&vsock->queued_replies); + + /* Do we now have resources to resume rx processing? */ + if (val + 1 == virtqueue_get_vring_size(rx_vq)) + restart_rx = true; + } + + added = true; + } + + if (added) + virtqueue_kick(vq); + +out: + mutex_unlock(&vsock->tx_lock); + + if (restart_rx) + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static int +virtio_transport_send_pkt(struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr; + struct virtio_vsock *vsock; + int len = skb->len; + + hdr = virtio_vsock_hdr(skb); + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + kfree_skb(skb); + len = -ENODEV; + goto out_rcu; + } + + if (le64_to_cpu(hdr->dst_cid) == vsock->guest_cid) { + kfree_skb(skb); + len = -ENODEV; + goto out_rcu; + } + + if (virtio_vsock_skb_reply(skb)) + atomic_inc(&vsock->queued_replies); + + virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + +out_rcu: + rcu_read_unlock(); + return len; +} + +static int +virtio_transport_cancel_pkt(struct vsock_sock *vsk) +{ + struct virtio_vsock *vsock; + int cnt = 0, ret; + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + ret = -ENODEV; + goto out_rcu; + } + + cnt = virtio_transport_purge_skbs(vsk, &vsock->send_pkt_queue); + + if (cnt) { + struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; + int new_cnt; + + new_cnt = atomic_sub_return(cnt, &vsock->queued_replies); + if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) && + new_cnt < virtqueue_get_vring_size(rx_vq)) + queue_work(virtio_vsock_workqueue, &vsock->rx_work); + } + + ret = 0; + +out_rcu: + rcu_read_unlock(); + return ret; +} + +static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) +{ + int total_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE + VIRTIO_VSOCK_SKB_HEADROOM; + struct scatterlist pkt, *p; + struct virtqueue *vq; + struct sk_buff *skb; + int ret; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + do { + skb = virtio_vsock_alloc_skb(total_len, GFP_KERNEL); + if (!skb) + break; + + memset(skb->head, 0, VIRTIO_VSOCK_SKB_HEADROOM); + sg_init_one(&pkt, virtio_vsock_hdr(skb), total_len); + p = &pkt; + ret = virtqueue_add_sgs(vq, &p, 0, 1, skb, GFP_KERNEL); + if (ret < 0) { + kfree_skb(skb); + break; + } + + vsock->rx_buf_nr++; + } while (vq->num_free); + if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) + vsock->rx_buf_max_nr = vsock->rx_buf_nr; + virtqueue_kick(vq); +} + +static void virtio_transport_tx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, tx_work); + struct virtqueue *vq; + bool added = false; + + vq = vsock->vqs[VSOCK_VQ_TX]; + mutex_lock(&vsock->tx_lock); + + if (!vsock->tx_run) + goto out; + + do { + struct sk_buff *skb; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((skb = virtqueue_get_buf(vq, &len)) != NULL) { + consume_skb(skb); + added = true; + } + } while (!virtqueue_enable_cb(vq)); + +out: + mutex_unlock(&vsock->tx_lock); + + if (added) + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); +} + +/* Is there space left for replies to rx packets? */ +static bool virtio_transport_more_replies(struct virtio_vsock *vsock) +{ + struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ + val = atomic_read(&vsock->queued_replies); + + return val < virtqueue_get_vring_size(vq); +} + +/* event_lock must be held */ +static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + struct scatterlist sg; + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + sg_init_one(&sg, event, sizeof(*event)); + + return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL); +} + +/* event_lock must be held */ +static void virtio_vsock_event_fill(struct virtio_vsock *vsock) +{ + size_t i; + + for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) { + struct virtio_vsock_event *event = &vsock->event_list[i]; + + virtio_vsock_event_fill_one(vsock, event); + } + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); +} + +static void virtio_vsock_reset_sock(struct sock *sk) +{ + /* vmci_transport.c doesn't take sk_lock here either. At least we're + * under vsock_table_lock so the sock cannot disappear while we're + * executing. + */ + + sk->sk_state = TCP_CLOSE; + sk->sk_err = ECONNRESET; + sk_error_report(sk); +} + +static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + __le64 guest_cid; + + vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), + &guest_cid, sizeof(guest_cid)); + vsock->guest_cid = le64_to_cpu(guest_cid); +} + +/* event_lock must be held */ +static void virtio_vsock_event_handle(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + switch (le32_to_cpu(event->id)) { + case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: + virtio_vsock_update_guest_cid(vsock); + vsock_for_each_connected_socket(&virtio_transport.transport, + virtio_vsock_reset_sock); + break; + } +} + +static void virtio_transport_event_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, event_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + mutex_lock(&vsock->event_lock); + + if (!vsock->event_run) + goto out; + + do { + struct virtio_vsock_event *event; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((event = virtqueue_get_buf(vq, &len)) != NULL) { + if (len == sizeof(*event)) + virtio_vsock_event_handle(vsock, event); + + virtio_vsock_event_fill_one(vsock, event); + } + } while (!virtqueue_enable_cb(vq)); + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); +out: + mutex_unlock(&vsock->event_lock); +} + +static void virtio_vsock_event_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->event_work); +} + +static void virtio_vsock_tx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->tx_work); +} + +static void virtio_vsock_rx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static bool virtio_transport_seqpacket_allow(u32 remote_cid); + +static struct virtio_transport virtio_transport = { + .transport = { + .module = THIS_MODULE, + + .get_local_cid = virtio_transport_get_local_cid, + + .init = virtio_transport_do_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + .cancel_pkt = virtio_transport_cancel_pkt, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, + .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, + .seqpacket_allow = virtio_transport_seqpacket_allow, + .seqpacket_has_data = virtio_transport_seqpacket_has_data, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + .notify_buffer_size = virtio_transport_notify_buffer_size, + .notify_set_rcvlowat = virtio_transport_notify_set_rcvlowat, + + .read_skb = virtio_transport_read_skb, + }, + + .send_pkt = virtio_transport_send_pkt, +}; + +static bool virtio_transport_seqpacket_allow(u32 remote_cid) +{ + struct virtio_vsock *vsock; + bool seqpacket_allow; + + seqpacket_allow = false; + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (vsock) + seqpacket_allow = vsock->seqpacket_allow; + rcu_read_unlock(); + + return seqpacket_allow; +} + +static void virtio_transport_rx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, rx_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + mutex_lock(&vsock->rx_lock); + + if (!vsock->rx_run) + goto out; + + do { + virtqueue_disable_cb(vq); + for (;;) { + struct sk_buff *skb; + unsigned int len; + + if (!virtio_transport_more_replies(vsock)) { + /* Stop rx until the device processes already + * pending replies. Leave rx virtqueue + * callbacks disabled. + */ + goto out; + } + + skb = virtqueue_get_buf(vq, &len); + if (!skb) + break; + + vsock->rx_buf_nr--; + + /* Drop short/long packets */ + if (unlikely(len < sizeof(struct virtio_vsock_hdr) || + len > virtio_vsock_skb_len(skb))) { + kfree_skb(skb); + continue; + } + + virtio_vsock_skb_rx_put(skb); + virtio_transport_deliver_tap_pkt(skb); + virtio_transport_recv_pkt(&virtio_transport, skb); + } + } while (!virtqueue_enable_cb(vq)); + +out: + if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); +} + +static int virtio_vsock_vqs_init(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + static const char * const names[] = { + "rx", + "tx", + "event", + }; + vq_callback_t *callbacks[] = { + virtio_vsock_rx_done, + virtio_vsock_tx_done, + virtio_vsock_event_done, + }; + int ret; + + ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names, + NULL); + if (ret < 0) + return ret; + + virtio_vsock_update_guest_cid(vsock); + + virtio_device_ready(vdev); + + return 0; +} + +static void virtio_vsock_vqs_start(struct virtio_vsock *vsock) +{ + mutex_lock(&vsock->tx_lock); + vsock->tx_run = true; + mutex_unlock(&vsock->tx_lock); + + mutex_lock(&vsock->rx_lock); + virtio_vsock_rx_fill(vsock); + vsock->rx_run = true; + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->event_lock); + virtio_vsock_event_fill(vsock); + vsock->event_run = true; + mutex_unlock(&vsock->event_lock); + + /* virtio_transport_send_pkt() can queue packets once + * the_virtio_vsock is set, but they won't be processed until + * vsock->tx_run is set to true. We queue vsock->send_pkt_work + * when initialization finishes to send those packets queued + * earlier. + * We don't need to queue the other workers (rx, event) because + * as long as we don't fill the queues with empty buffers, the + * host can't send us any notification. + */ + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); +} + +static void virtio_vsock_vqs_del(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + struct sk_buff *skb; + + /* Reset all connected sockets when the VQs disappear */ + vsock_for_each_connected_socket(&virtio_transport.transport, + virtio_vsock_reset_sock); + + /* Stop all work handlers to make sure no one is accessing the device, + * so we can safely call virtio_reset_device(). + */ + mutex_lock(&vsock->rx_lock); + vsock->rx_run = false; + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->tx_lock); + vsock->tx_run = false; + mutex_unlock(&vsock->tx_lock); + + mutex_lock(&vsock->event_lock); + vsock->event_run = false; + mutex_unlock(&vsock->event_lock); + + /* Flush all device writes and interrupts, device will not use any + * more buffers. + */ + virtio_reset_device(vdev); + + mutex_lock(&vsock->rx_lock); + while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX]))) + kfree_skb(skb); + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->tx_lock); + while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX]))) + kfree_skb(skb); + mutex_unlock(&vsock->tx_lock); + + virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue); + + /* Delete virtqueues and flush outstanding callbacks if any */ + vdev->config->del_vqs(vdev); +} + +static int virtio_vsock_probe(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = NULL; + int ret; + + ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); + if (ret) + return ret; + + /* Only one virtio-vsock device per guest is supported */ + if (rcu_dereference_protected(the_virtio_vsock, + lockdep_is_held(&the_virtio_vsock_mutex))) { + ret = -EBUSY; + goto out; + } + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) { + ret = -ENOMEM; + goto out; + } + + vsock->vdev = vdev; + + vsock->rx_buf_nr = 0; + vsock->rx_buf_max_nr = 0; + atomic_set(&vsock->queued_replies, 0); + + mutex_init(&vsock->tx_lock); + mutex_init(&vsock->rx_lock); + mutex_init(&vsock->event_lock); + skb_queue_head_init(&vsock->send_pkt_queue); + INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); + INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); + INIT_WORK(&vsock->event_work, virtio_transport_event_work); + INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); + + if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET)) + vsock->seqpacket_allow = true; + + vdev->priv = vsock; + + ret = virtio_vsock_vqs_init(vsock); + if (ret < 0) + goto out; + + rcu_assign_pointer(the_virtio_vsock, vsock); + virtio_vsock_vqs_start(vsock); + + mutex_unlock(&the_virtio_vsock_mutex); + + return 0; + +out: + kfree(vsock); + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} + +static void virtio_vsock_remove(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + + mutex_lock(&the_virtio_vsock_mutex); + + vdev->priv = NULL; + rcu_assign_pointer(the_virtio_vsock, NULL); + synchronize_rcu(); + + virtio_vsock_vqs_del(vsock); + + /* Other works can be queued before 'config->del_vqs()', so we flush + * all works before to free the vsock object to avoid use after free. + */ + flush_work(&vsock->rx_work); + flush_work(&vsock->tx_work); + flush_work(&vsock->event_work); + flush_work(&vsock->send_pkt_work); + + mutex_unlock(&the_virtio_vsock_mutex); + + kfree(vsock); +} + +#ifdef CONFIG_PM_SLEEP +static int virtio_vsock_freeze(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + + mutex_lock(&the_virtio_vsock_mutex); + + rcu_assign_pointer(the_virtio_vsock, NULL); + synchronize_rcu(); + + virtio_vsock_vqs_del(vsock); + + mutex_unlock(&the_virtio_vsock_mutex); + + return 0; +} + +static int virtio_vsock_restore(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + int ret; + + mutex_lock(&the_virtio_vsock_mutex); + + /* Only one virtio-vsock device per guest is supported */ + if (rcu_dereference_protected(the_virtio_vsock, + lockdep_is_held(&the_virtio_vsock_mutex))) { + ret = -EBUSY; + goto out; + } + + ret = virtio_vsock_vqs_init(vsock); + if (ret < 0) + goto out; + + rcu_assign_pointer(the_virtio_vsock, vsock); + virtio_vsock_vqs_start(vsock); + +out: + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} +#endif /* CONFIG_PM_SLEEP */ + +static struct virtio_device_id id_table[] = { + { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, + { 0 }, +}; + +static unsigned int features[] = { + VIRTIO_VSOCK_F_SEQPACKET +}; + +static struct virtio_driver virtio_vsock_driver = { + .feature_table = features, + .feature_table_size = ARRAY_SIZE(features), + .driver.name = KBUILD_MODNAME, + .driver.owner = THIS_MODULE, + .id_table = id_table, + .probe = virtio_vsock_probe, + .remove = virtio_vsock_remove, +#ifdef CONFIG_PM_SLEEP + .freeze = virtio_vsock_freeze, + .restore = virtio_vsock_restore, +#endif +}; + +static int __init virtio_vsock_init(void) +{ + int ret; + + virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); + if (!virtio_vsock_workqueue) + return -ENOMEM; + + ret = vsock_core_register(&virtio_transport.transport, + VSOCK_TRANSPORT_F_G2H); + if (ret) + goto out_wq; + + ret = register_virtio_driver(&virtio_vsock_driver); + if (ret) + goto out_vci; + + return 0; + +out_vci: + vsock_core_unregister(&virtio_transport.transport); +out_wq: + destroy_workqueue(virtio_vsock_workqueue); + return ret; +} + +static void __exit virtio_vsock_exit(void) +{ + unregister_virtio_driver(&virtio_vsock_driver); + vsock_core_unregister(&virtio_transport.transport); + destroy_workqueue(virtio_vsock_workqueue); +} + +module_init(virtio_vsock_init); +module_exit(virtio_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("virtio transport for vsock"); +MODULE_DEVICE_TABLE(virtio, id_table); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c new file mode 100644 index 0000000000..e87fd9480a --- /dev/null +++ b/net/vmw_vsock/virtio_transport_common.c @@ -0,0 +1,1561 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * common code for virtio vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/sched/signal.h> +#include <linux/ctype.h> +#include <linux/list.h> +#include <linux/virtio_vsock.h> +#include <uapi/linux/vsockmon.h> + +#include <net/sock.h> +#include <net/af_vsock.h> + +#define CREATE_TRACE_POINTS +#include <trace/events/vsock_virtio_transport_common.h> + +/* How long to wait for graceful shutdown of a connection */ +#define VSOCK_CLOSE_TIMEOUT (8 * HZ) + +/* Threshold for detecting small packets to copy */ +#define GOOD_COPY_LEN 128 + +static const struct virtio_transport * +virtio_transport_get_ops(struct vsock_sock *vsk) +{ + const struct vsock_transport *t = vsock_core_get_transport(vsk); + + if (WARN_ON(!t)) + return NULL; + + return container_of(t, struct virtio_transport, transport); +} + +/* Returns a new packet on success, otherwise returns NULL. + * + * If NULL is returned, errp is set to a negative errno. + */ +static struct sk_buff * +virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port) +{ + const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len; + struct virtio_vsock_hdr *hdr; + struct sk_buff *skb; + void *payload; + int err; + + skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL); + if (!skb) + return NULL; + + hdr = virtio_vsock_hdr(skb); + hdr->type = cpu_to_le16(info->type); + hdr->op = cpu_to_le16(info->op); + hdr->src_cid = cpu_to_le64(src_cid); + hdr->dst_cid = cpu_to_le64(dst_cid); + hdr->src_port = cpu_to_le32(src_port); + hdr->dst_port = cpu_to_le32(dst_port); + hdr->flags = cpu_to_le32(info->flags); + hdr->len = cpu_to_le32(len); + hdr->buf_alloc = cpu_to_le32(0); + hdr->fwd_cnt = cpu_to_le32(0); + + if (info->msg && len > 0) { + payload = skb_put(skb, len); + err = memcpy_from_msg(payload, info->msg, len); + if (err) + goto out; + + if (msg_data_left(info->msg) == 0 && + info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) { + hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM); + + if (info->msg->msg_flags & MSG_EOR) + hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR); + } + } + + if (info->reply) + virtio_vsock_skb_set_reply(skb); + + trace_virtio_transport_alloc_pkt(src_cid, src_port, + dst_cid, dst_port, + len, + info->type, + info->op, + info->flags); + + if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) { + WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n"); + goto out; + } + + return skb; + +out: + kfree_skb(skb); + return NULL; +} + +/* Packet capture */ +static struct sk_buff *virtio_transport_build_skb(void *opaque) +{ + struct virtio_vsock_hdr *pkt_hdr; + struct sk_buff *pkt = opaque; + struct af_vsockmon_hdr *hdr; + struct sk_buff *skb; + size_t payload_len; + void *payload_buf; + + /* A packet could be split to fit the RX buffer, so we can retrieve + * the payload length from the header and the buffer pointer taking + * care of the offset in the original packet. + */ + pkt_hdr = virtio_vsock_hdr(pkt); + payload_len = pkt->len; + payload_buf = pkt->data; + + skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len, + GFP_ATOMIC); + if (!skb) + return NULL; + + hdr = skb_put(skb, sizeof(*hdr)); + + /* pkt->hdr is little-endian so no need to byteswap here */ + hdr->src_cid = pkt_hdr->src_cid; + hdr->src_port = pkt_hdr->src_port; + hdr->dst_cid = pkt_hdr->dst_cid; + hdr->dst_port = pkt_hdr->dst_port; + + hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); + hdr->len = cpu_to_le16(sizeof(*pkt_hdr)); + memset(hdr->reserved, 0, sizeof(hdr->reserved)); + + switch (le16_to_cpu(pkt_hdr->op)) { + case VIRTIO_VSOCK_OP_REQUEST: + case VIRTIO_VSOCK_OP_RESPONSE: + hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); + break; + case VIRTIO_VSOCK_OP_RST: + case VIRTIO_VSOCK_OP_SHUTDOWN: + hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); + break; + case VIRTIO_VSOCK_OP_RW: + hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); + break; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); + break; + default: + hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); + break; + } + + skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr)); + + if (payload_len) { + skb_put_data(skb, payload_buf, payload_len); + } + + return skb; +} + +void virtio_transport_deliver_tap_pkt(struct sk_buff *skb) +{ + if (virtio_vsock_skb_tap_delivered(skb)) + return; + + vsock_deliver_tap(virtio_transport_build_skb, skb); + virtio_vsock_skb_set_tap_delivered(skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); + +static u16 virtio_transport_get_type(struct sock *sk) +{ + if (sk->sk_type == SOCK_STREAM) + return VIRTIO_VSOCK_TYPE_STREAM; + else + return VIRTIO_VSOCK_TYPE_SEQPACKET; +} + +/* This function can only be used on connecting/connected sockets, + * since a socket assigned to a transport is required. + * + * Do not use on listener sockets! + */ +static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + const struct virtio_transport *t_ops; + struct virtio_vsock_sock *vvs; + u32 pkt_len = info->pkt_len; + u32 rest_len; + int ret; + + info->type = virtio_transport_get_type(sk_vsock(vsk)); + + t_ops = virtio_transport_get_ops(vsk); + if (unlikely(!t_ops)) + return -EFAULT; + + src_cid = t_ops->transport.get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + vvs = vsk->trans; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + pkt_len = virtio_transport_get_credit(vvs, pkt_len); + + /* Do not send zero length OP_RW pkt */ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + rest_len = pkt_len; + + do { + struct sk_buff *skb; + size_t skb_len; + + skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len); + + skb = virtio_transport_alloc_skb(info, skb_len, + src_cid, src_port, + dst_cid, dst_port); + if (!skb) { + ret = -ENOMEM; + break; + } + + virtio_transport_inc_tx_pkt(vvs, skb); + + ret = t_ops->send_pkt(skb); + if (ret < 0) + break; + + /* Both virtio and vhost 'send_pkt()' returns 'skb_len', + * but for reliability use 'ret' instead of 'skb_len'. + * Also if partial send happens (e.g. 'ret' != 'skb_len') + * somehow, we break this loop, but account such returned + * value in 'virtio_transport_put_credit()'. + */ + rest_len -= ret; + + if (WARN_ONCE(ret != skb_len, + "'send_pkt()' returns %i, but %zu expected\n", + ret, skb_len)) + break; + } while (rest_len); + + virtio_transport_put_credit(vvs, rest_len); + + /* Return number of bytes, if any data has been sent. */ + if (rest_len != pkt_len) + ret = pkt_len - rest_len; + + return ret; +} + +static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, + u32 len) +{ + if (vvs->rx_bytes + len > vvs->buf_alloc) + return false; + + vvs->rx_bytes += len; + return true; +} + +static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, + u32 len) +{ + vvs->rx_bytes -= len; + vvs->fwd_cnt += len; +} + +void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + + spin_lock_bh(&vvs->rx_lock); + vvs->last_fwd_cnt = vvs->fwd_cnt; + hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt); + hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc); + spin_unlock_bh(&vvs->rx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); + +u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + u32 ret; + + if (!credit) + return 0; + + spin_lock_bh(&vvs->tx_lock); + ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (ret > credit) + ret = credit; + vvs->tx_cnt += ret; + spin_unlock_bh(&vvs->tx_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_credit); + +void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + if (!credit) + return; + + spin_lock_bh(&vvs->tx_lock); + vvs->tx_cnt -= credit; + spin_unlock_bh(&vvs->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_put_credit); + +static int virtio_transport_send_credit_update(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +static ssize_t +virtio_transport_stream_do_peek(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct sk_buff *skb; + size_t total = 0; + int err; + + spin_lock_bh(&vvs->rx_lock); + + skb_queue_walk(&vvs->rx_queue, skb) { + size_t bytes; + + bytes = len - total; + if (bytes > skb->len) + bytes = skb->len; + + spin_unlock_bh(&vvs->rx_lock); + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + err = memcpy_to_msg(msg, skb->data, bytes); + if (err) + goto out; + + total += bytes; + + spin_lock_bh(&vvs->rx_lock); + + if (total == len) + break; + } + + spin_unlock_bh(&vvs->rx_lock); + + return total; + +out: + if (total) + err = total; + return err; +} + +static ssize_t +virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + size_t bytes, total = 0; + struct sk_buff *skb; + u32 fwd_cnt_delta; + bool low_rx_bytes; + int err = -EFAULT; + u32 free_space; + + spin_lock_bh(&vvs->rx_lock); + + if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes, + "rx_queue is empty, but rx_bytes is non-zero\n")) { + spin_unlock_bh(&vvs->rx_lock); + return err; + } + + while (total < len && !skb_queue_empty(&vvs->rx_queue)) { + skb = skb_peek(&vvs->rx_queue); + + bytes = len - total; + if (bytes > skb->len) + bytes = skb->len; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, skb->data, bytes); + if (err) + goto out; + + spin_lock_bh(&vvs->rx_lock); + + total += bytes; + skb_pull(skb, bytes); + + if (skb->len == 0) { + u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len); + + virtio_transport_dec_rx_pkt(vvs, pkt_len); + __skb_unlink(skb, &vvs->rx_queue); + consume_skb(skb); + } + } + + fwd_cnt_delta = vvs->fwd_cnt - vvs->last_fwd_cnt; + free_space = vvs->buf_alloc - fwd_cnt_delta; + low_rx_bytes = (vvs->rx_bytes < + sock_rcvlowat(sk_vsock(vsk), 0, INT_MAX)); + + spin_unlock_bh(&vvs->rx_lock); + + /* To reduce the number of credit update messages, + * don't update credits as long as lots of space is available. + * Note: the limit chosen here is arbitrary. Setting the limit + * too high causes extra messages. Too low causes transmitter + * stalls. As stalls are in theory more expensive than extra + * messages, we set the limit to a high value. TODO: experiment + * with different values. Also send credit update message when + * number of bytes in rx queue is not enough to wake up reader. + */ + if (fwd_cnt_delta && + (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || low_rx_bytes)) + virtio_transport_send_credit_update(vsk); + + return total; + +out: + if (total) + err = total; + return err; +} + +static ssize_t +virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk, + struct msghdr *msg) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct sk_buff *skb; + size_t total, len; + + spin_lock_bh(&vvs->rx_lock); + + if (!vvs->msg_count) { + spin_unlock_bh(&vvs->rx_lock); + return 0; + } + + total = 0; + len = msg_data_left(msg); + + skb_queue_walk(&vvs->rx_queue, skb) { + struct virtio_vsock_hdr *hdr; + + if (total < len) { + size_t bytes; + int err; + + bytes = len - total; + if (bytes > skb->len) + bytes = skb->len; + + spin_unlock_bh(&vvs->rx_lock); + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + err = memcpy_to_msg(msg, skb->data, bytes); + if (err) + return err; + + spin_lock_bh(&vvs->rx_lock); + } + + total += skb->len; + hdr = virtio_vsock_hdr(skb); + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) + msg->msg_flags |= MSG_EOR; + + break; + } + } + + spin_unlock_bh(&vvs->rx_lock); + + return total; +} + +static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + int flags) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + int dequeued_len = 0; + size_t user_buf_len = msg_data_left(msg); + bool msg_ready = false; + struct sk_buff *skb; + + spin_lock_bh(&vvs->rx_lock); + + if (vvs->msg_count == 0) { + spin_unlock_bh(&vvs->rx_lock); + return 0; + } + + while (!msg_ready) { + struct virtio_vsock_hdr *hdr; + size_t pkt_len; + + skb = __skb_dequeue(&vvs->rx_queue); + if (!skb) + break; + hdr = virtio_vsock_hdr(skb); + pkt_len = (size_t)le32_to_cpu(hdr->len); + + if (dequeued_len >= 0) { + size_t bytes_to_copy; + + bytes_to_copy = min(user_buf_len, pkt_len); + + if (bytes_to_copy) { + int err; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, skb->data, bytes_to_copy); + if (err) { + /* Copy of message failed. Rest of + * fragments will be freed without copy. + */ + dequeued_len = err; + } else { + user_buf_len -= bytes_to_copy; + } + + spin_lock_bh(&vvs->rx_lock); + } + + if (dequeued_len >= 0) + dequeued_len += pkt_len; + } + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { + msg_ready = true; + vvs->msg_count--; + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) + msg->msg_flags |= MSG_EOR; + } + + virtio_transport_dec_rx_pkt(vvs, pkt_len); + kfree_skb(skb); + } + + spin_unlock_bh(&vvs->rx_lock); + + virtio_transport_send_credit_update(vsk); + + return dequeued_len; +} + +ssize_t +virtio_transport_stream_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + if (flags & MSG_PEEK) + return virtio_transport_stream_do_peek(vsk, msg, len); + else + return virtio_transport_stream_do_dequeue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); + +ssize_t +virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + int flags) +{ + if (flags & MSG_PEEK) + return virtio_transport_seqpacket_do_peek(vsk, msg); + else + return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); + +int +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + spin_lock_bh(&vvs->tx_lock); + + if (len > vvs->peer_buf_alloc) { + spin_unlock_bh(&vvs->tx_lock); + return -EMSGSIZE; + } + + spin_unlock_bh(&vvs->tx_lock); + + return virtio_transport_stream_enqueue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); + +int +virtio_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); + +s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->rx_lock); + bytes = vvs->rx_bytes; + spin_unlock_bh(&vvs->rx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); + +u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + u32 msg_count; + + spin_lock_bh(&vvs->rx_lock); + msg_count = vvs->msg_count; + spin_unlock_bh(&vvs->rx_lock); + + return msg_count; +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data); + +static s64 virtio_transport_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + bytes = (s64)vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (bytes < 0) + bytes = 0; + + return bytes; +} + +s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->tx_lock); + bytes = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); + +int virtio_transport_do_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + struct virtio_vsock_sock *vvs; + + vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); + if (!vvs) + return -ENOMEM; + + vsk->trans = vvs; + vvs->vsk = vsk; + if (psk && psk->trans) { + struct virtio_vsock_sock *ptrans = psk->trans; + + vvs->peer_buf_alloc = ptrans->peer_buf_alloc; + } + + if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) + vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; + + vvs->buf_alloc = vsk->buffer_size; + + spin_lock_init(&vvs->rx_lock); + spin_lock_init(&vvs->tx_lock); + skb_queue_head_init(&vvs->rx_queue); + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); + +/* sk_lock held by the caller */ +void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) + *val = VIRTIO_VSOCK_MAX_BUF_SIZE; + + vvs->buf_alloc = *val; + + virtio_transport_send_credit_update(vsk); +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); + +int +virtio_transport_notify_poll_in(struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + *data_ready_now = vsock_stream_has_data(vsk) >= target; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); + +int +virtio_transport_notify_poll_out(struct vsock_sock *vsk, + size_t target, + bool *space_avail_now) +{ + s64 free_space; + + free_space = vsock_stream_has_space(vsk); + if (free_space > 0) + *space_avail_now = true; + else if (free_space == 0) + *space_avail_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); + +int virtio_transport_notify_recv_init(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); + +int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); + +int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); + +int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, + size_t target, ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); + +int virtio_transport_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); + +int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); + +int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); + +int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, + ssize_t written, struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); + +u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return vsk->buffer_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); + +bool virtio_transport_stream_is_active(struct vsock_sock *vsk) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); + +bool virtio_transport_stream_allow(u32 cid, u32 port) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); + +int virtio_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); + +bool virtio_transport_dgram_allow(u32 cid, u32 port) +{ + return false; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); + +int virtio_transport_connect(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_REQUEST, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_connect); + +int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_SHUTDOWN, + .flags = (mode & RCV_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | + (mode & SEND_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_SEND : 0), + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_shutdown); + +int +virtio_transport_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t dgram_len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); + +ssize_t +virtio_transport_stream_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RW, + .msg = msg, + .pkt_len = len, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); + +void virtio_transport_destruct(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + kfree(vvs); +} +EXPORT_SYMBOL_GPL(virtio_transport_destruct); + +static int virtio_transport_reset(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .reply = !!skb, + .vsk = vsk, + }; + + /* Send RST only if the original pkt is not a RST pkt */ + if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST) + return 0; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +/* Normally packets are associated with a socket. There may be no socket if an + * attempt was made to connect to a socket that does not exist. + */ +static int virtio_transport_reset_no_sock(const struct virtio_transport *t, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .type = le16_to_cpu(hdr->type), + .reply = true, + }; + struct sk_buff *reply; + + /* Send RST only if the original pkt is not a RST pkt */ + if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) + return 0; + + if (!t) + return -ENOTCONN; + + reply = virtio_transport_alloc_skb(&info, 0, + le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port), + le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + if (!reply) + return -ENOMEM; + + return t->send_pkt(reply); +} + +/* This function should be called with sk_lock held and SOCK_DONE set */ +static void virtio_transport_remove_sock(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + /* We don't need to take rx_lock, as the socket is closing and we are + * removing it. + */ + __skb_queue_purge(&vvs->rx_queue); + vsock_remove_sock(vsk); +} + +static void virtio_transport_wait_close(struct sock *sk, long timeout) +{ + if (timeout) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + + do { + if (sk_wait_event(sk, &timeout, + sock_flag(sk, SOCK_DONE), &wait)) + break; + } while (!signal_pending(current) && timeout); + + remove_wait_queue(sk_sleep(sk), &wait); + } +} + +static void virtio_transport_do_close(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + sk->sk_state_change(sk); + + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + + virtio_transport_remove_sock(vsk); + + /* Release refcnt obtained when we scheduled the timeout */ + sock_put(sk); + } +} + +static void virtio_transport_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); + + sock_hold(sk); + lock_sock(sk); + + if (!sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset(vsk, NULL); + + virtio_transport_do_close(vsk, false); + } + + vsk->close_work_scheduled = false; + + release_sock(sk); + sock_put(sk); +} + +/* User context, vsk->sk is locked */ +static bool virtio_transport_close(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) + return true; + + /* Already received SHUTDOWN from peer, reply with RST */ + if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { + (void)virtio_transport_reset(vsk, NULL); + return true; + } + + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); + + if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) + virtio_transport_wait_close(sk, sk->sk_lingertime); + + if (sock_flag(sk, SOCK_DONE)) { + return true; + } + + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, + virtio_transport_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); + return false; +} + +void virtio_transport_release(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + bool remove_sock = true; + + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) + remove_sock = virtio_transport_close(vsk); + + if (remove_sock) { + sock_set_flag(sk, SOCK_DONE); + virtio_transport_remove_sock(vsk); + } +} +EXPORT_SYMBOL_GPL(virtio_transport_release); + +static int +virtio_transport_recv_connecting(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + int skerr; + int err; + + switch (le16_to_cpu(hdr->op)) { + case VIRTIO_VSOCK_OP_RESPONSE: + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_INVALID: + break; + case VIRTIO_VSOCK_OP_RST: + skerr = ECONNRESET; + err = 0; + goto destroy; + default: + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + return 0; + +destroy: + virtio_transport_reset(vsk, skb); + sk->sk_state = TCP_CLOSE; + sk->sk_err = skerr; + sk_error_report(sk); + return err; +} + +static void +virtio_transport_recv_enqueue(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + bool can_enqueue, free_pkt = false; + struct virtio_vsock_hdr *hdr; + u32 len; + + hdr = virtio_vsock_hdr(skb); + len = le32_to_cpu(hdr->len); + + spin_lock_bh(&vvs->rx_lock); + + can_enqueue = virtio_transport_inc_rx_pkt(vvs, len); + if (!can_enqueue) { + free_pkt = true; + goto out; + } + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) + vvs->msg_count++; + + /* Try to copy small packets into the buffer of last packet queued, + * to avoid wasting memory queueing the entire buffer with a small + * payload. + */ + if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) { + struct virtio_vsock_hdr *last_hdr; + struct sk_buff *last_skb; + + last_skb = skb_peek_tail(&vvs->rx_queue); + last_hdr = virtio_vsock_hdr(last_skb); + + /* If there is space in the last packet queued, we copy the + * new packet in its buffer. We avoid this if the last packet + * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is + * delimiter of SEQPACKET message, so 'pkt' is the first packet + * of a new message. + */ + if (skb->len < skb_tailroom(last_skb) && + !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) { + memcpy(skb_put(last_skb, skb->len), skb->data, skb->len); + free_pkt = true; + last_hdr->flags |= hdr->flags; + le32_add_cpu(&last_hdr->len, len); + goto out; + } + } + + __skb_queue_tail(&vvs->rx_queue, skb); + +out: + spin_unlock_bh(&vvs->rx_lock); + if (free_pkt) + kfree_skb(skb); +} + +static int +virtio_transport_recv_connected(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + int err = 0; + + switch (le16_to_cpu(hdr->op)) { + case VIRTIO_VSOCK_OP_RW: + virtio_transport_recv_enqueue(vsk, skb); + vsock_data_ready(sk); + return err; + case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + virtio_transport_send_credit_update(vsk); + break; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + sk->sk_write_space(sk); + break; + case VIRTIO_VSOCK_OP_SHUTDOWN: + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) + vsk->peer_shutdown |= RCV_SHUTDOWN; + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) + vsk->peer_shutdown |= SEND_SHUTDOWN; + if (vsk->peer_shutdown == SHUTDOWN_MASK) { + if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset(vsk, NULL); + virtio_transport_do_close(vsk, true); + } + /* Remove this socket anyway because the remote peer sent + * the shutdown. This way a new connection will succeed + * if the remote peer uses the same source port, + * even if the old socket is still unreleased, but now disconnected. + */ + vsock_remove_sock(vsk); + } + if (le32_to_cpu(virtio_vsock_hdr(skb)->flags)) + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_RST: + virtio_transport_do_close(vsk, true); + break; + default: + err = -EINVAL; + break; + } + + kfree_skb(skb); + return err; +} + +static void +virtio_transport_recv_disconnecting(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + + if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) + virtio_transport_do_close(vsk, true); +} + +static int +virtio_transport_send_response(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RESPONSE, + .remote_cid = le64_to_cpu(hdr->src_cid), + .remote_port = le32_to_cpu(hdr->src_port), + .reply = true, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +static bool virtio_transport_space_update(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_vsock_sock *vvs = vsk->trans; + bool space_available; + + /* Listener sockets are not associated with any transport, so we are + * not able to take the state to see if there is space available in the + * remote peer, but since they are only used to receive requests, we + * can assume that there is always space available in the other peer. + */ + if (!vvs) + return true; + + /* buf_alloc and fwd_cnt is always included in the hdr */ + spin_lock_bh(&vvs->tx_lock); + vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc); + vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt); + space_available = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + return space_available; +} + +/* Handle server socket */ +static int +virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, + struct virtio_transport *t) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_sock *vchild; + struct sock *child; + int ret; + + if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) { + virtio_transport_reset_no_sock(t, skb); + return -EINVAL; + } + + if (sk_acceptq_is_full(sk)) { + virtio_transport_reset_no_sock(t, skb); + return -ENOMEM; + } + + child = vsock_create_connected(sk); + if (!child) { + virtio_transport_reset_no_sock(t, skb); + return -ENOMEM; + } + + sk_acceptq_added(sk); + + lock_sock_nested(child, SINGLE_DEPTH_NESTING); + + child->sk_state = TCP_ESTABLISHED; + + vchild = vsock_sk(child); + vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port)); + vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + + ret = vsock_assign_transport(vchild, vsk); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (ret || vchild->transport != &t->transport) { + release_sock(child); + virtio_transport_reset_no_sock(t, skb); + sock_put(child); + return ret; + } + + if (virtio_transport_space_update(child, skb)) + child->sk_write_space(child); + + vsock_insert_connected(vchild); + vsock_enqueue_accept(sk, child); + virtio_transport_send_response(vchild, skb); + + release_sock(child); + + sk->sk_data_ready(sk); + return 0; +} + +static bool virtio_transport_valid_type(u16 type) +{ + return (type == VIRTIO_VSOCK_TYPE_STREAM) || + (type == VIRTIO_VSOCK_TYPE_SEQPACKET); +} + +/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex + * lock. + */ +void virtio_transport_recv_pkt(struct virtio_transport *t, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct sockaddr_vm src, dst; + struct vsock_sock *vsk; + struct sock *sk; + bool space_available; + + vsock_addr_init(&src, le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port)); + + trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, + dst.svm_cid, dst.svm_port, + le32_to_cpu(hdr->len), + le16_to_cpu(hdr->type), + le16_to_cpu(hdr->op), + le32_to_cpu(hdr->flags), + le32_to_cpu(hdr->buf_alloc), + le32_to_cpu(hdr->fwd_cnt)); + + if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) { + (void)virtio_transport_reset_no_sock(t, skb); + goto free_pkt; + } + + /* The socket must be in connected or bound table + * otherwise send reset back + */ + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + (void)virtio_transport_reset_no_sock(t, skb); + goto free_pkt; + } + } + + if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) { + (void)virtio_transport_reset_no_sock(t, skb); + sock_put(sk); + goto free_pkt; + } + + if (!skb_set_owner_sk_safe(skb, sk)) { + WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n"); + goto free_pkt; + } + + vsk = vsock_sk(sk); + + lock_sock(sk); + + /* Check if sk has been closed before lock_sock */ + if (sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset_no_sock(t, skb); + release_sock(sk); + sock_put(sk); + goto free_pkt; + } + + space_available = virtio_transport_space_update(sk, skb); + + /* Update CID in case it has changed after a transport reset event */ + if (vsk->local_addr.svm_cid != VMADDR_CID_ANY) + vsk->local_addr.svm_cid = dst.svm_cid; + + if (space_available) + sk->sk_write_space(sk); + + switch (sk->sk_state) { + case TCP_LISTEN: + virtio_transport_recv_listen(sk, skb, t); + kfree_skb(skb); + break; + case TCP_SYN_SENT: + virtio_transport_recv_connecting(sk, skb); + kfree_skb(skb); + break; + case TCP_ESTABLISHED: + virtio_transport_recv_connected(sk, skb); + break; + case TCP_CLOSING: + virtio_transport_recv_disconnecting(sk, skb); + kfree_skb(skb); + break; + default: + (void)virtio_transport_reset_no_sock(t, skb); + kfree_skb(skb); + break; + } + + release_sock(sk); + + /* Release refcnt obtained when we fetched this socket out of the + * bound or connected list. + */ + sock_put(sk); + return; + +free_pkt: + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); + +/* Remove skbs found in a queue that have a vsk that matches. + * + * Each skb is freed. + * + * Returns the count of skbs that were reply packets. + */ +int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue) +{ + struct sk_buff_head freeme; + struct sk_buff *skb, *tmp; + int cnt = 0; + + skb_queue_head_init(&freeme); + + spin_lock_bh(&queue->lock); + skb_queue_walk_safe(queue, skb, tmp) { + if (vsock_sk(skb->sk) != vsk) + continue; + + __skb_unlink(skb, queue); + __skb_queue_tail(&freeme, skb); + + if (virtio_vsock_skb_reply(skb)) + cnt++; + } + spin_unlock_bh(&queue->lock); + + __skb_queue_purge(&freeme); + + return cnt; +} +EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs); + +int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct sock *sk = sk_vsock(vsk); + struct sk_buff *skb; + int off = 0; + int err; + + spin_lock_bh(&vvs->rx_lock); + /* Use __skb_recv_datagram() for race-free handling of the receive. It + * works for types other than dgrams. + */ + skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err); + spin_unlock_bh(&vvs->rx_lock); + + if (!skb) + return err; + + return recv_actor(sk, skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_read_skb); + +int virtio_transport_notify_set_rcvlowat(struct vsock_sock *vsk, int val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + bool send_update; + + spin_lock_bh(&vvs->rx_lock); + + /* If number of available bytes is less than new SO_RCVLOWAT value, + * kick sender to send more data, because sender may sleep in its + * 'send()' syscall waiting for enough space at our side. Also + * don't send credit update when peer already knows actual value - + * such transmission will be useless. + */ + send_update = (vvs->rx_bytes < val) && + (vvs->fwd_cnt != vvs->last_fwd_cnt); + + spin_unlock_bh(&vvs->rx_lock); + + if (send_update) { + int err; + + err = virtio_transport_send_credit_update(vsk); + if (err < 0) + return err; + } + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_set_rcvlowat); + +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("common code for virtio vsock"); diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c new file mode 100644 index 0000000000..b370070194 --- /dev/null +++ b/net/vmw_vsock/vmci_transport.c @@ -0,0 +1,2160 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. + */ + +#include <linux/types.h> +#include <linux/bitops.h> +#include <linux/cred.h> +#include <linux/init.h> +#include <linux/io.h> +#include <linux/kernel.h> +#include <linux/kmod.h> +#include <linux/list.h> +#include <linux/module.h> +#include <linux/mutex.h> +#include <linux/net.h> +#include <linux/poll.h> +#include <linux/skbuff.h> +#include <linux/smp.h> +#include <linux/socket.h> +#include <linux/stddef.h> +#include <linux/unistd.h> +#include <linux/wait.h> +#include <linux/workqueue.h> +#include <net/sock.h> +#include <net/af_vsock.h> + +#include "vmci_transport_notify.h" + +static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg); +static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg); +static void vmci_transport_peer_detach_cb(u32 sub_id, + const struct vmci_event_data *ed, + void *client_data); +static void vmci_transport_recv_pkt_work(struct work_struct *work); +static void vmci_transport_cleanup(struct work_struct *work); +static int vmci_transport_recv_listen(struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_server( + struct sock *sk, + struct sock *pending, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client_negotiate( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client_invalid( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connected(struct sock *sk, + struct vmci_transport_packet *pkt); +static bool vmci_transport_old_proto_override(bool *old_pkt_proto); +static u16 vmci_transport_new_proto_supported_versions(void); +static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto, + bool old_pkt_proto); +static bool vmci_check_transport(struct vsock_sock *vsk); + +struct vmci_transport_recv_pkt_info { + struct work_struct work; + struct sock *sk; + struct vmci_transport_packet pkt; +}; + +static LIST_HEAD(vmci_transport_cleanup_list); +static DEFINE_SPINLOCK(vmci_transport_cleanup_lock); +static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup); + +static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID, + VMCI_INVALID_ID }; +static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + +static int PROTOCOL_OVERRIDE = -1; + +static struct vsock_transport vmci_transport; /* forward declaration */ + +/* Helper function to convert from a VMCI error code to a VSock error code. */ + +static s32 vmci_transport_error_to_vsock_error(s32 vmci_error) +{ + switch (vmci_error) { + case VMCI_ERROR_NO_MEM: + return -ENOMEM; + case VMCI_ERROR_DUPLICATE_ENTRY: + case VMCI_ERROR_ALREADY_EXISTS: + return -EADDRINUSE; + case VMCI_ERROR_NO_ACCESS: + return -EPERM; + case VMCI_ERROR_NO_RESOURCES: + return -ENOBUFS; + case VMCI_ERROR_INVALID_RESOURCE: + return -EHOSTUNREACH; + case VMCI_ERROR_INVALID_ARGS: + default: + break; + } + return -EINVAL; +} + +static u32 vmci_transport_peer_rid(u32 peer_cid) +{ + if (VMADDR_CID_HYPERVISOR == peer_cid) + return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID; + + return VMCI_TRANSPORT_PACKET_RID; +} + +static inline void +vmci_transport_packet_init(struct vmci_transport_packet *pkt, + struct sockaddr_vm *src, + struct sockaddr_vm *dst, + u8 type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + /* We register the stream control handler as an any cid handle so we + * must always send from a source address of VMADDR_CID_ANY + */ + pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY, + VMCI_TRANSPORT_PACKET_RID); + pkt->dg.dst = vmci_make_handle(dst->svm_cid, + vmci_transport_peer_rid(dst->svm_cid)); + pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg); + pkt->version = VMCI_TRANSPORT_PACKET_VERSION; + pkt->type = type; + pkt->src_port = src->svm_port; + pkt->dst_port = dst->svm_port; + memset(&pkt->proto, 0, sizeof(pkt->proto)); + memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2)); + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_INVALID: + pkt->u.size = 0; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_REQUEST: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: + pkt->u.size = size; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_OFFER: + case VMCI_TRANSPORT_PACKET_TYPE_ATTACH: + pkt->u.handle = handle; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + case VMCI_TRANSPORT_PACKET_TYPE_READ: + case VMCI_TRANSPORT_PACKET_TYPE_RST: + pkt->u.size = 0; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN: + pkt->u.mode = mode; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ: + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE: + memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait)); + break; + + case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + pkt->u.size = size; + pkt->proto = proto; + break; + } +} + +static inline void +vmci_transport_packet_get_addresses(struct vmci_transport_packet *pkt, + struct sockaddr_vm *local, + struct sockaddr_vm *remote) +{ + vsock_addr_init(local, pkt->dg.dst.context, pkt->dst_port); + vsock_addr_init(remote, pkt->dg.src.context, pkt->src_port); +} + +static int +__vmci_transport_send_control_pkt(struct vmci_transport_packet *pkt, + struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle, + bool convert_error) +{ + int err; + + vmci_transport_packet_init(pkt, src, dst, type, size, mode, wait, + proto, handle); + err = vmci_datagram_send(&pkt->dg); + if (convert_error && (err < 0)) + return vmci_transport_error_to_vsock_error(err); + + return err; +} + +static int +vmci_transport_reply_control_pkt_fast(struct vmci_transport_packet *pkt, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + struct vmci_handle handle) +{ + struct vmci_transport_packet reply; + struct sockaddr_vm src, dst; + + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) { + return 0; + } else { + vmci_transport_packet_get_addresses(pkt, &src, &dst); + return __vmci_transport_send_control_pkt(&reply, &src, &dst, + type, + size, mode, wait, + VSOCK_PROTO_INVALID, + handle, true); + } +} + +static int +vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + struct vmci_handle handle) +{ + /* Note that it is safe to use a single packet across all CPUs since + * two tasklets of the same type are guaranteed to not ever run + * simultaneously. If that ever changes, or VMCI stops using tasklets, + * we can use per-cpu packets. + */ + static struct vmci_transport_packet pkt; + + return __vmci_transport_send_control_pkt(&pkt, src, dst, type, + size, mode, wait, + VSOCK_PROTO_INVALID, handle, + false); +} + +static int +vmci_transport_alloc_send_control_pkt(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + struct vmci_transport_packet *pkt; + int err; + + pkt = kmalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return -ENOMEM; + + err = __vmci_transport_send_control_pkt(pkt, src, dst, type, size, + mode, wait, proto, handle, + true); + kfree(pkt); + + return err; +} + +static int +vmci_transport_send_control_pkt(struct sock *sk, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + if (!vsock_addr_bound(&vsk->remote_addr)) + return -EINVAL; + + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, + &vsk->remote_addr, + type, size, mode, + wait, proto, handle); +} + +static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src, + struct vmci_transport_packet *pkt) +{ + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) + return 0; + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_RST, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_reset(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct sockaddr_vm *dst_ptr; + struct sockaddr_vm dst; + struct vsock_sock *vsk; + + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) + return 0; + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + if (vsock_addr_bound(&vsk->remote_addr)) { + dst_ptr = &vsk->remote_addr; + } else { + vsock_addr_init(&dst, pkt->dg.src.context, + pkt->src_port); + dst_ptr = &dst; + } + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr, + VMCI_TRANSPORT_PACKET_TYPE_RST, + 0, 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_negotiate(struct sock *sk, size_t size) +{ + return vmci_transport_send_control_pkt( + sk, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE, + size, 0, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_negotiate2(struct sock *sk, size_t size, + u16 version) +{ + return vmci_transport_send_control_pkt( + sk, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2, + size, 0, NULL, version, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_qp_offer(struct sock *sk, + struct vmci_handle handle) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_OFFER, 0, + 0, NULL, + VSOCK_PROTO_INVALID, handle); +} + +static int vmci_transport_send_attach(struct sock *sk, + struct vmci_handle handle) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_ATTACH, + 0, 0, NULL, VSOCK_PROTO_INVALID, + handle); +} + +static int vmci_transport_reply_reset(struct vmci_transport_packet *pkt) +{ + return vmci_transport_reply_control_pkt_fast( + pkt, + VMCI_TRANSPORT_PACKET_TYPE_RST, + 0, 0, NULL, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_invalid_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_INVALID, + 0, 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_read_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_READ, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_wrote(struct sock *sk) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0, + 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_read(struct sock *sk) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_READ, 0, + 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_waiting_write(struct sock *sk, + struct vmci_transport_waiting_info *wait) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE, + 0, 0, wait, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_waiting_read(struct sock *sk, + struct vmci_transport_waiting_info *wait) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ, + 0, 0, wait, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + return vmci_transport_send_control_pkt( + &vsk->sk, + VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN, + 0, mode, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_conn_request(struct sock *sk, size_t size) +{ + return vmci_transport_send_control_pkt(sk, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST, + size, 0, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_conn_request2(struct sock *sk, size_t size, + u16 version) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_REQUEST2, + size, 0, NULL, version, + VMCI_INVALID_HANDLE); +} + +static struct sock *vmci_transport_get_pending( + struct sock *listener, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vpending; + struct sock *pending; + struct sockaddr_vm src; + + vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port); + + vlistener = vsock_sk(listener); + + list_for_each_entry(vpending, &vlistener->pending_links, + pending_links) { + if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && + pkt->dst_port == vpending->local_addr.svm_port) { + pending = sk_vsock(vpending); + sock_hold(pending); + goto found; + } + } + + pending = NULL; +found: + return pending; + +} + +static void vmci_transport_release_pending(struct sock *pending) +{ + sock_put(pending); +} + +/* We allow two kinds of sockets to communicate with a restricted VM: 1) + * trusted sockets 2) sockets from applications running as the same user as the + * VM (this is only true for the host side and only when using hosted products) + */ + +static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid) +{ + return vsock->trusted || + vmci_is_context_owner(peer_cid, vsock->owner->uid); +} + +/* We allow sending datagrams to and receiving datagrams from a restricted VM + * only if it is trusted as described in vmci_transport_is_trusted. + */ + +static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid) +{ + if (VMADDR_CID_HYPERVISOR == peer_cid) + return true; + + if (vsock->cached_peer != peer_cid) { + vsock->cached_peer = peer_cid; + if (!vmci_transport_is_trusted(vsock, peer_cid) && + (vmci_context_get_priv_flags(peer_cid) & + VMCI_PRIVILEGE_FLAG_RESTRICTED)) { + vsock->cached_peer_allow_dgram = false; + } else { + vsock->cached_peer_allow_dgram = true; + } + } + + return vsock->cached_peer_allow_dgram; +} + +static int +vmci_transport_queue_pair_alloc(struct vmci_qp **qpair, + struct vmci_handle *handle, + u64 produce_size, + u64 consume_size, + u32 peer, u32 flags, bool trusted) +{ + int err = 0; + + if (trusted) { + /* Try to allocate our queue pair as trusted. This will only + * work if vsock is running in the host. + */ + + err = vmci_qpair_alloc(qpair, handle, produce_size, + consume_size, + peer, flags, + VMCI_PRIVILEGE_FLAG_TRUSTED); + if (err != VMCI_ERROR_NO_ACCESS) + goto out; + + } + + err = vmci_qpair_alloc(qpair, handle, produce_size, consume_size, + peer, flags, VMCI_NO_PRIVILEGE_FLAGS); +out: + if (err < 0) { + pr_err_once("Could not attach to queue pair with %d\n", err); + err = vmci_transport_error_to_vsock_error(err); + } + + return err; +} + +static int +vmci_transport_datagram_create_hnd(u32 resource_id, + u32 flags, + vmci_datagram_recv_cb recv_cb, + void *client_data, + struct vmci_handle *out_handle) +{ + int err = 0; + + /* Try to allocate our datagram handler as trusted. This will only work + * if vsock is running in the host. + */ + + err = vmci_datagram_create_handle_priv(resource_id, flags, + VMCI_PRIVILEGE_FLAG_TRUSTED, + recv_cb, + client_data, out_handle); + + if (err == VMCI_ERROR_NO_ACCESS) + err = vmci_datagram_create_handle(resource_id, flags, + recv_cb, client_data, + out_handle); + + return err; +} + +/* This is invoked as part of a tasklet that's scheduled when the VMCI + * interrupt fires. This is run in bottom-half context and if it ever needs to + * sleep it should defer that work to a work queue. + */ + +static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg) +{ + struct sock *sk; + size_t size; + struct sk_buff *skb; + struct vsock_sock *vsk; + + sk = (struct sock *)data; + + /* This handler is privileged when this module is running on the host. + * We will get datagrams from all endpoints (even VMs that are in a + * restricted context). If we get one from a restricted context then + * the destination socket must be trusted. + * + * NOTE: We access the socket struct without holding the lock here. + * This is ok because the field we are interested is never modified + * outside of the create and destruct socket functions. + */ + vsk = vsock_sk(sk); + if (!vmci_transport_allow_dgram(vsk, dg->src.context)) + return VMCI_ERROR_NO_ACCESS; + + size = VMCI_DG_SIZE(dg); + + /* Attach the packet to the socket's receive queue as an sk_buff. */ + skb = alloc_skb(size, GFP_ATOMIC); + if (!skb) + return VMCI_ERROR_NO_MEM; + + /* sk_receive_skb() will do a sock_put(), so hold here. */ + sock_hold(sk); + skb_put(skb, size); + memcpy(skb->data, dg, size); + sk_receive_skb(sk, skb, 0); + + return VMCI_SUCCESS; +} + +static bool vmci_transport_stream_allow(u32 cid, u32 port) +{ + static const u32 non_socket_contexts[] = { + VMADDR_CID_LOCAL, + }; + int i; + + BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts)); + + for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) { + if (cid == non_socket_contexts[i]) + return false; + } + + return true; +} + +/* This is invoked as part of a tasklet that's scheduled when the VMCI + * interrupt fires. This is run in bottom-half context but it defers most of + * its work to the packet handling work queue. + */ + +static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg) +{ + struct sock *sk; + struct sockaddr_vm dst; + struct sockaddr_vm src; + struct vmci_transport_packet *pkt; + struct vsock_sock *vsk; + bool bh_process_pkt; + int err; + + sk = NULL; + err = VMCI_SUCCESS; + bh_process_pkt = false; + + /* Ignore incoming packets from contexts without sockets, or resources + * that aren't vsock implementations. + */ + + if (!vmci_transport_stream_allow(dg->src.context, -1) + || vmci_transport_peer_rid(dg->src.context) != dg->src.resource) + return VMCI_ERROR_NO_ACCESS; + + if (VMCI_DG_SIZE(dg) < sizeof(*pkt)) + /* Drop datagrams that do not contain full VSock packets. */ + return VMCI_ERROR_INVALID_ARGS; + + pkt = (struct vmci_transport_packet *)dg; + + /* Find the socket that should handle this packet. First we look for a + * connected socket and if there is none we look for a socket bound to + * the destintation address. + */ + vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port); + vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port); + + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + /* We could not find a socket for this specified + * address. If this packet is a RST, we just drop it. + * If it is another packet, we send a RST. Note that + * we do not send a RST reply to RSTs so that we do not + * continually send RSTs between two endpoints. + * + * Note that since this is a reply, dst is src and src + * is dst. + */ + if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0) + pr_err("unable to send reset\n"); + + err = VMCI_ERROR_NOT_FOUND; + goto out; + } + } + + /* If the received packet type is beyond all types known to this + * implementation, reply with an invalid message. Hopefully this will + * help when implementing backwards compatibility in the future. + */ + if (pkt->type >= VMCI_TRANSPORT_PACKET_TYPE_MAX) { + vmci_transport_send_invalid_bh(&dst, &src); + err = VMCI_ERROR_INVALID_ARGS; + goto out; + } + + /* This handler is privileged when this module is running on the host. + * We will get datagram connect requests from all endpoints (even VMs + * that are in a restricted context). If we get one from a restricted + * context then the destination socket must be trusted. + * + * NOTE: We access the socket struct without holding the lock here. + * This is ok because the field we are interested is never modified + * outside of the create and destruct socket functions. + */ + vsk = vsock_sk(sk); + if (!vmci_transport_allow_dgram(vsk, pkt->dg.src.context)) { + err = VMCI_ERROR_NO_ACCESS; + goto out; + } + + /* We do most everything in a work queue, but let's fast path the + * notification of reads and writes to help data transfer performance. + * We can only do this if there is no process context code executing + * for this socket since that may change the state. + */ + bh_lock_sock(sk); + + if (!sock_owned_by_user(sk)) { + /* The local context ID may be out of date, update it. */ + vsk->local_addr.svm_cid = dst.svm_cid; + + if (sk->sk_state == TCP_ESTABLISHED) + vmci_trans(vsk)->notify_ops->handle_notify_pkt( + sk, pkt, true, &dst, &src, + &bh_process_pkt); + } + + bh_unlock_sock(sk); + + if (!bh_process_pkt) { + struct vmci_transport_recv_pkt_info *recv_pkt_info; + + recv_pkt_info = kmalloc(sizeof(*recv_pkt_info), GFP_ATOMIC); + if (!recv_pkt_info) { + if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0) + pr_err("unable to send reset\n"); + + err = VMCI_ERROR_NO_MEM; + goto out; + } + + recv_pkt_info->sk = sk; + memcpy(&recv_pkt_info->pkt, pkt, sizeof(recv_pkt_info->pkt)); + INIT_WORK(&recv_pkt_info->work, vmci_transport_recv_pkt_work); + + schedule_work(&recv_pkt_info->work); + /* Clear sk so that the reference count incremented by one of + * the Find functions above is not decremented below. We need + * that reference count for the packet handler we've scheduled + * to run. + */ + sk = NULL; + } + +out: + if (sk) + sock_put(sk); + + return err; +} + +static void vmci_transport_handle_detach(struct sock *sk) +{ + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) { + sock_set_flag(sk, SOCK_DONE); + + /* On a detach the peer will not be sending or receiving + * anymore. + */ + vsk->peer_shutdown = SHUTDOWN_MASK; + + /* We should not be sending anymore since the peer won't be + * there to receive, but we can still receive if there is data + * left in our consume queue. If the local endpoint is a host, + * we can't call vsock_stream_has_data, since that may block, + * but a host endpoint can't read data once the VM has + * detached, so there is no available data in that case. + */ + if (vsk->local_addr.svm_cid == VMADDR_CID_HOST || + vsock_stream_has_data(vsk) <= 0) { + if (sk->sk_state == TCP_SYN_SENT) { + /* The peer may detach from a queue pair while + * we are still in the connecting state, i.e., + * if the peer VM is killed after attaching to + * a queue pair, but before we complete the + * handshake. In that case, we treat the detach + * event like a reset. + */ + + sk->sk_state = TCP_CLOSE; + sk->sk_err = ECONNRESET; + sk_error_report(sk); + return; + } + sk->sk_state = TCP_CLOSE; + } + sk->sk_state_change(sk); + } +} + +static void vmci_transport_peer_detach_cb(u32 sub_id, + const struct vmci_event_data *e_data, + void *client_data) +{ + struct vmci_transport *trans = client_data; + const struct vmci_event_payload_qp *e_payload; + + e_payload = vmci_event_data_const_payload(e_data); + + /* XXX This is lame, we should provide a way to lookup sockets by + * qp_handle. + */ + if (vmci_handle_is_invalid(e_payload->handle) || + !vmci_handle_is_equal(trans->qp_handle, e_payload->handle)) + return; + + /* We don't ask for delayed CBs when we subscribe to this event (we + * pass 0 as flags to vmci_event_subscribe()). VMCI makes no + * guarantees in that case about what context we might be running in, + * so it could be BH or process, blockable or non-blockable. So we + * need to account for all possible contexts here. + */ + spin_lock_bh(&trans->lock); + if (!trans->sk) + goto out; + + /* Apart from here, trans->lock is only grabbed as part of sk destruct, + * where trans->sk isn't locked. + */ + bh_lock_sock(trans->sk); + + vmci_transport_handle_detach(trans->sk); + + bh_unlock_sock(trans->sk); + out: + spin_unlock_bh(&trans->lock); +} + +static void vmci_transport_qp_resumed_cb(u32 sub_id, + const struct vmci_event_data *e_data, + void *client_data) +{ + vsock_for_each_connected_socket(&vmci_transport, + vmci_transport_handle_detach); +} + +static void vmci_transport_recv_pkt_work(struct work_struct *work) +{ + struct vmci_transport_recv_pkt_info *recv_pkt_info; + struct vmci_transport_packet *pkt; + struct sock *sk; + + recv_pkt_info = + container_of(work, struct vmci_transport_recv_pkt_info, work); + sk = recv_pkt_info->sk; + pkt = &recv_pkt_info->pkt; + + lock_sock(sk); + + /* The local context ID may be out of date. */ + vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context; + + switch (sk->sk_state) { + case TCP_LISTEN: + vmci_transport_recv_listen(sk, pkt); + break; + case TCP_SYN_SENT: + /* Processing of pending connections for servers goes through + * the listening socket, so see vmci_transport_recv_listen() + * for that path. + */ + vmci_transport_recv_connecting_client(sk, pkt); + break; + case TCP_ESTABLISHED: + vmci_transport_recv_connected(sk, pkt); + break; + default: + /* Because this function does not run in the same context as + * vmci_transport_recv_stream_cb it is possible that the + * socket has closed. We need to let the other side know or it + * could be sitting in a connect and hang forever. Send a + * reset to prevent that. + */ + vmci_transport_send_reset(sk, pkt); + break; + } + + release_sock(sk); + kfree(recv_pkt_info); + /* Release reference obtained in the stream callback when we fetched + * this socket out of the bound or connected list. + */ + sock_put(sk); +} + +static int vmci_transport_recv_listen(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct sock *pending; + struct vsock_sock *vpending; + int err; + u64 qp_size; + bool old_request = false; + bool old_pkt_proto = false; + + /* Because we are in the listen state, we could be receiving a packet + * for ourself or any previous connection requests that we received. + * If it's the latter, we try to find a socket in our list of pending + * connections and, if we do, call the appropriate handler for the + * state that socket is in. Otherwise we try to service the + * connection request. + */ + pending = vmci_transport_get_pending(sk, pkt); + if (pending) { + lock_sock(pending); + + /* The local context ID may be out of date. */ + vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context; + + switch (pending->sk_state) { + case TCP_SYN_SENT: + err = vmci_transport_recv_connecting_server(sk, + pending, + pkt); + break; + default: + vmci_transport_send_reset(pending, pkt); + err = -EINVAL; + } + + if (err < 0) + vsock_remove_pending(sk, pending); + + release_sock(pending); + vmci_transport_release_pending(pending); + + return err; + } + + /* The listen state only accepts connection requests. Reply with a + * reset unless we received a reset. + */ + + if (!(pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST || + pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)) { + vmci_transport_reply_reset(pkt); + return -EINVAL; + } + + if (pkt->u.size == 0) { + vmci_transport_reply_reset(pkt); + return -EINVAL; + } + + /* If this socket can't accommodate this connection request, we send a + * reset. Otherwise we create and initialize a child socket and reply + * with a connection negotiation. + */ + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) { + vmci_transport_reply_reset(pkt); + return -ECONNREFUSED; + } + + pending = vsock_create_connected(sk); + if (!pending) { + vmci_transport_send_reset(sk, pkt); + return -ENOMEM; + } + + vpending = vsock_sk(pending); + + vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context, + pkt->dst_port); + vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, + pkt->src_port); + + err = vsock_assign_transport(vpending, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (err || !vmci_check_transport(vpending)) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + return err; + } + + /* If the proposed size fits within our min/max, accept it. Otherwise + * propose our own size. + */ + if (pkt->u.size >= vpending->buffer_min_size && + pkt->u.size <= vpending->buffer_max_size) { + qp_size = pkt->u.size; + } else { + qp_size = vpending->buffer_size; + } + + /* Figure out if we are using old or new requests based on the + * overrides pkt types sent by our peer. + */ + if (vmci_transport_old_proto_override(&old_pkt_proto)) { + old_request = old_pkt_proto; + } else { + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST) + old_request = true; + else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2) + old_request = false; + + } + + if (old_request) { + /* Handle a REQUEST (or override) */ + u16 version = VSOCK_PROTO_INVALID; + if (vmci_transport_proto_to_notify_struct( + pending, &version, true)) + err = vmci_transport_send_negotiate(pending, qp_size); + else + err = -EINVAL; + + } else { + /* Handle a REQUEST2 (or override) */ + int proto_int = pkt->proto; + int pos; + u16 active_proto_version = 0; + + /* The list of possible protocols is the intersection of all + * protocols the client supports ... plus all the protocols we + * support. + */ + proto_int &= vmci_transport_new_proto_supported_versions(); + + /* We choose the highest possible protocol version and use that + * one. + */ + pos = fls(proto_int); + if (pos) { + active_proto_version = (1 << (pos - 1)); + if (vmci_transport_proto_to_notify_struct( + pending, &active_proto_version, false)) + err = vmci_transport_send_negotiate2(pending, + qp_size, + active_proto_version); + else + err = -EINVAL; + + } else { + err = -EINVAL; + } + } + + if (err < 0) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + err = vmci_transport_error_to_vsock_error(err); + goto out; + } + + vsock_add_pending(sk, pending); + sk_acceptq_added(sk); + + pending->sk_state = TCP_SYN_SENT; + vmci_trans(vpending)->produce_size = + vmci_trans(vpending)->consume_size = qp_size; + vpending->buffer_size = qp_size; + + vmci_trans(vpending)->notify_ops->process_request(pending); + + /* We might never receive another message for this socket and it's not + * connected to any process, so we have to ensure it gets cleaned up + * ourself. Our delayed work function will take care of that. Note + * that we do not ever cancel this function since we have few + * guarantees about its state when calling cancel_delayed_work(). + * Instead we hold a reference on the socket for that function and make + * it capable of handling cases where it needs to do nothing but + * release that reference. + */ + vpending->listener = sk; + sock_hold(sk); + sock_hold(pending); + schedule_delayed_work(&vpending->pending_work, HZ); + +out: + return err; +} + +static int +vmci_transport_recv_connecting_server(struct sock *listener, + struct sock *pending, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vpending; + struct vmci_handle handle; + struct vmci_qp *qpair; + bool is_local; + u32 flags; + u32 detach_sub_id; + int err; + int skerr; + + vpending = vsock_sk(pending); + detach_sub_id = VMCI_INVALID_ID; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_OFFER: + if (vmci_handle_is_invalid(pkt->u.handle)) { + vmci_transport_send_reset(pending, pkt); + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + break; + default: + /* Close and cleanup the connection. */ + vmci_transport_send_reset(pending, pkt); + skerr = EPROTO; + err = pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST ? 0 : -EINVAL; + goto destroy; + } + + /* In order to complete the connection we need to attach to the offered + * queue pair and send an attach notification. We also subscribe to the + * detach event so we know when our peer goes away, and we do that + * before attaching so we don't miss an event. If all this succeeds, + * we update our state and wakeup anything waiting in accept() for a + * connection. + */ + + /* We don't care about attach since we ensure the other side has + * attached by specifying the ATTACH_ONLY flag below. + */ + err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH, + vmci_transport_peer_detach_cb, + vmci_trans(vpending), &detach_sub_id); + if (err < VMCI_SUCCESS) { + vmci_transport_send_reset(pending, pkt); + err = vmci_transport_error_to_vsock_error(err); + skerr = -err; + goto destroy; + } + + vmci_trans(vpending)->detach_sub_id = detach_sub_id; + + /* Now attach to the queue pair the client created. */ + handle = pkt->u.handle; + + /* vpending->local_addr always has a context id so we do not need to + * worry about VMADDR_CID_ANY in this case. + */ + is_local = + vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid; + flags = VMCI_QPFLAG_ATTACH_ONLY; + flags |= is_local ? VMCI_QPFLAG_LOCAL : 0; + + err = vmci_transport_queue_pair_alloc( + &qpair, + &handle, + vmci_trans(vpending)->produce_size, + vmci_trans(vpending)->consume_size, + pkt->dg.src.context, + flags, + vmci_transport_is_trusted( + vpending, + vpending->remote_addr.svm_cid)); + if (err < 0) { + vmci_transport_send_reset(pending, pkt); + skerr = -err; + goto destroy; + } + + vmci_trans(vpending)->qp_handle = handle; + vmci_trans(vpending)->qpair = qpair; + + /* When we send the attach message, we must be ready to handle incoming + * control messages on the newly connected socket. So we move the + * pending socket to the connected state before sending the attach + * message. Otherwise, an incoming packet triggered by the attach being + * received by the peer may be processed concurrently with what happens + * below after sending the attach message, and that incoming packet + * will find the listening socket instead of the (currently) pending + * socket. Note that enqueueing the socket increments the reference + * count, so even if a reset comes before the connection is accepted, + * the socket will be valid until it is removed from the queue. + * + * If we fail sending the attach below, we remove the socket from the + * connected list and move the socket to TCP_CLOSE before + * releasing the lock, so a pending slow path processing of an incoming + * packet will not see the socket in the connected state in that case. + */ + pending->sk_state = TCP_ESTABLISHED; + + vsock_insert_connected(vpending); + + /* Notify our peer of our attach. */ + err = vmci_transport_send_attach(pending, handle); + if (err < 0) { + vsock_remove_connected(vpending); + pr_err("Could not send attach\n"); + vmci_transport_send_reset(pending, pkt); + err = vmci_transport_error_to_vsock_error(err); + skerr = -err; + goto destroy; + } + + /* We have a connection. Move the now connected socket from the + * listener's pending list to the accept queue so callers of accept() + * can find it. + */ + vsock_remove_pending(listener, pending); + vsock_enqueue_accept(listener, pending); + + /* Callers of accept() will be waiting on the listening socket, not + * the pending socket. + */ + listener->sk_data_ready(listener); + + return 0; + +destroy: + pending->sk_err = skerr; + pending->sk_state = TCP_CLOSE; + /* As long as we drop our reference, all necessary cleanup will handle + * when the cleanup function drops its reference and our destruct + * implementation is called. Note that since the listen handler will + * remove pending from the pending list upon our failure, the cleanup + * function won't drop the additional reference, which is why we do it + * here. + */ + sock_put(pending); + + return err; +} + +static int +vmci_transport_recv_connecting_client(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vsk; + int err; + int skerr; + + vsk = vsock_sk(sk); + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_ATTACH: + if (vmci_handle_is_invalid(pkt->u.handle) || + !vmci_handle_is_equal(pkt->u.handle, + vmci_trans(vsk)->qp_handle)) { + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + + /* Signify the socket is connected and wakeup the waiter in + * connect(). Also place the socket in the connected table for + * accounting (it can already be found since it's in the bound + * table). + */ + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + + break; + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + if (pkt->u.size == 0 + || pkt->dg.src.context != vsk->remote_addr.svm_cid + || pkt->src_port != vsk->remote_addr.svm_port + || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle) + || vmci_trans(vsk)->qpair + || vmci_trans(vsk)->produce_size != 0 + || vmci_trans(vsk)->consume_size != 0 + || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) { + skerr = EPROTO; + err = -EINVAL; + + goto destroy; + } + + err = vmci_transport_recv_connecting_client_negotiate(sk, pkt); + if (err) { + skerr = -err; + goto destroy; + } + + break; + case VMCI_TRANSPORT_PACKET_TYPE_INVALID: + err = vmci_transport_recv_connecting_client_invalid(sk, pkt); + if (err) { + skerr = -err; + goto destroy; + } + + break; + case VMCI_TRANSPORT_PACKET_TYPE_RST: + /* Older versions of the linux code (WS 6.5 / ESX 4.0) used to + * continue processing here after they sent an INVALID packet. + * This meant that we got a RST after the INVALID. We ignore a + * RST after an INVALID. The common code doesn't send the RST + * ... so we can hang if an old version of the common code + * fails between getting a REQUEST and sending an OFFER back. + * Not much we can do about it... except hope that it doesn't + * happen. + */ + if (vsk->ignore_connecting_rst) { + vsk->ignore_connecting_rst = false; + } else { + skerr = ECONNRESET; + err = 0; + goto destroy; + } + + break; + default: + /* Close and cleanup the connection. */ + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + + return 0; + +destroy: + vmci_transport_send_reset(sk, pkt); + + sk->sk_state = TCP_CLOSE; + sk->sk_err = skerr; + sk_error_report(sk); + return err; +} + +static int vmci_transport_recv_connecting_client_negotiate( + struct sock *sk, + struct vmci_transport_packet *pkt) +{ + int err; + struct vsock_sock *vsk; + struct vmci_handle handle; + struct vmci_qp *qpair; + u32 detach_sub_id; + bool is_local; + u32 flags; + bool old_proto = true; + bool old_pkt_proto; + u16 version; + + vsk = vsock_sk(sk); + handle = VMCI_INVALID_HANDLE; + detach_sub_id = VMCI_INVALID_ID; + + /* If we have gotten here then we should be past the point where old + * linux vsock could have sent the bogus rst. + */ + vsk->sent_request = false; + vsk->ignore_connecting_rst = false; + + /* Verify that we're OK with the proposed queue pair size */ + if (pkt->u.size < vsk->buffer_min_size || + pkt->u.size > vsk->buffer_max_size) { + err = -EINVAL; + goto destroy; + } + + /* At this point we know the CID the peer is using to talk to us. */ + + if (vsk->local_addr.svm_cid == VMADDR_CID_ANY) + vsk->local_addr.svm_cid = pkt->dg.dst.context; + + /* Setup the notify ops to be the highest supported version that both + * the server and the client support. + */ + + if (vmci_transport_old_proto_override(&old_pkt_proto)) { + old_proto = old_pkt_proto; + } else { + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE) + old_proto = true; + else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2) + old_proto = false; + + } + + if (old_proto) + version = VSOCK_PROTO_INVALID; + else + version = pkt->proto; + + if (!vmci_transport_proto_to_notify_struct(sk, &version, old_proto)) { + err = -EINVAL; + goto destroy; + } + + /* Subscribe to detach events first. + * + * XXX We attach once for each queue pair created for now so it is easy + * to find the socket (it's provided), but later we should only + * subscribe once and add a way to lookup sockets by queue pair handle. + */ + err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH, + vmci_transport_peer_detach_cb, + vmci_trans(vsk), &detach_sub_id); + if (err < VMCI_SUCCESS) { + err = vmci_transport_error_to_vsock_error(err); + goto destroy; + } + + /* Make VMCI select the handle for us. */ + handle = VMCI_INVALID_HANDLE; + is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid; + flags = is_local ? VMCI_QPFLAG_LOCAL : 0; + + err = vmci_transport_queue_pair_alloc(&qpair, + &handle, + pkt->u.size, + pkt->u.size, + vsk->remote_addr.svm_cid, + flags, + vmci_transport_is_trusted( + vsk, + vsk-> + remote_addr.svm_cid)); + if (err < 0) + goto destroy; + + err = vmci_transport_send_qp_offer(sk, handle); + if (err < 0) { + err = vmci_transport_error_to_vsock_error(err); + goto destroy; + } + + vmci_trans(vsk)->qp_handle = handle; + vmci_trans(vsk)->qpair = qpair; + + vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = + pkt->u.size; + + vmci_trans(vsk)->detach_sub_id = detach_sub_id; + + vmci_trans(vsk)->notify_ops->process_negotiate(sk); + + return 0; + +destroy: + if (detach_sub_id != VMCI_INVALID_ID) + vmci_event_unsubscribe(detach_sub_id); + + if (!vmci_handle_is_invalid(handle)) + vmci_qpair_detach(&qpair); + + return err; +} + +static int +vmci_transport_recv_connecting_client_invalid(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + int err = 0; + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsk->sent_request) { + vsk->sent_request = false; + vsk->ignore_connecting_rst = true; + + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); + if (err < 0) + err = vmci_transport_error_to_vsock_error(err); + else + err = 0; + + } + + return err; +} + +static int vmci_transport_recv_connected(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vsk; + bool pkt_processed = false; + + /* In cases where we are closing the connection, it's sufficient to + * mark the state change (and maybe error) and wake up any waiting + * threads. Since this is a connected socket, it's owned by a user + * process and will be cleaned up when the failure is passed back on + * the current or next system call. Our system call implementations + * must therefore check for error and state changes on entry and when + * being awoken. + */ + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN: + if (pkt->u.mode) { + vsk = vsock_sk(sk); + + vsk->peer_shutdown |= pkt->u.mode; + sk->sk_state_change(sk); + } + break; + + case VMCI_TRANSPORT_PACKET_TYPE_RST: + vsk = vsock_sk(sk); + /* It is possible that we sent our peer a message (e.g a + * WAITING_READ) right before we got notified that the peer had + * detached. If that happens then we can get a RST pkt back + * from our peer even though there is data available for us to + * read. In that case, don't shutdown the socket completely but + * instead allow the local client to finish reading data off + * the queuepair. Always treat a RST pkt in connected mode like + * a clean shutdown. + */ + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + + sk->sk_state_change(sk); + break; + + default: + vsk = vsock_sk(sk); + vmci_trans(vsk)->notify_ops->handle_notify_pkt( + sk, pkt, false, NULL, NULL, + &pkt_processed); + if (!pkt_processed) + return -EINVAL; + + break; + } + + return 0; +} + +static int vmci_transport_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + vsk->trans = kmalloc(sizeof(struct vmci_transport), GFP_KERNEL); + if (!vsk->trans) + return -ENOMEM; + + vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; + vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE; + vmci_trans(vsk)->qpair = NULL; + vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0; + vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID; + vmci_trans(vsk)->notify_ops = NULL; + INIT_LIST_HEAD(&vmci_trans(vsk)->elem); + vmci_trans(vsk)->sk = &vsk->sk; + spin_lock_init(&vmci_trans(vsk)->lock); + + return 0; +} + +static void vmci_transport_free_resources(struct list_head *transport_list) +{ + while (!list_empty(transport_list)) { + struct vmci_transport *transport = + list_first_entry(transport_list, struct vmci_transport, + elem); + list_del(&transport->elem); + + if (transport->detach_sub_id != VMCI_INVALID_ID) { + vmci_event_unsubscribe(transport->detach_sub_id); + transport->detach_sub_id = VMCI_INVALID_ID; + } + + if (!vmci_handle_is_invalid(transport->qp_handle)) { + vmci_qpair_detach(&transport->qpair); + transport->qp_handle = VMCI_INVALID_HANDLE; + transport->produce_size = 0; + transport->consume_size = 0; + } + + kfree(transport); + } +} + +static void vmci_transport_cleanup(struct work_struct *work) +{ + LIST_HEAD(pending); + + spin_lock_bh(&vmci_transport_cleanup_lock); + list_replace_init(&vmci_transport_cleanup_list, &pending); + spin_unlock_bh(&vmci_transport_cleanup_lock); + vmci_transport_free_resources(&pending); +} + +static void vmci_transport_destruct(struct vsock_sock *vsk) +{ + /* transport can be NULL if we hit a failure at init() time */ + if (!vmci_trans(vsk)) + return; + + /* Ensure that the detach callback doesn't use the sk/vsk + * we are about to destruct. + */ + spin_lock_bh(&vmci_trans(vsk)->lock); + vmci_trans(vsk)->sk = NULL; + spin_unlock_bh(&vmci_trans(vsk)->lock); + + if (vmci_trans(vsk)->notify_ops) + vmci_trans(vsk)->notify_ops->socket_destruct(vsk); + + spin_lock_bh(&vmci_transport_cleanup_lock); + list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list); + spin_unlock_bh(&vmci_transport_cleanup_lock); + schedule_work(&vmci_transport_cleanup_work); + + vsk->trans = NULL; +} + +static void vmci_transport_release(struct vsock_sock *vsk) +{ + vsock_remove_sock(vsk); + + if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { + vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); + vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; + } +} + +static int vmci_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + u32 port; + u32 flags; + int err; + + /* VMCI will select a resource ID for us if we provide + * VMCI_INVALID_ID. + */ + port = addr->svm_port == VMADDR_PORT_ANY ? + VMCI_INVALID_ID : addr->svm_port; + + if (port <= LAST_RESERVED_PORT && !capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + flags = addr->svm_cid == VMADDR_CID_ANY ? + VMCI_FLAG_ANYCID_DG_HND : 0; + + err = vmci_transport_datagram_create_hnd(port, flags, + vmci_transport_recv_dgram_cb, + &vsk->sk, + &vmci_trans(vsk)->dg_handle); + if (err < VMCI_SUCCESS) + return vmci_transport_error_to_vsock_error(err); + vsock_addr_init(&vsk->local_addr, addr->svm_cid, + vmci_trans(vsk)->dg_handle.resource); + + return 0; +} + +static int vmci_transport_dgram_enqueue( + struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t len) +{ + int err; + struct vmci_datagram *dg; + + if (len > VMCI_MAX_DG_PAYLOAD_SIZE) + return -EMSGSIZE; + + if (!vmci_transport_allow_dgram(vsk, remote_addr->svm_cid)) + return -EPERM; + + /* Allocate a buffer for the user's message and our packet header. */ + dg = kmalloc(len + sizeof(*dg), GFP_KERNEL); + if (!dg) + return -ENOMEM; + + err = memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + if (err) { + kfree(dg); + return err; + } + + dg->dst = vmci_make_handle(remote_addr->svm_cid, + remote_addr->svm_port); + dg->src = vmci_make_handle(vsk->local_addr.svm_cid, + vsk->local_addr.svm_port); + dg->payload_size = len; + + err = vmci_datagram_send(dg); + kfree(dg); + if (err < 0) + return vmci_transport_error_to_vsock_error(err); + + return err - sizeof(*dg); +} + +static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, size_t len, + int flags) +{ + int err; + struct vmci_datagram *dg; + size_t payload_len; + struct sk_buff *skb; + + if (flags & MSG_OOB || flags & MSG_ERRQUEUE) + return -EOPNOTSUPP; + + /* Retrieve the head sk_buff from the socket's receive queue. */ + err = 0; + skb = skb_recv_datagram(&vsk->sk, flags, &err); + if (!skb) + return err; + + dg = (struct vmci_datagram *)skb->data; + if (!dg) + /* err is 0, meaning we read zero bytes. */ + goto out; + + payload_len = dg->payload_size; + /* Ensure the sk_buff matches the payload size claimed in the packet. */ + if (payload_len != skb->len - sizeof(*dg)) { + err = -EINVAL; + goto out; + } + + if (payload_len > len) { + payload_len = len; + msg->msg_flags |= MSG_TRUNC; + } + + /* Place the datagram payload in the user's iovec. */ + err = skb_copy_datagram_msg(skb, sizeof(*dg), msg, payload_len); + if (err) + goto out; + + if (msg->msg_name) { + /* Provide the address of the sender. */ + DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name); + vsock_addr_init(vm_addr, dg->src.context, dg->src.resource); + msg->msg_namelen = sizeof(*vm_addr); + } + err = payload_len; + +out: + skb_free_datagram(&vsk->sk, skb); + return err; +} + +static bool vmci_transport_dgram_allow(u32 cid, u32 port) +{ + if (cid == VMADDR_CID_HYPERVISOR) { + /* Registrations of PBRPC Servers do not modify VMX/Hypervisor + * state and are allowed. + */ + return port == VMCI_UNITY_PBRPC_REGISTER; + } + + return true; +} + +static int vmci_transport_connect(struct vsock_sock *vsk) +{ + int err; + bool old_pkt_proto = false; + struct sock *sk = &vsk->sk; + + if (vmci_transport_old_proto_override(&old_pkt_proto) && + old_pkt_proto) { + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); + if (err < 0) { + sk->sk_state = TCP_CLOSE; + return err; + } + } else { + int supported_proto_versions = + vmci_transport_new_proto_supported_versions(); + err = vmci_transport_send_conn_request2(sk, vsk->buffer_size, + supported_proto_versions); + if (err < 0) { + sk->sk_state = TCP_CLOSE; + return err; + } + + vsk->sent_request = true; + } + + return err; +} + +static ssize_t vmci_transport_stream_dequeue( + struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, + int flags) +{ + ssize_t err; + + if (flags & MSG_PEEK) + err = vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0); + else + err = vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0); + + if (err < 0) + err = -ENOMEM; + + return err; +} + +static ssize_t vmci_transport_stream_enqueue( + struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + ssize_t err; + + err = vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0); + if (err < 0) + err = -ENOMEM; + + return err; +} + +static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk) +{ + return vmci_qpair_consume_buf_ready(vmci_trans(vsk)->qpair); +} + +static s64 vmci_transport_stream_has_space(struct vsock_sock *vsk) +{ + return vmci_qpair_produce_free_space(vmci_trans(vsk)->qpair); +} + +static u64 vmci_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return vmci_trans(vsk)->consume_size; +} + +static bool vmci_transport_stream_is_active(struct vsock_sock *vsk) +{ + return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle); +} + +static int vmci_transport_notify_poll_in( + struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + return vmci_trans(vsk)->notify_ops->poll_in( + &vsk->sk, target, data_ready_now); +} + +static int vmci_transport_notify_poll_out( + struct vsock_sock *vsk, + size_t target, + bool *space_available_now) +{ + return vmci_trans(vsk)->notify_ops->poll_out( + &vsk->sk, target, space_available_now); +} + +static int vmci_transport_notify_recv_init( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_init( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_pre_block( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_pre_block( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_pre_dequeue( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_pre_dequeue( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_post_dequeue( + struct vsock_sock *vsk, + size_t target, + ssize_t copied, + bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_post_dequeue( + &vsk->sk, target, copied, data_read, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_send_init( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_init( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_pre_block( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_pre_block( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_pre_enqueue( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_pre_enqueue( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_post_enqueue( + struct vsock_sock *vsk, + ssize_t written, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_post_enqueue( + &vsk->sk, written, + (struct vmci_transport_send_notify_data *)data); +} + +static bool vmci_transport_old_proto_override(bool *old_pkt_proto) +{ + if (PROTOCOL_OVERRIDE != -1) { + if (PROTOCOL_OVERRIDE == 0) + *old_pkt_proto = true; + else + *old_pkt_proto = false; + + pr_info("Proto override in use\n"); + return true; + } + + return false; +} + +static bool vmci_transport_proto_to_notify_struct(struct sock *sk, + u16 *proto, + bool old_pkt_proto) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (old_pkt_proto) { + if (*proto != VSOCK_PROTO_INVALID) { + pr_err("Can't set both an old and new protocol\n"); + return false; + } + vmci_trans(vsk)->notify_ops = &vmci_transport_notify_pkt_ops; + goto exit; + } + + switch (*proto) { + case VSOCK_PROTO_PKT_ON_NOTIFY: + vmci_trans(vsk)->notify_ops = + &vmci_transport_notify_pkt_q_state_ops; + break; + default: + pr_err("Unknown notify protocol version\n"); + return false; + } + +exit: + vmci_trans(vsk)->notify_ops->socket_init(sk); + return true; +} + +static u16 vmci_transport_new_proto_supported_versions(void) +{ + if (PROTOCOL_OVERRIDE != -1) + return PROTOCOL_OVERRIDE; + + return VSOCK_PROTO_ALL_SUPPORTED; +} + +static u32 vmci_transport_get_local_cid(void) +{ + return vmci_get_context_id(); +} + +static struct vsock_transport vmci_transport = { + .module = THIS_MODULE, + .init = vmci_transport_socket_init, + .destruct = vmci_transport_destruct, + .release = vmci_transport_release, + .connect = vmci_transport_connect, + .dgram_bind = vmci_transport_dgram_bind, + .dgram_dequeue = vmci_transport_dgram_dequeue, + .dgram_enqueue = vmci_transport_dgram_enqueue, + .dgram_allow = vmci_transport_dgram_allow, + .stream_dequeue = vmci_transport_stream_dequeue, + .stream_enqueue = vmci_transport_stream_enqueue, + .stream_has_data = vmci_transport_stream_has_data, + .stream_has_space = vmci_transport_stream_has_space, + .stream_rcvhiwat = vmci_transport_stream_rcvhiwat, + .stream_is_active = vmci_transport_stream_is_active, + .stream_allow = vmci_transport_stream_allow, + .notify_poll_in = vmci_transport_notify_poll_in, + .notify_poll_out = vmci_transport_notify_poll_out, + .notify_recv_init = vmci_transport_notify_recv_init, + .notify_recv_pre_block = vmci_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = vmci_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = vmci_transport_notify_recv_post_dequeue, + .notify_send_init = vmci_transport_notify_send_init, + .notify_send_pre_block = vmci_transport_notify_send_pre_block, + .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue, + .shutdown = vmci_transport_shutdown, + .get_local_cid = vmci_transport_get_local_cid, +}; + +static bool vmci_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &vmci_transport; +} + +static void vmci_vsock_transport_cb(bool is_host) +{ + int features; + + if (is_host) + features = VSOCK_TRANSPORT_F_H2G; + else + features = VSOCK_TRANSPORT_F_G2H; + + vsock_core_register(&vmci_transport, features); +} + +static int __init vmci_transport_init(void) +{ + int err; + + /* Create the datagram handle that we will use to send and receive all + * VSocket control messages for this context. + */ + err = vmci_transport_datagram_create_hnd(VMCI_TRANSPORT_PACKET_RID, + VMCI_FLAG_ANYCID_DG_HND, + vmci_transport_recv_stream_cb, + NULL, + &vmci_transport_stream_handle); + if (err < VMCI_SUCCESS) { + pr_err("Unable to create datagram handle. (%d)\n", err); + return vmci_transport_error_to_vsock_error(err); + } + err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED, + vmci_transport_qp_resumed_cb, + NULL, &vmci_transport_qp_resumed_sub_id); + if (err < VMCI_SUCCESS) { + pr_err("Unable to subscribe to resumed event. (%d)\n", err); + err = vmci_transport_error_to_vsock_error(err); + vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + goto err_destroy_stream_handle; + } + + /* Register only with dgram feature, other features (H2G, G2H) will be + * registered when the first host or guest becomes active. + */ + err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM); + if (err < 0) + goto err_unsubscribe; + + err = vmci_register_vsock_callback(vmci_vsock_transport_cb); + if (err < 0) + goto err_unregister; + + return 0; + +err_unregister: + vsock_core_unregister(&vmci_transport); +err_unsubscribe: + vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); +err_destroy_stream_handle: + vmci_datagram_destroy_handle(vmci_transport_stream_handle); + return err; +} +module_init(vmci_transport_init); + +static void __exit vmci_transport_exit(void) +{ + cancel_work_sync(&vmci_transport_cleanup_work); + vmci_transport_free_resources(&vmci_transport_cleanup_list); + + if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) { + if (vmci_datagram_destroy_handle( + vmci_transport_stream_handle) != VMCI_SUCCESS) + pr_err("Couldn't destroy datagram handle\n"); + vmci_transport_stream_handle = VMCI_INVALID_HANDLE; + } + + if (vmci_transport_qp_resumed_sub_id != VMCI_INVALID_ID) { + vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); + vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + } + + vmci_register_vsock_callback(NULL); + vsock_core_unregister(&vmci_transport); +} +module_exit(vmci_transport_exit); + +MODULE_AUTHOR("VMware, Inc."); +MODULE_DESCRIPTION("VMCI transport for Virtual Sockets"); +MODULE_VERSION("1.0.5.0-k"); +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS("vmware_vsock"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/vmw_vsock/vmci_transport.h b/net/vmw_vsock/vmci_transport.h new file mode 100644 index 0000000000..dbda3ababa --- /dev/null +++ b/net/vmw_vsock/vmci_transport.h @@ -0,0 +1,130 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * VMware vSockets Driver + * + * Copyright (C) 2013 VMware, Inc. All rights reserved. + */ + +#ifndef _VMCI_TRANSPORT_H_ +#define _VMCI_TRANSPORT_H_ + +#include <linux/vmw_vmci_defs.h> +#include <linux/vmw_vmci_api.h> + +#include <net/vsock_addr.h> +#include <net/af_vsock.h> + +/* If the packet format changes in a release then this should change too. */ +#define VMCI_TRANSPORT_PACKET_VERSION 1 + +/* The resource ID on which control packets are sent. */ +#define VMCI_TRANSPORT_PACKET_RID 1 + +/* The resource ID on which control packets are sent to the hypervisor. */ +#define VMCI_TRANSPORT_HYPERVISOR_PACKET_RID 15 + +#define VSOCK_PROTO_INVALID 0 +#define VSOCK_PROTO_PKT_ON_NOTIFY (1 << 0) +#define VSOCK_PROTO_ALL_SUPPORTED (VSOCK_PROTO_PKT_ON_NOTIFY) + +#define vmci_trans(_vsk) ((struct vmci_transport *)((_vsk)->trans)) + +enum vmci_transport_packet_type { + VMCI_TRANSPORT_PACKET_TYPE_INVALID = 0, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE, + VMCI_TRANSPORT_PACKET_TYPE_OFFER, + VMCI_TRANSPORT_PACKET_TYPE_ATTACH, + VMCI_TRANSPORT_PACKET_TYPE_WROTE, + VMCI_TRANSPORT_PACKET_TYPE_READ, + VMCI_TRANSPORT_PACKET_TYPE_RST, + VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN, + VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE, + VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST2, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2, + VMCI_TRANSPORT_PACKET_TYPE_MAX +}; + +struct vmci_transport_waiting_info { + u64 generation; + u64 offset; +}; + +/* Control packet type for STREAM sockets. DGRAMs have no control packets nor + * special packet header for data packets, they are just raw VMCI DGRAM + * messages. For STREAMs, control packets are sent over the control channel + * while data is written and read directly from queue pairs with no packet + * format. + */ +struct vmci_transport_packet { + struct vmci_datagram dg; + u8 version; + u8 type; + u16 proto; + u32 src_port; + u32 dst_port; + u32 _reserved2; + union { + u64 size; + u64 mode; + struct vmci_handle handle; + struct vmci_transport_waiting_info wait; + } u; +}; + +struct vmci_transport_notify_pkt { + u64 write_notify_window; + u64 write_notify_min_window; + bool peer_waiting_read; + bool peer_waiting_write; + bool peer_waiting_write_detected; + bool sent_waiting_read; + bool sent_waiting_write; + struct vmci_transport_waiting_info peer_waiting_read_info; + struct vmci_transport_waiting_info peer_waiting_write_info; + u64 produce_q_generation; + u64 consume_q_generation; +}; + +struct vmci_transport_notify_pkt_q_state { + u64 write_notify_window; + u64 write_notify_min_window; + bool peer_waiting_write; + bool peer_waiting_write_detected; +}; + +union vmci_transport_notify { + struct vmci_transport_notify_pkt pkt; + struct vmci_transport_notify_pkt_q_state pkt_q_state; +}; + +/* Our transport-specific data. */ +struct vmci_transport { + /* For DGRAMs. */ + struct vmci_handle dg_handle; + /* For STREAMs. */ + struct vmci_handle qp_handle; + struct vmci_qp *qpair; + u64 produce_size; + u64 consume_size; + u32 detach_sub_id; + union vmci_transport_notify notify; + const struct vmci_transport_notify_ops *notify_ops; + struct list_head elem; + struct sock *sk; + spinlock_t lock; /* protects sk. */ +}; + +int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src); +int vmci_transport_send_read_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src); +int vmci_transport_send_wrote(struct sock *sk); +int vmci_transport_send_read(struct sock *sk); +int vmci_transport_send_waiting_write(struct sock *sk, + struct vmci_transport_waiting_info *wait); +int vmci_transport_send_waiting_read(struct sock *sk, + struct vmci_transport_waiting_info *wait); + +#endif diff --git a/net/vmw_vsock/vmci_transport_notify.c b/net/vmw_vsock/vmci_transport_notify.c new file mode 100644 index 0000000000..7c3a7db134 --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify.c @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/stddef.h> +#include <net/sock.h> + +#include "vmci_transport_notify.h" + +#define PKT_FIELD(vsk, field_name) (vmci_trans(vsk)->notify.pkt.field_name) + +static bool vmci_transport_notify_waiting_write(struct vsock_sock *vsk) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + bool retval; + u64 notify_limit; + + if (!PKT_FIELD(vsk, peer_waiting_write)) + return false; + +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + /* When the sender blocks, we take that as a sign that the sender is + * faster than the receiver. To reduce the transmit rate of the sender, + * we delay the sending of the read notification by decreasing the + * write_notify_window. The notification is delayed until the number of + * bytes used in the queue drops below the write_notify_window. + */ + + if (!PKT_FIELD(vsk, peer_waiting_write_detected)) { + PKT_FIELD(vsk, peer_waiting_write_detected) = true; + if (PKT_FIELD(vsk, write_notify_window) < PAGE_SIZE) { + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + } else { + PKT_FIELD(vsk, write_notify_window) -= PAGE_SIZE; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + + } + } + notify_limit = vmci_trans(vsk)->consume_size - + PKT_FIELD(vsk, write_notify_window); +#else + notify_limit = 0; +#endif + + /* For now we ignore the wait information and just see if the free + * space exceeds the notify limit. Note that improving this function + * to be more intelligent will not require a protocol change and will + * retain compatibility between endpoints with mixed versions of this + * function. + * + * The notify_limit is used to delay notifications in the case where + * flow control is enabled. Below the test is expressed in terms of + * free space in the queue: if free_space > ConsumeSize - + * write_notify_window then notify An alternate way of expressing this + * is to rewrite the expression to use the data ready in the receive + * queue: if write_notify_window > bufferReady then notify as + * free_space == ConsumeSize - bufferReady. + */ + retval = vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair) > + notify_limit; +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + if (retval) { + /* + * Once we notify the peer, we reset the detected flag so the + * next wait will again cause a decrease in the window size. + */ + + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + } +#endif + return retval; +#else + return true; +#endif +} + +static bool vmci_transport_notify_waiting_read(struct vsock_sock *vsk) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + if (!PKT_FIELD(vsk, peer_waiting_read)) + return false; + + /* For now we ignore the wait information and just see if there is any + * data for our peer to read. Note that improving this function to be + * more intelligent will not require a protocol change and will retain + * compatibility between endpoints with mixed versions of this + * function. + */ + return vmci_qpair_produce_buf_ready(vmci_trans(vsk)->qpair) > 0; +#else + return true; +#endif +} + +static void +vmci_transport_handle_waiting_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + PKT_FIELD(vsk, peer_waiting_read) = true; + memcpy(&PKT_FIELD(vsk, peer_waiting_read_info), &pkt->u.wait, + sizeof(PKT_FIELD(vsk, peer_waiting_read_info))); + + if (vmci_transport_notify_waiting_read(vsk)) { + bool sent; + + if (bottom_half) + sent = vmci_transport_send_wrote_bh(dst, src) > 0; + else + sent = vmci_transport_send_wrote(sk) > 0; + + if (sent) + PKT_FIELD(vsk, peer_waiting_read) = false; + } +#endif +} + +static void +vmci_transport_handle_waiting_write(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + PKT_FIELD(vsk, peer_waiting_write) = true; + memcpy(&PKT_FIELD(vsk, peer_waiting_write_info), &pkt->u.wait, + sizeof(PKT_FIELD(vsk, peer_waiting_write_info))); + + if (vmci_transport_notify_waiting_write(vsk)) { + bool sent; + + if (bottom_half) + sent = vmci_transport_send_read_bh(dst, src) > 0; + else + sent = vmci_transport_send_read(sk) > 0; + + if (sent) + PKT_FIELD(vsk, peer_waiting_write) = false; + } +#endif +} + +static void +vmci_transport_handle_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + PKT_FIELD(vsk, sent_waiting_write) = false; +#endif + + sk->sk_write_space(sk); +} + +static bool send_waiting_read(struct sock *sk, u64 room_needed) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + struct vmci_transport_waiting_info waiting_info; + u64 tail; + u64 head; + u64 room_left; + bool ret; + + vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, sent_waiting_read)) + return true; + + if (PKT_FIELD(vsk, write_notify_window) < + vmci_trans(vsk)->consume_size) + PKT_FIELD(vsk, write_notify_window) = + min(PKT_FIELD(vsk, write_notify_window) + PAGE_SIZE, + vmci_trans(vsk)->consume_size); + + vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair, &tail, &head); + room_left = vmci_trans(vsk)->consume_size - head; + if (room_needed >= room_left) { + waiting_info.offset = room_needed - room_left; + waiting_info.generation = + PKT_FIELD(vsk, consume_q_generation) + 1; + } else { + waiting_info.offset = head + room_needed; + waiting_info.generation = PKT_FIELD(vsk, consume_q_generation); + } + + ret = vmci_transport_send_waiting_read(sk, &waiting_info) > 0; + if (ret) + PKT_FIELD(vsk, sent_waiting_read) = true; + + return ret; +#else + return true; +#endif +} + +static bool send_waiting_write(struct sock *sk, u64 room_needed) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + struct vmci_transport_waiting_info waiting_info; + u64 tail; + u64 head; + u64 room_left; + bool ret; + + vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, sent_waiting_write)) + return true; + + vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair, &tail, &head); + room_left = vmci_trans(vsk)->produce_size - tail; + if (room_needed + 1 >= room_left) { + /* Wraps around to current generation. */ + waiting_info.offset = room_needed + 1 - room_left; + waiting_info.generation = PKT_FIELD(vsk, produce_q_generation); + } else { + waiting_info.offset = tail + room_needed + 1; + waiting_info.generation = + PKT_FIELD(vsk, produce_q_generation) - 1; + } + + ret = vmci_transport_send_waiting_write(sk, &waiting_info) > 0; + if (ret) + PKT_FIELD(vsk, sent_waiting_write) = true; + + return ret; +#else + return true; +#endif +} + +static int vmci_transport_send_read_notification(struct sock *sk) +{ + struct vsock_sock *vsk; + bool sent_read; + unsigned int retries; + int err; + + vsk = vsock_sk(sk); + sent_read = false; + retries = 0; + err = 0; + + if (vmci_transport_notify_waiting_write(vsk)) { + /* Notify the peer that we have read, retrying the send on + * failure up to our maximum value. XXX For now we just log + * the failure, but later we should schedule a work item to + * handle the resend until it succeeds. That would require + * keeping track of work items in the vsk and cleaning them up + * upon socket close. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_read && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_read(sk); + if (err >= 0) + sent_read = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS) + pr_err("%p unable to send read notify to peer\n", sk); + else +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + PKT_FIELD(vsk, peer_waiting_write) = false; +#endif + + } + return err; +} + +static void +vmci_transport_handle_wrote(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk = vsock_sk(sk); + PKT_FIELD(vsk, sent_waiting_read) = false; +#endif + vsock_data_ready(sk); +} + +static void vmci_transport_notify_pkt_socket_init(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_read) = false; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + PKT_FIELD(vsk, sent_waiting_read) = false; + PKT_FIELD(vsk, sent_waiting_write) = false; + PKT_FIELD(vsk, produce_q_generation) = 0; + PKT_FIELD(vsk, consume_q_generation) = 0; + + memset(&PKT_FIELD(vsk, peer_waiting_read_info), 0, + sizeof(PKT_FIELD(vsk, peer_waiting_read_info))); + memset(&PKT_FIELD(vsk, peer_waiting_write_info), 0, + sizeof(PKT_FIELD(vsk, peer_waiting_write_info))); +} + +static void vmci_transport_notify_pkt_socket_destruct(struct vsock_sock *vsk) +{ +} + +static int +vmci_transport_notify_pkt_poll_in(struct sock *sk, + size_t target, bool *data_ready_now) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= target) { + *data_ready_now = true; + } else { + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. + */ + if (sk->sk_state == TCP_ESTABLISHED) { + if (!send_waiting_read(sk, 1)) + return -1; + + } + *data_ready_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_poll_out(struct sock *sk, + size_t target, bool *space_avail_now) +{ + s64 produce_q_free_space; + struct vsock_sock *vsk = vsock_sk(sk); + + produce_q_free_space = vsock_stream_has_space(vsk); + if (produce_q_free_space > 0) { + *space_avail_now = true; + return 0; + } else if (produce_q_free_space == 0) { + /* This is a connected socket but we can't currently send data. + * Notify the peer that we are waiting if the queue is full. We + * only send a waiting write if the queue is full because + * otherwise we end up in an infinite WAITING_WRITE, READ, + * WAITING_WRITE, READ, etc. loop. Treat failing to send the + * notification as a socket error, passing that back through + * the mask. + */ + if (!send_waiting_write(sk, 1)) + return -1; + + *space_avail_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_init( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + +#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY + data->consume_head = 0; + data->produce_tail = 0; +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + data->notify_on_block = false; + + if (PKT_FIELD(vsk, write_notify_min_window) < target + 1) { + PKT_FIELD(vsk, write_notify_min_window) = target + 1; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) { + /* If the current window is smaller than the new + * minimal window size, we need to reevaluate whether + * we need to notify the sender. If the number of ready + * bytes are smaller than the new window, we need to + * send a notification to the sender before we block. + */ + + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + data->notify_on_block = true; + } + } +#endif +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_pre_block( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + int err = 0; + + /* Notify our peer that we are waiting for data to read. */ + if (!send_waiting_read(sk, target)) { + err = -EHOSTUNREACH; + return err; + } +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + if (data->notify_on_block) { + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + data->notify_on_block = false; + } +#endif + + return err; +} + +static int +vmci_transport_notify_pkt_recv_pre_dequeue( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + /* Now consume up to len bytes from the queue. Note that since we have + * the socket locked we should copy at least ready bytes. + */ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair, + &data->produce_tail, + &data->consume_head); +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_post_dequeue( + struct sock *sk, + size_t target, + ssize_t copied, + bool data_read, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk; + int err; + + vsk = vsock_sk(sk); + err = 0; + + if (data_read) { +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + /* Detect a wrap-around to maintain queue generation. Note + * that this is safe since we hold the socket lock across the + * two queue pair operations. + */ + if (copied >= + vmci_trans(vsk)->consume_size - data->consume_head) + PKT_FIELD(vsk, consume_q_generation)++; +#endif + + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + } + return err; +} + +static int +vmci_transport_notify_pkt_send_init( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ +#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY + data->consume_head = 0; + data->produce_tail = 0; +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_send_pre_block( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + /* Notify our peer that we are waiting for room to write. */ + if (!send_waiting_write(sk, 1)) + return -EHOSTUNREACH; + + return 0; +} + +static int +vmci_transport_notify_pkt_send_pre_enqueue( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair, + &data->produce_tail, + &data->consume_head); +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_send_post_enqueue( + struct sock *sk, + ssize_t written, + struct vmci_transport_send_notify_data *data) +{ + int err = 0; + struct vsock_sock *vsk; + bool sent_wrote = false; + int retries = 0; + + vsk = vsock_sk(sk); + +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + /* Detect a wrap-around to maintain queue generation. Note that this + * is safe since we hold the socket lock across the two queue pair + * operations. + */ + if (written >= vmci_trans(vsk)->produce_size - data->produce_tail) + PKT_FIELD(vsk, produce_q_generation)++; + +#endif + + if (vmci_transport_notify_waiting_read(vsk)) { + /* Notify the peer that we have written, retrying the send on + * failure up to our maximum value. See the XXX comment for the + * corresponding piece of code in StreamRecvmsg() for potential + * improvements. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_wrote && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_wrote(sk); + if (err >= 0) + sent_wrote = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + pr_err("%p unable to send wrote notify to peer\n", sk); + return err; + } else { +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + PKT_FIELD(vsk, peer_waiting_read) = false; +#endif + } + } + return err; +} + +static void +vmci_transport_notify_pkt_handle_pkt( + struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src, bool *pkt_processed) +{ + bool processed = false; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + vmci_transport_handle_wrote(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_READ: + vmci_transport_handle_read(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE: + vmci_transport_handle_waiting_write(sk, pkt, bottom_half, + dst, src); + processed = true; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ: + vmci_transport_handle_waiting_read(sk, pkt, bottom_half, + dst, src); + processed = true; + break; + } + + if (pkt_processed) + *pkt_processed = processed; +} + +static void vmci_transport_notify_pkt_process_request(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static void vmci_transport_notify_pkt_process_negotiate(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +/* Socket control packet based operations. */ +const struct vmci_transport_notify_ops vmci_transport_notify_pkt_ops = { + .socket_init = vmci_transport_notify_pkt_socket_init, + .socket_destruct = vmci_transport_notify_pkt_socket_destruct, + .poll_in = vmci_transport_notify_pkt_poll_in, + .poll_out = vmci_transport_notify_pkt_poll_out, + .handle_notify_pkt = vmci_transport_notify_pkt_handle_pkt, + .recv_init = vmci_transport_notify_pkt_recv_init, + .recv_pre_block = vmci_transport_notify_pkt_recv_pre_block, + .recv_pre_dequeue = vmci_transport_notify_pkt_recv_pre_dequeue, + .recv_post_dequeue = vmci_transport_notify_pkt_recv_post_dequeue, + .send_init = vmci_transport_notify_pkt_send_init, + .send_pre_block = vmci_transport_notify_pkt_send_pre_block, + .send_pre_enqueue = vmci_transport_notify_pkt_send_pre_enqueue, + .send_post_enqueue = vmci_transport_notify_pkt_send_post_enqueue, + .process_request = vmci_transport_notify_pkt_process_request, + .process_negotiate = vmci_transport_notify_pkt_process_negotiate, +}; diff --git a/net/vmw_vsock/vmci_transport_notify.h b/net/vmw_vsock/vmci_transport_notify.h new file mode 100644 index 0000000000..a1aa5a998c --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify.h @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#ifndef __VMCI_TRANSPORT_NOTIFY_H__ +#define __VMCI_TRANSPORT_NOTIFY_H__ + +#include <linux/types.h> +#include <linux/vmw_vmci_defs.h> +#include <linux/vmw_vmci_api.h> + +#include "vmci_transport.h" + +/* Comment this out to compare with old protocol. */ +#define VSOCK_OPTIMIZATION_WAITING_NOTIFY 1 +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) +/* Comment this out to remove flow control for "new" protocol */ +#define VSOCK_OPTIMIZATION_FLOW_CONTROL 1 +#endif + +#define VMCI_TRANSPORT_MAX_DGRAM_RESENDS 10 + +struct vmci_transport_recv_notify_data { + u64 consume_head; + u64 produce_tail; + bool notify_on_block; +}; + +struct vmci_transport_send_notify_data { + u64 consume_head; + u64 produce_tail; +}; + +/* Socket notification callbacks. */ +struct vmci_transport_notify_ops { + void (*socket_init) (struct sock *sk); + void (*socket_destruct) (struct vsock_sock *vsk); + int (*poll_in) (struct sock *sk, size_t target, + bool *data_ready_now); + int (*poll_out) (struct sock *sk, size_t target, + bool *space_avail_now); + void (*handle_notify_pkt) (struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, struct sockaddr_vm *dst, + struct sockaddr_vm *src, + bool *pkt_processed); + int (*recv_init) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_pre_block) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_pre_dequeue) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_post_dequeue) (struct sock *sk, size_t target, + ssize_t copied, bool data_read, + struct vmci_transport_recv_notify_data *data); + int (*send_init) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_pre_block) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_pre_enqueue) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_post_enqueue) (struct sock *sk, ssize_t written, + struct vmci_transport_send_notify_data *data); + void (*process_request) (struct sock *sk); + void (*process_negotiate) (struct sock *sk); +}; + +extern const struct vmci_transport_notify_ops vmci_transport_notify_pkt_ops; +extern const +struct vmci_transport_notify_ops vmci_transport_notify_pkt_q_state_ops; + +#endif /* __VMCI_TRANSPORT_NOTIFY_H__ */ diff --git a/net/vmw_vsock/vmci_transport_notify_qstate.c b/net/vmw_vsock/vmci_transport_notify_qstate.c new file mode 100644 index 0000000000..e96a88d850 --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify_qstate.c @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/stddef.h> +#include <net/sock.h> + +#include "vmci_transport_notify.h" + +#define PKT_FIELD(vsk, field_name) \ + (vmci_trans(vsk)->notify.pkt_q_state.field_name) + +static bool vmci_transport_notify_waiting_write(struct vsock_sock *vsk) +{ + bool retval; + u64 notify_limit; + + if (!PKT_FIELD(vsk, peer_waiting_write)) + return false; + + /* When the sender blocks, we take that as a sign that the sender is + * faster than the receiver. To reduce the transmit rate of the sender, + * we delay the sending of the read notification by decreasing the + * write_notify_window. The notification is delayed until the number of + * bytes used in the queue drops below the write_notify_window. + */ + + if (!PKT_FIELD(vsk, peer_waiting_write_detected)) { + PKT_FIELD(vsk, peer_waiting_write_detected) = true; + if (PKT_FIELD(vsk, write_notify_window) < PAGE_SIZE) { + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + } else { + PKT_FIELD(vsk, write_notify_window) -= PAGE_SIZE; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + + } + } + notify_limit = vmci_trans(vsk)->consume_size - + PKT_FIELD(vsk, write_notify_window); + + /* The notify_limit is used to delay notifications in the case where + * flow control is enabled. Below the test is expressed in terms of + * free space in the queue: if free_space > ConsumeSize - + * write_notify_window then notify An alternate way of expressing this + * is to rewrite the expression to use the data ready in the receive + * queue: if write_notify_window > bufferReady then notify as + * free_space == ConsumeSize - bufferReady. + */ + + retval = vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair) > + notify_limit; + + if (retval) { + /* Once we notify the peer, we reset the detected flag so the + * next wait will again cause a decrease in the window size. + */ + + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + } + return retval; +} + +static void +vmci_transport_handle_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ + sk->sk_write_space(sk); +} + +static void +vmci_transport_handle_wrote(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ + vsock_data_ready(sk); +} + +static void vsock_block_update_write_window(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, write_notify_window) < vmci_trans(vsk)->consume_size) + PKT_FIELD(vsk, write_notify_window) = + min(PKT_FIELD(vsk, write_notify_window) + PAGE_SIZE, + vmci_trans(vsk)->consume_size); +} + +static int vmci_transport_send_read_notification(struct sock *sk) +{ + struct vsock_sock *vsk; + bool sent_read; + unsigned int retries; + int err; + + vsk = vsock_sk(sk); + sent_read = false; + retries = 0; + err = 0; + + if (vmci_transport_notify_waiting_write(vsk)) { + /* Notify the peer that we have read, retrying the send on + * failure up to our maximum value. XXX For now we just log + * the failure, but later we should schedule a work item to + * handle the resend until it succeeds. That would require + * keeping track of work items in the vsk and cleaning them up + * upon socket close. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_read && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_read(sk); + if (err >= 0) + sent_read = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS && !sent_read) + pr_err("%p unable to send read notification to peer\n", + sk); + else + PKT_FIELD(vsk, peer_waiting_write) = false; + + } + return err; +} + +static void vmci_transport_notify_pkt_socket_init(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; +} + +static void vmci_transport_notify_pkt_socket_destruct(struct vsock_sock *vsk) +{ + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; +} + +static int +vmci_transport_notify_pkt_poll_in(struct sock *sk, + size_t target, bool *data_ready_now) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= target) { + *data_ready_now = true; + } else { + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. + */ + if (sk->sk_state == TCP_ESTABLISHED) + vsock_block_update_write_window(sk); + *data_ready_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_poll_out(struct sock *sk, + size_t target, bool *space_avail_now) +{ + s64 produce_q_free_space; + struct vsock_sock *vsk = vsock_sk(sk); + + produce_q_free_space = vsock_stream_has_space(vsk); + if (produce_q_free_space > 0) { + *space_avail_now = true; + return 0; + } else if (produce_q_free_space == 0) { + /* This is a connected socket but we can't currently send data. + * Nothing else to do. + */ + *space_avail_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_init( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + data->consume_head = 0; + data->produce_tail = 0; + data->notify_on_block = false; + + if (PKT_FIELD(vsk, write_notify_min_window) < target + 1) { + PKT_FIELD(vsk, write_notify_min_window) = target + 1; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) { + /* If the current window is smaller than the new + * minimal window size, we need to reevaluate whether + * we need to notify the sender. If the number of ready + * bytes are smaller than the new window, we need to + * send a notification to the sender before we block. + */ + + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + data->notify_on_block = true; + } + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_pre_block( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + int err = 0; + + vsock_block_update_write_window(sk); + + if (data->notify_on_block) { + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + data->notify_on_block = false; + } + + return err; +} + +static int +vmci_transport_notify_pkt_recv_post_dequeue( + struct sock *sk, + size_t target, + ssize_t copied, + bool data_read, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk; + int err; + bool was_full = false; + u64 free_space; + + vsk = vsock_sk(sk); + err = 0; + + if (data_read) { + smp_mb(); + + free_space = + vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair); + was_full = free_space == copied; + + if (was_full) + PKT_FIELD(vsk, peer_waiting_write) = true; + + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + /* See the comment in + * vmci_transport_notify_pkt_send_post_enqueue(). + */ + vsock_data_ready(sk); + } + + return err; +} + +static int +vmci_transport_notify_pkt_send_init( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + data->consume_head = 0; + data->produce_tail = 0; + + return 0; +} + +static int +vmci_transport_notify_pkt_send_post_enqueue( + struct sock *sk, + ssize_t written, + struct vmci_transport_send_notify_data *data) +{ + int err = 0; + struct vsock_sock *vsk; + bool sent_wrote = false; + bool was_empty; + int retries = 0; + + vsk = vsock_sk(sk); + + smp_mb(); + + was_empty = + vmci_qpair_produce_buf_ready(vmci_trans(vsk)->qpair) == written; + if (was_empty) { + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_wrote && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_wrote(sk); + if (err >= 0) + sent_wrote = true; + + retries++; + } + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS && !sent_wrote) { + pr_err("%p unable to send wrote notification to peer\n", + sk); + return err; + } + + return err; +} + +static void +vmci_transport_notify_pkt_handle_pkt( + struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src, bool *pkt_processed) +{ + bool processed = false; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + vmci_transport_handle_wrote(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_READ: + vmci_transport_handle_read(sk, pkt, bottom_half, dst, src); + processed = true; + break; + } + + if (pkt_processed) + *pkt_processed = processed; +} + +static void vmci_transport_notify_pkt_process_request(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static void vmci_transport_notify_pkt_process_negotiate(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static int +vmci_transport_notify_pkt_recv_pre_dequeue( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +static int +vmci_transport_notify_pkt_send_pre_block( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +static int +vmci_transport_notify_pkt_send_pre_enqueue( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +/* Socket always on control packet based operations. */ +const struct vmci_transport_notify_ops vmci_transport_notify_pkt_q_state_ops = { + .socket_init = vmci_transport_notify_pkt_socket_init, + .socket_destruct = vmci_transport_notify_pkt_socket_destruct, + .poll_in = vmci_transport_notify_pkt_poll_in, + .poll_out = vmci_transport_notify_pkt_poll_out, + .handle_notify_pkt = vmci_transport_notify_pkt_handle_pkt, + .recv_init = vmci_transport_notify_pkt_recv_init, + .recv_pre_block = vmci_transport_notify_pkt_recv_pre_block, + .recv_pre_dequeue = vmci_transport_notify_pkt_recv_pre_dequeue, + .recv_post_dequeue = vmci_transport_notify_pkt_recv_post_dequeue, + .send_init = vmci_transport_notify_pkt_send_init, + .send_pre_block = vmci_transport_notify_pkt_send_pre_block, + .send_pre_enqueue = vmci_transport_notify_pkt_send_pre_enqueue, + .send_post_enqueue = vmci_transport_notify_pkt_send_post_enqueue, + .process_request = vmci_transport_notify_pkt_process_request, + .process_negotiate = vmci_transport_notify_pkt_process_negotiate, +}; diff --git a/net/vmw_vsock/vsock_addr.c b/net/vmw_vsock/vsock_addr.c new file mode 100644 index 0000000000..223b9660a7 --- /dev/null +++ b/net/vmw_vsock/vsock_addr.c @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2012 VMware, Inc. All rights reserved. + */ + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/stddef.h> +#include <net/sock.h> +#include <net/vsock_addr.h> + +void vsock_addr_init(struct sockaddr_vm *addr, u32 cid, u32 port) +{ + memset(addr, 0, sizeof(*addr)); + addr->svm_family = AF_VSOCK; + addr->svm_cid = cid; + addr->svm_port = port; +} +EXPORT_SYMBOL_GPL(vsock_addr_init); + +int vsock_addr_validate(const struct sockaddr_vm *addr) +{ + __u8 svm_valid_flags = VMADDR_FLAG_TO_HOST; + + if (!addr) + return -EFAULT; + + if (addr->svm_family != AF_VSOCK) + return -EAFNOSUPPORT; + + if (addr->svm_flags & ~svm_valid_flags) + return -EINVAL; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_addr_validate); + +bool vsock_addr_bound(const struct sockaddr_vm *addr) +{ + return addr->svm_port != VMADDR_PORT_ANY; +} +EXPORT_SYMBOL_GPL(vsock_addr_bound); + +void vsock_addr_unbind(struct sockaddr_vm *addr) +{ + vsock_addr_init(addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); +} +EXPORT_SYMBOL_GPL(vsock_addr_unbind); + +bool vsock_addr_equals_addr(const struct sockaddr_vm *addr, + const struct sockaddr_vm *other) +{ + return addr->svm_cid == other->svm_cid && + addr->svm_port == other->svm_port; +} +EXPORT_SYMBOL_GPL(vsock_addr_equals_addr); + +int vsock_addr_cast(const struct sockaddr *addr, + size_t len, struct sockaddr_vm **out_addr) +{ + if (len < sizeof(**out_addr)) + return -EFAULT; + + *out_addr = (struct sockaddr_vm *)addr; + return vsock_addr_validate(*out_addr); +} +EXPORT_SYMBOL_GPL(vsock_addr_cast); diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c new file mode 100644 index 0000000000..a3c97546ab --- /dev/null +++ b/net/vmw_vsock/vsock_bpf.c @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com> + * + * Based off of net/unix/unix_bpf.c + */ + +#include <linux/bpf.h> +#include <linux/module.h> +#include <linux/skmsg.h> +#include <linux/socket.h> +#include <linux/wait.h> +#include <net/af_vsock.h> +#include <net/sock.h> + +#define vsock_sk_has_data(__sk, __psock) \ + ({ !skb_queue_empty(&(__sk)->sk_receive_queue) || \ + !skb_queue_empty(&(__psock)->ingress_skb) || \ + !list_empty(&(__psock)->ingress_msg); \ + }) + +static struct proto *vsock_prot_saved __read_mostly; +static DEFINE_SPINLOCK(vsock_prot_lock); +static struct proto vsock_bpf_prot; + +static bool vsock_has_data(struct sock *sk, struct sk_psock *psock) +{ + struct vsock_sock *vsk = vsock_sk(sk); + s64 ret; + + ret = vsock_connectible_has_data(vsk); + if (ret > 0) + return true; + + return vsock_sk_has_data(sk, psock); +} + +static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo) +{ + bool ret; + + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return true; + + if (!timeo) + return false; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + ret = vsock_has_data(sk, psock); + if (!ret) { + wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + ret = vsock_has_data(sk, psock); + } + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return ret; +} + +static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags) +{ + struct socket *sock = sk->sk_socket; + int err; + + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) + err = vsock_connectible_recvmsg(sock, msg, len, flags); + else if (sk->sk_type == SOCK_DGRAM) + err = vsock_dgram_recvmsg(sock, msg, len, flags); + else + err = -EPROTOTYPE; + + return err; +} + +static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags, int *addr_len) +{ + struct sk_psock *psock; + int copied; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return __vsock_recvmsg(sk, msg, len, flags); + + lock_sock(sk); + if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) { + release_sock(sk); + sk_psock_put(sk, psock); + return __vsock_recvmsg(sk, msg, len, flags); + } + + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + while (copied == 0) { + long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + if (!vsock_msg_wait_data(sk, psock, timeo)) { + copied = -EAGAIN; + break; + } + + if (sk_psock_queue_empty(psock)) { + release_sock(sk); + sk_psock_put(sk, psock); + return __vsock_recvmsg(sk, msg, len, flags); + } + + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + } + + release_sock(sk); + sk_psock_put(sk, psock); + + return copied; +} + +/* Copy of original proto with updated sock_map methods */ +static struct proto vsock_bpf_prot = { + .close = sock_map_close, + .recvmsg = vsock_bpf_recvmsg, + .sock_is_readable = sk_msg_is_readable, + .unhash = sock_map_unhash, +}; + +static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base) +{ + *prot = *base; + prot->close = sock_map_close; + prot->recvmsg = vsock_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; +} + +static void vsock_bpf_check_needs_rebuild(struct proto *ops) +{ + /* Paired with the smp_store_release() below. */ + if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) { + spin_lock_bh(&vsock_prot_lock); + if (likely(ops != vsock_prot_saved)) { + vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops); + /* Make sure proto function pointers are updated before publishing the + * pointer to the struct. + */ + smp_store_release(&vsock_prot_saved, ops); + } + spin_unlock_bh(&vsock_prot_lock); + } +} + +int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) +{ + struct vsock_sock *vsk; + + if (restore) { + sk->sk_write_space = psock->saved_write_space; + sock_replace_proto(sk, psock->sk_proto); + return 0; + } + + vsk = vsock_sk(sk); + if (!vsk->transport) + return -ENODEV; + + if (!vsk->transport->read_skb) + return -EOPNOTSUPP; + + vsock_bpf_check_needs_rebuild(psock->sk_proto); + sock_replace_proto(sk, &vsock_bpf_prot); + return 0; +} + +void __init vsock_bpf_build_proto(void) +{ + vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto); +} diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c new file mode 100644 index 0000000000..0ce65d0a4a --- /dev/null +++ b/net/vmw_vsock/vsock_loopback.c @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* loopback transport for vsock using virtio_transport_common APIs + * + * Copyright (C) 2013-2019 Red Hat, Inc. + * Authors: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * Stefano Garzarella <sgarzare@redhat.com> + * + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/list.h> +#include <linux/virtio_vsock.h> + +struct vsock_loopback { + struct workqueue_struct *workqueue; + + struct sk_buff_head pkt_queue; + struct work_struct pkt_work; +}; + +static struct vsock_loopback the_vsock_loopback; + +static u32 vsock_loopback_get_local_cid(void) +{ + return VMADDR_CID_LOCAL; +} + +static int vsock_loopback_send_pkt(struct sk_buff *skb) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + int len = skb->len; + + virtio_vsock_skb_queue_tail(&vsock->pkt_queue, skb); + queue_work(vsock->workqueue, &vsock->pkt_work); + + return len; +} + +static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + + virtio_transport_purge_skbs(vsk, &vsock->pkt_queue); + + return 0; +} + +static bool vsock_loopback_seqpacket_allow(u32 remote_cid); + +static struct virtio_transport loopback_transport = { + .transport = { + .module = THIS_MODULE, + + .get_local_cid = vsock_loopback_get_local_cid, + + .init = virtio_transport_do_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + .cancel_pkt = vsock_loopback_cancel_pkt, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, + .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, + .seqpacket_allow = vsock_loopback_seqpacket_allow, + .seqpacket_has_data = virtio_transport_seqpacket_has_data, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + .notify_buffer_size = virtio_transport_notify_buffer_size, + .notify_set_rcvlowat = virtio_transport_notify_set_rcvlowat, + + .read_skb = virtio_transport_read_skb, + }, + + .send_pkt = vsock_loopback_send_pkt, +}; + +static bool vsock_loopback_seqpacket_allow(u32 remote_cid) +{ + return true; +} + +static void vsock_loopback_work(struct work_struct *work) +{ + struct vsock_loopback *vsock = + container_of(work, struct vsock_loopback, pkt_work); + struct sk_buff_head pkts; + struct sk_buff *skb; + + skb_queue_head_init(&pkts); + + spin_lock_bh(&vsock->pkt_queue.lock); + skb_queue_splice_init(&vsock->pkt_queue, &pkts); + spin_unlock_bh(&vsock->pkt_queue.lock); + + while ((skb = __skb_dequeue(&pkts))) { + virtio_transport_deliver_tap_pkt(skb); + virtio_transport_recv_pkt(&loopback_transport, skb); + } +} + +static int __init vsock_loopback_init(void) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + int ret; + + vsock->workqueue = alloc_workqueue("vsock-loopback", 0, 0); + if (!vsock->workqueue) + return -ENOMEM; + + skb_queue_head_init(&vsock->pkt_queue); + INIT_WORK(&vsock->pkt_work, vsock_loopback_work); + + ret = vsock_core_register(&loopback_transport.transport, + VSOCK_TRANSPORT_F_LOCAL); + if (ret) + goto out_wq; + + return 0; + +out_wq: + destroy_workqueue(vsock->workqueue); + return ret; +} + +static void __exit vsock_loopback_exit(void) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + + vsock_core_unregister(&loopback_transport.transport); + + flush_work(&vsock->pkt_work); + + virtio_vsock_skb_queue_purge(&vsock->pkt_queue); + + destroy_workqueue(vsock->workqueue); +} + +module_init(vsock_loopback_init); +module_exit(vsock_loopback_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Stefano Garzarella <sgarzare@redhat.com>"); +MODULE_DESCRIPTION("loopback transport for vsock"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); |